[release][project] supersonic 0.7.3 version backend update (#40)

* [improvement] add some features

* [improvement] revise CHANGELOG

---------

Co-authored-by: zuopengge <hwzuopengge@tencent.com>
This commit is contained in:
mainmain
2023-08-29 20:06:34 +08:00
committed by GitHub
parent 6fe9ab79ed
commit e1911bc81b
260 changed files with 6466 additions and 7108 deletions

View File

@@ -4,7 +4,17 @@
- "Breaking Changes" describes any changes that may break existing functionality or cause - "Breaking Changes" describes any changes that may break existing functionality or cause
compatibility issues with previous versions. compatibility issues with previous versions.
## SuperSonic [0.7.3] - 2023-08-29
### Added
- meet checkstyle code requirements
- save parseInfo after parsing
- add time statistics
- add agent
### Updated
- dsl where condition is used for front-end display
- dsl remove context inheritance
## SuperSonic [0.7.2] - 2023-08-12 ## SuperSonic [0.7.2] - 2023-08-12

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash #!/usr/bin/env bash
sbinDir=$(cd "$(dirname "$0")"; pwd) sbinDir=$(cd "$(dirname "$0")"; pwd)
baseDir=$(cd "$sbinDir/.." && pwd -P) baseDir=$(readlink -f $sbinDir/../)
runtimeDir=$baseDir/runtime runtimeDir=$baseDir/runtime
buildDir=$baseDir/build buildDir=$baseDir/build

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash #!/usr/bin/env bash
sbinDir=$(cd "$(dirname "$0")"; pwd) sbinDir=$(cd "$(dirname "$0")"; pwd)
baseDir=$(cd "$sbinDir/.." && pwd -P) baseDir=$(readlink -f $sbinDir/../)
buildDir=$baseDir/build buildDir=$baseDir/build
cd $baseDir/bin cd $baseDir/bin

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash #!/usr/bin/env bash
sbinDir=$(cd "$(dirname "$0")"; pwd) sbinDir=$(cd "$(dirname "$0")"; pwd)
baseDir=$(cd "$sbinDir/.." && pwd -P) baseDir=$(readlink -f $sbinDir/../)
runtimeDir=$baseDir/runtime runtimeDir=$baseDir/runtime
buildDir=$baseDir/build buildDir=$baseDir/build

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash #!/usr/bin/env bash
sbinDir=$(cd "$(dirname "$0")"; pwd) sbinDir=$(cd "$(dirname "$0")"; pwd)
baseDir=$(cd "$sbinDir/.." && pwd -P) baseDir=$(readlink -f $sbinDir/../)
runtimeDir=$baseDir/runtime runtimeDir=$baseDir/runtime
buildDir=$baseDir/build buildDir=$baseDir/build

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash #!/usr/bin/env bash
sbinDir=$(cd "$(dirname "$0")"; pwd) sbinDir=$(cd "$(dirname "$0")"; pwd)
baseDir=$(cd "$sbinDir/.." && pwd -P) baseDir=$(readlink -f $sbinDir/../)
runtimeDir=$baseDir/../runtime runtimeDir=$baseDir/../runtime
buildDir=$baseDir/build buildDir=$baseDir/build

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash #!/usr/bin/env bash
sbinDir=$(cd "$(dirname "$0")"; pwd) sbinDir=$(cd "$(dirname "$0")"; pwd)
baseDir=$(cd "$sbinDir/.." && pwd -P) baseDir=$(readlink -f $sbinDir/../)
runtimeDir=$baseDir/../runtime runtimeDir=$baseDir/../runtime
buildDir=$baseDir/build buildDir=$baseDir/build

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash #!/usr/bin/env bash
sbinDir=$(cd "$(dirname "$0")"; pwd) sbinDir=$(cd "$(dirname "$0")"; pwd)
baseDir=$(cd "$sbinDir/.." && pwd -P) baseDir=$(readlink -f $sbinDir/../)
runtimeDir=$baseDir/../runtime runtimeDir=$baseDir/../runtime
buildDir=$baseDir/build buildDir=$baseDir/build
@@ -29,4 +29,4 @@ rm -fr ${buildDir}/supersonic-webapp
#start standalone service #start standalone service
sh ${runtimeDir}/supersonic-standalone/bin/service.sh restart sh ${runtimeDir}/supersonic-standalone/bin/service.sh restart
#start llm service #start llm service
sh ${runtimeDir}/supersonic-standalone/llm/bin/service.sh restart sh ${runtimeDir}/supersonic-standalone/llm/bin/service.sh restart

View File

@@ -5,7 +5,7 @@ import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq; import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq;
import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp; import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp;
import java.util.List; import java.util.List;
import javax.servlet.http.HttpServletRequest;
public interface AuthService { public interface AuthService {

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.auth.authentication.persistence.repository.Impl; package com.tencent.supersonic.auth.authentication.persistence.repository.impl;
import com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO; import com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO;

View File

@@ -11,12 +11,12 @@ import java.util.Set;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RestController;
@RestController @RestController
@RequestMapping("/api/auth/user") @RequestMapping("/api/auth/user")

View File

@@ -9,7 +9,6 @@ import static com.tencent.supersonic.auth.api.authentication.constant.UserConsta
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_ID; import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_ID;
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_NAME; import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_NAME;
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_PASSWORD; import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_PASSWORD;
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig; import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.pojo.UserWithPassword; import com.tencent.supersonic.auth.api.authentication.pojo.UserWithPassword;
@@ -22,9 +21,11 @@ import java.util.Date;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
@Slf4j
@Component @Component
public class UserTokenUtils { public class UserTokenUtils {
@@ -68,7 +69,9 @@ public class UserTokenUtils {
public UserWithPassword getUserWithPassword(HttpServletRequest request) { public UserWithPassword getUserWithPassword(HttpServletRequest request) {
String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey()); String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey());
if (StringUtils.isBlank(token)) { if (StringUtils.isBlank(token)) {
throw new AccessException("token is blank, get user failed"); String message = "token is blank, get user failed";
log.warn("{}, uri: {}", message, request.getServletPath());
throw new AccessException(message);
} }
final Claims claims = getClaims(token); final Claims claims = getClaims(token);
Long userId = Long.parseLong(claims.getOrDefault(TOKEN_USER_ID, 0).toString()); Long userId = Long.parseLong(claims.getOrDefault(TOKEN_USER_ID, 0).toString());

View File

@@ -12,13 +12,16 @@ import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResource
import com.tencent.supersonic.auth.api.authorization.service.AuthService; import com.tencent.supersonic.auth.api.authorization.service.AuthService;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup; import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule; import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
import com.tencent.supersonic.common.util.S2ThreadContext;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.*; import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.Map;
import java.util.ArrayList;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Service @Service

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.api.component;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo; import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.JSQLParserException;
public interface DSLOptimizer { public interface SemanticCorrector {
CorrectionInfo rewriter(CorrectionInfo correctionInfo) throws JSQLParserException; CorrectionInfo corrector(CorrectionInfo correctionInfo) throws JSQLParserException;
} }

View File

@@ -6,14 +6,15 @@ import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.common.pojo.enums.AuthType; import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq; import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
import com.tencent.supersonic.semantic.api.model.request.PageMetricReq; import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.DomainResp; import com.tencent.supersonic.semantic.api.model.response.DomainResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp; import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.ModelResp; import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq; import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.List; import java.util.List;
/** /**
@@ -31,22 +32,13 @@ import java.util.List;
public interface SemanticLayer { public interface SemanticLayer {
QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user); QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user);
QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user); QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
QueryResultWithSchemaResp queryByDsl(QueryDslReq queryDslReq, User user); QueryResultWithSchemaResp queryByDsl(QueryDslReq queryDslReq, User user);
List<ModelSchema> getModelSchema(); List<ModelSchema> getModelSchema();
List<ModelSchema> getModelSchema(List<Long> ids); List<ModelSchema> getModelSchema(List<Long> ids);
ModelSchema getModelSchema(Long model, Boolean cacheEnable); ModelSchema getModelSchema(Long model, Boolean cacheEnable);
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionCmd); PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionCmd);
PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricCmd); PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricCmd);
List<DomainResp> getDomainList(User user); List<DomainResp> getDomainList(User user);
List<ModelResp> getModelList(AuthType authType, Long domainId, User user); List<ModelResp> getModelList(AuthType authType, Long domainId, User user);
} }

View File

@@ -6,6 +6,7 @@ import lombok.Data;
public class ChatContext { public class ChatContext {
private Integer chatId; private Integer chatId;
private Integer agentId;
private String queryText; private String queryText;
private SemanticParseInfo parseInfo = new SemanticParseInfo(); private SemanticParseInfo parseInfo = new SemanticParseInfo();
private String user; private String user;

View File

@@ -32,6 +32,7 @@ public class ModelSchema {
break; break;
case VALUE: case VALUE:
element = dimensionValues.stream().filter(e -> e.getId() == elementID).findFirst(); element = dimensionValues.stream().filter(e -> e.getId() == elementID).findFirst();
break;
default: default:
} }

View File

@@ -1,18 +1,19 @@
package com.tencent.supersonic.chat.api.pojo; package com.tencent.supersonic.chat.api.pojo;
import com.google.common.base.Objects; import com.google.common.base.Objects;
import java.io.Serializable; import java.io.Serializable;
import java.util.List; import java.util.List;
import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.Getter; import lombok.Getter;
import lombok.Builder;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
@Data @Data
@Getter @Getter
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
//@AllArgsConstructor
public class SchemaElement implements Serializable { public class SchemaElement implements Serializable {
private Long model; private Long model;
@@ -23,11 +24,8 @@ public class SchemaElement implements Serializable {
private SchemaElementType type; private SchemaElementType type;
private List<String> alias; private List<String> alias;
// public SchemaElement() {
// }
public SchemaElement(Long model, Long id, String name, String bizName, public SchemaElement(Long model, Long id, String name, String bizName,
Long useCnt, SchemaElementType type, List<String> alias) { Long useCnt, SchemaElementType type, List<String> alias) {
this.model = model; this.model = model;
this.id = id; this.id = id;
this.name = name; this.name = name;

View File

@@ -1,23 +1,26 @@
package com.tencent.supersonic.chat.api.pojo; package com.tencent.supersonic.chat.api.pojo;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import java.util.LinkedHashSet;
import java.util.ArrayList;
import java.util.Map;
import java.util.HashMap;
import java.util.Comparator;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.Order; import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import lombok.Data; import lombok.Data;
@Data @Data
public class SemanticParseInfo { public class SemanticParseInfo {
private Integer id;
private String queryMode; private String queryMode;
private SchemaElement model; private SchemaElement model;
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator()); private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
@@ -33,7 +36,7 @@ public class SemanticParseInfo {
private double score; private double score;
private List<SchemaElementMatch> elementMatches = new ArrayList<>(); private List<SchemaElementMatch> elementMatches = new ArrayList<>();
private Map<String, Object> properties = new HashMap<>(); private Map<String, Object> properties = new HashMap<>();
private EntityInfo entityInfo;
public Long getModelId() { public Long getModelId() {
return model != null ? model.getId() : 0L; return model != null ? model.getId() : 0L;
} }
@@ -43,7 +46,6 @@ public class SemanticParseInfo {
} }
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> { private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
@Override @Override
public int compare(SchemaElement o1, SchemaElement o2) { public int compare(SchemaElement o1, SchemaElement o2) {
int len1 = o1.getName().length(); int len1 = o1.getName().length();

View File

@@ -9,8 +9,11 @@ import lombok.Data;
public class ExecuteQueryReq { public class ExecuteQueryReq {
private User user; private User user;
private Integer agentId;
private Integer chatId; private Integer chatId;
private String queryText; private String queryText;
private Long queryId;
private Integer parseId;
private SemanticParseInfo parseInfo; private SemanticParseInfo parseInfo;
private boolean saveAnswer = true; private boolean saveAnswer = true;
} }

View File

@@ -6,6 +6,5 @@ import lombok.Data;
@Data @Data
public class AggregateInfo { public class AggregateInfo {
private List<MetricInfo> metricInfos = new ArrayList<>(); private List<MetricInfo> metricInfos = new ArrayList<>();
} }

View File

@@ -1,12 +1,13 @@
package com.tencent.supersonic.chat.api.pojo.response; package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.Getter; import lombok.Getter;
import lombok.Builder;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.AllArgsConstructor;
import java.util.List;
@Data @Data
@Getter @Getter
@@ -14,9 +15,9 @@ import lombok.NoArgsConstructor;
@NoArgsConstructor @NoArgsConstructor
@AllArgsConstructor @AllArgsConstructor
public class ParseResp { public class ParseResp {
private Integer chatId; private Integer chatId;
private String queryText; private String queryText;
private Long queryId;
private ParseState state; private ParseState state;
private List<SemanticParseInfo> selectedParses; private List<SemanticParseInfo> selectedParses;
private List<SemanticParseInfo> candidateParses; private List<SemanticParseInfo> candidateParses;

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.api.pojo.response;
import java.util.List; import java.util.List;
import lombok.Data; import lombok.Data;
@Data @Data
public class SearchResp { public class SearchResp {

View File

@@ -0,0 +1,18 @@
package com.tencent.supersonic.chat.api.pojo.response;
import lombok.Data;
import java.util.List;
import java.util.Map;
@Data
public class ShowCaseResp {
private Map<Long, List<QueryResp>> showCaseMap;
private int pageSize;
private int current;
}

View File

@@ -5,6 +5,7 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.agent.tool.AgentToolType; import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.common.pojo.RecordInfo; import com.tencent.supersonic.common.pojo.RecordInfo;
import java.util.Objects;
import lombok.Data; import lombok.Data;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.List; import java.util.List;
@@ -23,7 +24,6 @@ public class Agent extends RecordInfo {
private Integer status; private Integer status;
private List<String> examples; private List<String> examples;
private String agentConfig; private String agentConfig;
public List<String> getTools(AgentToolType type) { public List<String> getTools(AgentToolType type) {
Map map = JSONObject.parseObject(agentConfig, Map.class); Map map = JSONObject.parseObject(agentConfig, Map.class);
if (CollectionUtils.isEmpty(map) || map.get("tools") == null) { if (CollectionUtils.isEmpty(map) || map.get("tools") == null) {
@@ -31,7 +31,13 @@ public class Agent extends RecordInfo {
} }
List<Map> toolList = (List) map.get("tools"); List<Map> toolList = (List) map.get("tools");
return toolList.stream() return toolList.stream()
.filter(tool -> type.name().equals(tool.get("type"))) .filter(tool -> {
if (Objects.isNull(type)) {
return true;
}
return type.name().equals(tool.get("type"));
}
)
.map(JSONObject::toJSONString) .map(JSONObject::toJSONString)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }

View File

@@ -1,12 +1,9 @@
package com.tencent.supersonic.chat.agent.tool; package com.tencent.supersonic.chat.agent.tool;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import java.util.List;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@AllArgsConstructor @AllArgsConstructor

View File

@@ -2,12 +2,19 @@ package com.tencent.supersonic.chat.agent.tool;
import lombok.Data; import lombok.Data;
import org.apache.commons.collections.CollectionUtils;
import java.util.List; import java.util.List;
@Data @Data
public class RuleQueryTool extends AgentTool { public class RuleQueryTool extends AgentTool {
private List<Long> modelIds;
private List<String> queryModes; private List<String> queryModes;
public boolean isContainsAllModel() {
return CollectionUtils.isNotEmpty(modelIds) && modelIds.contains(-1L);
}
} }

View File

@@ -0,0 +1,99 @@
/*
//package com.tencent.supersonic.chat.aspect;
//
//import lombok.extern.slf4j.Slf4j;
//import org.aspectj.lang.JoinPoint;
//import org.aspectj.lang.ProceedingJoinPoint;
//import org.aspectj.lang.annotation.*;
//import org.springframework.stereotype.Component;
//
//import java.util.HashMap;
//import java.util.Map;
//
//@Aspect
//@Component
//@Slf4j
//public class TimeCostAspect {
//
// ThreadLocal<Long> startTime = new ThreadLocal<>();
//
// ThreadLocal<Map<String, Long>> map = new ThreadLocal<>();
//
// @Pointcut("execution(public * com.tencent.supersonic.chat.mapper.HanlpDictMapper.*(*))")
// //@Pointcut("execution(* public com.tencent.supersonic.chat.parser.*.*(..))")
// //@Pointcut("execution(* com.tencent.supersonic.chat.mapper.*Mapper.map(..)) ")
// //@Pointcut("execution(* com.tencent.supersonic.chat.mapper.HanlpDictMapper.map(..)) ")
// //@Pointcut("execution(* com.tencent.supersonic.chat.parser.rule.QueryModeParser.*(..)) ")
// public void point() {
// }
//
// @Around("point()")
// public void doAround(ProceedingJoinPoint joinPoint) throws Throwable {
// long start = System.currentTimeMillis();
// try {
// log.info("切面开始");
// Object result = joinPoint.proceed();
// log.info("切面开始");
// if (result == null) {
// //如果切到了 没有返回类型的void方法这里直接返回
// //return null;
// }
// long end = System.currentTimeMillis();
// log.info("===================");
// String targetClassName = joinPoint.getSignature().getDeclaringTypeName();
// String MethodName = joinPoint.getSignature().getName();
// String typeStr = joinPoint.getSignature().getDeclaringType().toString().split(" ")[0];
// log.info("类/接口:" + targetClassName + "(" + typeStr + ")");
// log.info("方法:" + MethodName);
// Long total = end - start;
// log.info("耗时: " + total + " ms!");
// map.get().put(targetClassName + "_" + MethodName, total);
// //return result;
// } catch (Throwable e) {
// long end = System.currentTimeMillis();
// log.info("====around " + joinPoint + "\tUse time : " + (end - start) + " ms with exception : "
// + e.getMessage());
// throw e;
// }
// }
//
//// //对Controller下面的方法执行前进行切入初始化开始时间
//// @Before(value = "execution(* com.appleyk.controller.*.*(..))")
//// public void beforMehhod(JoinPoint jp) {
//// startTime.set(System.currentTimeMillis());
//// }
////
//// //对Controller下面的方法执行后进行切入统计方法执行的次数和耗时情况
//// //注意这里的执行方法统计的数据不止包含Controller下面的方法也包括环绕切入的所有方法的统计信息
//// @AfterReturning(value = "execution(* com.appleyk.controller.*.*(..))")
//// public void afterMehhod(JoinPoint jp) {
//// long end = System.currentTimeMillis();
//// long total = end - startTime.get();
//// String methodName = jp.getSignature().getName();
//// log.info("连接点方法为:" + methodName + ",执行总耗时为:" +total+"ms");
////
//// //重新new一个map
//// Map<String, Long> map = new HashMap<>();
//////从map2中将最后的 连接点方法给移除了,替换成最终的,避免连接点方法多次进行叠加计算
//// //由于map2受ThreadLocal的保护这里不支持remove因此需要单开一个map进行数据交接
//// for(Map.Entry<String, Long> entry:map2.get().entrySet()){
//// if(entry.getKey().equals(methodName)){
//// map.put(methodName, total);
////
//// }else{
//// map.put(entry.getKey(), entry.getValue());
//// }
//// }
////
//// for (Map.Entry<String, Long> entry :map1.get().entrySet()) {
//// for(Map.Entry<String, Long> entry2 :map.entrySet()){
//// if(entry.getKey().equals(entry2.getKey())){
//// System.err.println(entry.getKey()+",被调用次数:"+entry.getValue()+",综合耗时:"+entry2.getValue()+"ms");
//// }
//// }
////
//// }
//// }
//
//}
*/

View File

@@ -3,11 +3,9 @@ package com.tencent.supersonic.chat.config;
import lombok.Data; import lombok.Data;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
@Configuration
@Data @Data
@Configuration
public class AggregatorConfig { public class AggregatorConfig {
@Value("${metric.aggregator.ratio.enable:true}") @Value("${metric.aggregator.ratio.enable:true}")
private Boolean enableRatio; private Boolean enableRatio;
} }

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.query.dsl.optimizer; package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.component.DSLOptimizer; import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
@@ -13,7 +13,7 @@ import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@Slf4j @Slf4j
public abstract class BaseDSLOptimizer implements DSLOptimizer { public abstract class BaseSemanticCorrector implements SemanticCorrector {
public static final String DATE_FIELD = "数据日期"; public static final String DATE_FIELD = "数据日期";
protected Map<String, String> getFieldToBizName(Long modelId) { protected Map<String, String> getFieldToBizName(Long modelId) {

View File

@@ -1,23 +1,24 @@
package com.tencent.supersonic.chat.query.dsl.optimizer; package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo; import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper; import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import java.util.List; import java.util.List;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@Slf4j @Slf4j
public class DateFieldCorrector extends BaseDSLOptimizer { public class DateFieldCorrector extends BaseSemanticCorrector {
@Override @Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) { public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
String sql = correctionInfo.getSql(); String sql = correctionInfo.getSql();
List<String> whereFields = CCJSqlParserUtils.getWhereFields(sql); List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(BaseDSLOptimizer.DATE_FIELD)) { if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DATE_FIELD)) {
String currentDate = DSLDateHelper.getCurrentDate(correctionInfo.getParseInfo().getModelId()); String currentDate = DSLDateHelper.getCurrentDate(correctionInfo.getParseInfo().getModelId());
sql = CCJSqlParserUtils.addWhere(sql, BaseDSLOptimizer.DATE_FIELD, currentDate); sql = SqlParserUpdateHelper.addWhere(sql, DATE_FIELD, currentDate);
} }
correctionInfo.setSql(sql); correctionInfo.setSql(sql);
return correctionInfo; return correctionInfo;

View File

@@ -0,0 +1,17 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FieldCorrector extends BaseSemanticCorrector {
@Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
String replaceFields = SqlParserUpdateHelper.replaceFields(correctionInfo.getSql(),
getFieldToBizName(correctionInfo.getParseInfo().getModelId()));
correctionInfo.setSql(replaceFields);
return correctionInfo;
}
}

View File

@@ -0,0 +1,48 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class FieldValueCorrector extends BaseSemanticCorrector {
@Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
Object context = correctionInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
if (Objects.isNull(context)) {
return correctionInfo;
}
DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class);
if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) {
return correctionInfo;
}
LLMReq llmReq = dslParseResult.getLlmReq();
List<ElementValue> linking = llmReq.getLinking();
if (CollectionUtils.isEmpty(linking)) {
return correctionInfo;
}
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
Collectors.groupingBy(ElementValue::getFieldValue,
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
String sql = SqlParserUpdateHelper.replaceValueFields(correctionInfo.getSql(), fieldValueToFieldNames);
correctionInfo.setSql(sql);
return correctionInfo;
}
}

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FunctionCorrector extends BaseSemanticCorrector {
@Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
String replaceFunction = SqlParserUpdateHelper.replaceFunction(correctionInfo.getSql());
correctionInfo.setSql(replaceFunction);
return correctionInfo;
}
}

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic.chat.query.dsl.optimizer; package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo; import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.StringUtil; import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -15,17 +15,17 @@ import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@Slf4j @Slf4j
public class QueryFilterAppend extends BaseDSLOptimizer { public class QueryFilterAppend extends BaseSemanticCorrector {
@Override @Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) throws JSQLParserException { public CorrectionInfo corrector(CorrectionInfo correctionInfo) throws JSQLParserException {
String queryFilter = getQueryFilter(correctionInfo.getQueryFilters()); String queryFilter = getQueryFilter(correctionInfo.getQueryFilters());
String sql = correctionInfo.getSql(); String sql = correctionInfo.getSql();
if (StringUtils.isNotEmpty(queryFilter)) { if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to sql :{}", queryFilter); log.info("add queryFilter to sql :{}", queryFilter);
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter); Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
sql = CCJSqlParserUtils.addWhere(sql, expression); sql = SqlParserUpdateHelper.addWhere(sql, expression);
} }
correctionInfo.setSql(sql); correctionInfo.setSql(sql);
return correctionInfo; return correctionInfo;

View File

@@ -1,7 +1,8 @@
package com.tencent.supersonic.chat.query.dsl.optimizer; package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo; import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet; import java.util.HashSet;
@@ -10,25 +11,28 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@Slf4j @Slf4j
public class SelectFieldAppendCorrector extends BaseDSLOptimizer { public class SelectFieldAppendCorrector extends BaseSemanticCorrector {
@Override @Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) { public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
String sql = correctionInfo.getSql(); String sql = correctionInfo.getSql();
if (CCJSqlParserUtils.hasAggregateFunction(sql)) { if (SqlParserSelectHelper.hasAggregateFunction(sql)) {
return correctionInfo; return correctionInfo;
} }
Set<String> selectFields = new HashSet<>(CCJSqlParserUtils.getSelectFields(sql)); Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql));
Set<String> whereFields = new HashSet<>(CCJSqlParserUtils.getWhereFields(sql)); Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(sql));
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) { if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
return correctionInfo; return correctionInfo;
} }
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
whereFields.removeAll(selectFields); whereFields.removeAll(selectFields);
whereFields.remove(TimeDimensionEnum.DAY.getName()); whereFields.remove(TimeDimensionEnum.DAY.getName());
whereFields.remove(TimeDimensionEnum.WEEK.getName()); whereFields.remove(TimeDimensionEnum.WEEK.getName());
whereFields.remove(TimeDimensionEnum.MONTH.getName()); whereFields.remove(TimeDimensionEnum.MONTH.getName());
String replaceFields = CCJSqlParserUtils.addFieldsToSelect(sql, new ArrayList<>(whereFields));
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
correctionInfo.setSql(replaceFields); correctionInfo.setSql(replaceFields);
return correctionInfo; return correctionInfo;
} }

View File

@@ -1,19 +1,19 @@
package com.tencent.supersonic.chat.query.dsl.optimizer; package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo; import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@Slf4j @Slf4j
public class TableNameCorrector extends BaseDSLOptimizer { public class TableNameCorrector extends BaseSemanticCorrector {
public static final String TABLE_PREFIX = "t_"; public static final String TABLE_PREFIX = "t_";
@Override @Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) { public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
Long modelId = correctionInfo.getParseInfo().getModelId(); Long modelId = correctionInfo.getParseInfo().getModelId();
String sqlOutput = correctionInfo.getSql(); String sqlOutput = correctionInfo.getSql();
String replaceTable = CCJSqlParserUtils.replaceTable(sqlOutput, TABLE_PREFIX + modelId); String replaceTable = SqlParserUpdateHelper.replaceTable(sqlOutput, TABLE_PREFIX + modelId);
correctionInfo.setSql(replaceTable); correctionInfo.setSql(replaceTable);
return correctionInfo; return correctionInfo;
} }

View File

@@ -2,19 +2,19 @@ package com.tencent.supersonic.chat.mapper;
import com.tencent.supersonic.chat.api.component.SchemaMapper; import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.stream.Collectors;
@Slf4j @Slf4j
@@ -33,7 +33,7 @@ public class EntityMapper implements SchemaMapper {
continue; continue;
} }
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch -> List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())) SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList()); .collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) { for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) { if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
@@ -51,7 +51,7 @@ public class EntityMapper implements SchemaMapper {
} }
private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch, private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch,
List<SchemaElementMatch> schemaElementMatchList) { List<SchemaElementMatch> schemaElementMatchList) {
List<SchemaElementMatch> entitySchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch -> List<SchemaElementMatch> entitySchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
SchemaElementType.ENTITY.equals(schemaElementMatch.getElement().getType())) SchemaElementType.ENTITY.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList()); .collect(Collectors.toList());

View File

@@ -2,12 +2,12 @@ package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term; import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.component.SchemaMapper; import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.service.SchemaService; import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.knowledge.utils.HanlpHelper; import com.tencent.supersonic.knowledge.utils.HanlpHelper;
@@ -43,13 +43,13 @@ public class FuzzyNameMapper implements SchemaMapper {
log.debug("after db mapper,mapInfo:{}", queryContext.getMapInfo()); log.debug("after db mapper,mapInfo:{}", queryContext.getMapInfo());
} }
private void detectAndAddToSchema(QueryContext queryContext, List<Term> terms, List<SchemaElement> Models, private void detectAndAddToSchema(QueryContext queryContext, List<Term> terms, List<SchemaElement> models,
SchemaElementType schemaElementType) { SchemaElementType schemaElementType) {
try { try {
Map<String, Set<SchemaElement>> ModelResultSet = getResultSet(queryContext, terms, Models); Map<String, Set<SchemaElement>> modelResultSet = getResultSet(queryContext, terms, models);
addToSchemaMapInfo(ModelResultSet, queryContext.getMapInfo(), schemaElementType); addToSchemaMapInfo(modelResultSet, queryContext.getMapInfo(), schemaElementType);
} catch (Exception e) { } catch (Exception e) {
log.error("detectAndAddToSchema error", e); log.error("detectAndAddToSchema error", e);
@@ -57,20 +57,21 @@ public class FuzzyNameMapper implements SchemaMapper {
} }
private Map<String, Set<SchemaElement>> getResultSet(QueryContext queryContext, List<Term> terms, private Map<String, Set<SchemaElement>> getResultSet(QueryContext queryContext, List<Term> terms,
List<SchemaElement> Models) { List<SchemaElement> models) {
String queryText = queryContext.getRequest().getQueryText(); String queryText = queryContext.getRequest().getQueryText();
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class); MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest());
Double metricDimensionThresholdConfig = getThreshold(queryContext, mapperHelper); Double metricDimensionThresholdConfig = getThreshold(queryContext, mapperHelper);
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(Models); Map<String, Set<SchemaElement>> nameToItems = getNameToItems(models);
Map<Integer, Integer> regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length)) Map<Integer, Integer> regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length))
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2)); .collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
Map<String, Set<SchemaElement>> ModelResultSet = new HashMap<>(); Map<String, Set<SchemaElement>> modelResultSet = new HashMap<>();
for (Integer startIndex = 0; startIndex <= queryText.length() - 1; ) { for (Integer startIndex = 0; startIndex <= queryText.length() - 1; ) {
for (Integer endIndex = startIndex; endIndex <= queryText.length(); ) { for (Integer endIndex = startIndex; endIndex <= queryText.length(); ) {
endIndex = mapperHelper.getStepIndex(regOffsetToLength, endIndex); endIndex = mapperHelper.getStepIndex(regOffsetToLength, endIndex);
@@ -86,8 +87,12 @@ public class FuzzyNameMapper implements SchemaMapper {
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) { || mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
continue; continue;
} }
Set<SchemaElement> preSchemaElements = ModelResultSet.putIfAbsent(detectSegment, if (!CollectionUtils.isEmpty(modelIds)) {
schemaElements); schemaElements = schemaElements.stream()
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
.collect(Collectors.toSet());
}
Set<SchemaElement> preSchemaElements = modelResultSet.putIfAbsent(detectSegment, schemaElements);
if (Objects.nonNull(preSchemaElements)) { if (Objects.nonNull(preSchemaElements)) {
preSchemaElements.addAll(schemaElements); preSchemaElements.addAll(schemaElements);
} }
@@ -95,7 +100,7 @@ public class FuzzyNameMapper implements SchemaMapper {
} }
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex); startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
} }
return ModelResultSet; return modelResultSet;
} }
private Double getThreshold(QueryContext queryContext, MapperHelper mapperHelper) { private Double getThreshold(QueryContext queryContext, MapperHelper mapperHelper) {
@@ -103,9 +108,9 @@ public class FuzzyNameMapper implements SchemaMapper {
Double metricDimensionThresholdConfig = mapperHelper.getMetricDimensionThresholdConfig(); Double metricDimensionThresholdConfig = mapperHelper.getMetricDimensionThresholdConfig();
Double metricDimensionMinThresholdConfig = mapperHelper.getMetricDimensionMinThresholdConfig(); Double metricDimensionMinThresholdConfig = mapperHelper.getMetricDimensionMinThresholdConfig();
Map<Long, List<SchemaElementMatch>> ModelElementMatches = queryContext.getMapInfo() Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo()
.getModelElementMatches(); .getModelElementMatches();
boolean existElement = ModelElementMatches.entrySet().stream() boolean existElement = modelElementMatches.entrySet().stream()
.anyMatch(entry -> entry.getValue().size() >= 1); .anyMatch(entry -> entry.getValue().size() >= 1);
if (!existElement) { if (!existElement) {
@@ -114,13 +119,13 @@ public class FuzzyNameMapper implements SchemaMapper {
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
: metricDimensionMinThresholdConfig; : metricDimensionMinThresholdConfig;
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}", log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
ModelElementMatches, metricDimensionThresholdConfig); modelElementMatches, metricDimensionThresholdConfig);
} }
return metricDimensionThresholdConfig; return metricDimensionThresholdConfig;
} }
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> Models) { private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
return Models.stream().collect( return models.stream().collect(
Collectors.toMap(SchemaElement::getName, a -> { Collectors.toMap(SchemaElement::getName, a -> {
Set<SchemaElement> result = new HashSet<>(); Set<SchemaElement> result = new HashSet<>();
result.add(a); result.add(a);

View File

@@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.NatureHelper; import com.tencent.supersonic.knowledge.utils.NatureHelper;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.dictionary.DictWordType; import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.MapResult; import com.tencent.supersonic.knowledge.dictionary.MapResult;
@@ -21,6 +21,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
@@ -37,10 +38,13 @@ public class HanlpDictMapper implements SchemaMapper {
for (Term term : terms) { for (Term term : terms) {
log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency()); log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
} }
Long modelId = queryContext.getRequest().getModelId();
QueryMatchStrategy matchStrategy = ContextUtils.getBean(QueryMatchStrategy.class); QueryMatchStrategy matchStrategy = ContextUtils.getBean(QueryMatchStrategy.class);
Map<MatchText, List<MapResult>> matchResult = matchStrategy.match(queryText, terms, modelId); MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest());
Map<MatchText, List<MapResult>> matchResult = matchStrategy.match(queryContext.getRequest(), terms,
detectModelIds);
List<MapResult> matches = getMatches(matchResult); List<MapResult> matches = getMatches(matchResult);
@@ -51,6 +55,7 @@ public class HanlpDictMapper implements SchemaMapper {
convertTermsToSchemaMapInfo(matches, queryContext.getMapInfo(), terms); convertTermsToSchemaMapInfo(matches, queryContext.getMapInfo(), terms);
} }
private void convertTermsToSchemaMapInfo(List<MapResult> mapResults, SchemaMapInfo schemaMap, List<Term> terms) { private void convertTermsToSchemaMapInfo(List<MapResult> mapResults, SchemaMapInfo schemaMap, List<Term> terms) {
if (CollectionUtils.isEmpty(mapResults)) { if (CollectionUtils.isEmpty(mapResults)) {
return; return;

View File

@@ -1,10 +1,16 @@
package com.tencent.supersonic.chat.mapper; package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.algorithm.EditDistance; import com.hankcs.hanlp.algorithm.EditDistance;
import com.tencent.supersonic.chat.utils.NatureHelper; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Data; import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
@@ -84,5 +90,24 @@ public class MapperHelper {
detectSegment.length()); detectSegment.length());
} }
public Set<Long> getModelIds(QueryReq request) {
Long modelId = request.getModelId();
AgentService agentService = ContextUtils.getBean(AgentService.class);
Set<Long> detectModelIds = agentService.getDslToolsModelIds(request.getAgentId(), null);
if (Objects.nonNull(detectModelIds)) {
detectModelIds = detectModelIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
}
if (Objects.nonNull(modelId) && modelId > 0 && Objects.nonNull(detectModelIds)) {
if (detectModelIds.contains(modelId)) {
Set<Long> result = new HashSet<>();
result.add(modelId);
return result;
}
}
return detectModelIds;
}
} }

View File

@@ -1,15 +1,17 @@
package com.tencent.supersonic.chat.mapper; package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term; import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.knowledge.dictionary.MapResult; import com.tencent.supersonic.knowledge.dictionary.MapResult;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
/** /**
* match strategy * match strategy
*/ */
public interface MatchStrategy { public interface MatchStrategy {
Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectModelId); Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelId);
} }

View File

@@ -3,25 +3,25 @@ package com.tencent.supersonic.chat.mapper;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.component.SchemaMapper; import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j @Slf4j
public class QueryFilterMapper implements SchemaMapper { public class QueryFilterMapper implements SchemaMapper {
private Long FREQUENCY = 9999999L; private Long frequency = 9999999L;
private double SIMILARITY = 1.0; private double similarity = 1.0;
@Override @Override
public void map(QueryContext queryContext) { public void map(QueryContext queryContext) {
@@ -49,7 +49,7 @@ public class QueryFilterMapper implements SchemaMapper {
} }
private List<SchemaElementMatch> addValueSchemaElementMatch(List<SchemaElementMatch> candidateElementMatches, private List<SchemaElementMatch> addValueSchemaElementMatch(List<SchemaElementMatch> candidateElementMatches,
QueryFilters queryFilter) { QueryFilters queryFilter) {
if (queryFilter == null || CollectionUtils.isEmpty(queryFilter.getFilters())) { if (queryFilter == null || CollectionUtils.isEmpty(queryFilter.getFilters())) {
return candidateElementMatches; return candidateElementMatches;
} }
@@ -65,9 +65,9 @@ public class QueryFilterMapper implements SchemaMapper {
.build(); .build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element) .element(element)
.frequency(FREQUENCY) .frequency(frequency)
.word(String.valueOf(filter.getValue())) .word(String.valueOf(filter.getValue()))
.similarity(SIMILARITY) .similarity(similarity)
.detectWord(Constants.EMPTY) .detectWord(Constants.EMPTY)
.build(); .build();
candidateElementMatches.add(schemaElementMatch); candidateElementMatches.add(schemaElementMatch);
@@ -76,7 +76,7 @@ public class QueryFilterMapper implements SchemaMapper {
} }
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter, private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,
List<SchemaElementMatch> schemaElementMatches) { List<SchemaElementMatch> schemaElementMatches) {
List<SchemaElementMatch> valueSchemaElements = schemaElementMatches.stream().filter(schemaElementMatch -> List<SchemaElementMatch> valueSchemaElements = schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())) SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList()); .collect(Collectors.toList());

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.mapper; package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term; import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.MapResult; import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.service.SearchService; import com.tencent.supersonic.knowledge.service.SearchService;
import java.util.ArrayList; import java.util.ArrayList;
@@ -32,7 +32,8 @@ public class QueryMatchStrategy implements MatchStrategy {
private MapperHelper mapperHelper; private MapperHelper mapperHelper;
@Override @Override
public Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectModelId) { public Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
String text = queryReq.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null; return null;
} }
@@ -43,18 +44,19 @@ public class QueryMatchStrategy implements MatchStrategy {
List<Integer> offsetList = terms.stream().sorted(Comparator.comparing(Term::getOffset)) List<Integer> offsetList = terms.stream().sorted(Comparator.comparing(Term::getOffset))
.map(term -> term.getOffset()).collect(Collectors.toList()); .map(term -> term.getOffset()).collect(Collectors.toList());
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectModelId:{}", terms, log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectModelIds:{}", terms,
regOffsetToLength, offsetList, detectModelId); regOffsetToLength, offsetList, detectModelIds);
List<MapResult> detects = detect(text, regOffsetToLength, offsetList, detectModelId); List<MapResult> detects = detect(queryReq, regOffsetToLength, offsetList, detectModelIds);
Map<MatchText, List<MapResult>> result = new HashMap<>(); Map<MatchText, List<MapResult>> result = new HashMap<>();
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects); result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
return result; return result;
} }
private List<MapResult> detect(String text, Map<Integer, Integer> regOffsetToLength, List<Integer> offsetList, private List<MapResult> detect(QueryReq queryReq, Map<Integer, Integer> regOffsetToLength, List<Integer> offsetList,
Long detectModelId) { Set<Long> detectModelIds) {
String text = queryReq.getQueryText();
List<MapResult> results = Lists.newArrayList(); List<MapResult> results = Lists.newArrayList();
for (Integer index = 0; index <= text.length() - 1; ) { for (Integer index = 0; index <= text.length() - 1; ) {
@@ -65,7 +67,7 @@ public class QueryMatchStrategy implements MatchStrategy {
int offset = mapperHelper.getStepOffset(offsetList, index); int offset = mapperHelper.getStepOffset(offsetList, index);
i = mapperHelper.getStepIndex(regOffsetToLength, i); i = mapperHelper.getStepIndex(regOffsetToLength, i);
if (i <= text.length()) { if (i <= text.length()) {
List<MapResult> mapResults = detectByStep(text, detectModelId, index, i, offset); List<MapResult> mapResults = detectByStep(queryReq, detectModelIds, index, i, offset);
selectMapResultInOneRound(mapResultRowSet, mapResults); selectMapResultInOneRound(mapResultRowSet, mapResults);
} }
} }
@@ -102,16 +104,19 @@ public class QueryMatchStrategy implements MatchStrategy {
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures()); return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
} }
private List<MapResult> detectByStep(String text, Long detectModelId, Integer index, Integer i, int offset) { private List<MapResult> detectByStep(QueryReq queryReq, Set<Long> detectModelIds, Integer index, Integer i,
int offset) {
String text = queryReq.getQueryText();
Integer agentId = queryReq.getAgentId();
String detectSegment = text.substring(index, i); String detectSegment = text.substring(index, i);
// step1. pre search // step1. pre search
Integer oneDetectionMaxSize = mapperHelper.getOneDetectionMaxSize(); Integer oneDetectionMaxSize = mapperHelper.getOneDetectionMaxSize();
LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize) LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize, agentId,
.stream().collect(Collectors.toCollection(LinkedHashSet::new)); detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
// step2. suffix search // step2. suffix search
LinkedHashSet<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionMaxSize) LinkedHashSet<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionMaxSize,
.stream().collect(Collectors.toCollection(LinkedHashSet::new)); agentId, detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
mapResults.addAll(suffixMapResults); mapResults.addAll(suffixMapResults);
@@ -121,27 +126,15 @@ public class QueryMatchStrategy implements MatchStrategy {
// step3. merge pre/suffix result // step3. merge pre/suffix result
mapResults = mapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length())) mapResults = mapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toCollection(LinkedHashSet::new)); .collect(Collectors.toCollection(LinkedHashSet::new));
// step4. filter by classId
if (Objects.nonNull(detectModelId) && detectModelId > 0) { // step4. filter by similarity
log.debug("detectModelId:{}, before parseResults:{}", mapResults);
mapResults = mapResults.stream().map(entry -> {
List<String> natures = entry.getNatures().stream().filter(
nature -> nature.startsWith(DictWordType.NATURE_SPILT + detectModelId) || (nature.startsWith(
DictWordType.NATURE_SPILT))
).collect(Collectors.toList());
entry.setNatures(natures);
return entry;
}).collect(Collectors.toCollection(LinkedHashSet::new));
log.info("after modelId parseResults:{}", mapResults);
}
// step5. filter by similarity
mapResults = mapResults.stream() mapResults = mapResults.stream()
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName()) .filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
>= mapperHelper.getThresholdMatch(term.getNatures())) >= mapperHelper.getThresholdMatch(term.getNatures()))
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures())) .filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
.collect(Collectors.toCollection(LinkedHashSet::new)); .collect(Collectors.toCollection(LinkedHashSet::new));
log.debug("after isSimilarity parseResults:{}", mapResults); log.info("after isSimilarity parseResults:{}", mapResults);
mapResults = mapResults.stream().map(parseResult -> { mapResults = mapResults.stream().map(parseResult -> {
parseResult.setOffset(offset); parseResult.setOffset(offset);
@@ -149,7 +142,7 @@ public class QueryMatchStrategy implements MatchStrategy {
return parseResult; return parseResult;
}).collect(Collectors.toCollection(LinkedHashSet::new)); }).collect(Collectors.toCollection(LinkedHashSet::new));
// step6. take only one dimension or 10 metric/dimension value per rond. // step5. take only one dimension or 10 metric/dimension value per rond.
List<MapResult> dimensionMetrics = mapResults.stream() List<MapResult> dimensionMetrics = mapResults.stream()
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures())) .filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
.collect(Collectors.toList()) .collect(Collectors.toList())

View File

@@ -2,12 +2,14 @@ package com.tencent.supersonic.chat.mapper;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.hankcs.hanlp.seg.common.Term; import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.knowledge.dictionary.DictWordType; import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.MapResult; import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.service.SearchService; import com.tencent.supersonic.knowledge.service.SearchService;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
@@ -23,9 +25,8 @@ public class SearchMatchStrategy implements MatchStrategy {
private static final int SEARCH_SIZE = 3; private static final int SEARCH_SIZE = 3;
@Override @Override
public Map<MatchText, List<MapResult>> match(String text, List<Term> originals, public Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> originals, Set<Long> detectModelIds) {
Long detectModelId) { String text = queryReq.getQueryText();
Map<Integer, Integer> regOffsetToLength = originals.stream() Map<Integer, Integer> regOffsetToLength = originals.stream()
.filter(entry -> !entry.nature.toString().startsWith(DictWordType.NATURE_SPILT)) .filter(entry -> !entry.nature.toString().startsWith(DictWordType.NATURE_SPILT))
.collect(Collectors.toMap(Term::getOffset, value -> value.word.length(), .collect(Collectors.toMap(Term::getOffset, value -> value.word.length(),
@@ -51,24 +52,16 @@ public class SearchMatchStrategy implements MatchStrategy {
String detectSegment = text.substring(detectIndex); String detectSegment = text.substring(detectIndex);
if (StringUtils.isNotEmpty(detectSegment)) { if (StringUtils.isNotEmpty(detectSegment)) {
List<MapResult> mapResults = SearchService.prefixSearch(detectSegment); List<MapResult> mapResults = SearchService.prefixSearch(detectSegment,
List<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, SEARCH_SIZE); SearchService.SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
List<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, SEARCH_SIZE,
queryReq.getAgentId(), detectModelIds);
mapResults.addAll(suffixMapResults); mapResults.addAll(suffixMapResults);
// remove entity name where search // remove entity name where search
mapResults = mapResults.stream().filter(entry -> { mapResults = mapResults.stream().filter(entry -> {
List<String> natures = entry.getNatures().stream() List<String> natures = entry.getNatures().stream()
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType())) .filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType()))
.filter(nature -> { .collect(Collectors.toList());
if (Objects.isNull(detectModelId) || detectModelId <= 0) {
return true;
}
if (nature.startsWith(DictWordType.NATURE_SPILT + detectModelId)
&& nature.startsWith(DictWordType.NATURE_SPILT)) {
return true;
}
return false;
}
).collect(Collectors.toList());
if (CollectionUtils.isEmpty(natures)) { if (CollectionUtils.isEmpty(natures)) {
return false; return false;
} }
@@ -84,4 +77,4 @@ public class SearchMatchStrategy implements MatchStrategy {
); );
return regTextMap; return regTextMap;
} }
} }

View File

@@ -2,8 +2,9 @@ package com.tencent.supersonic.chat.parser;
import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.*; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.query.dsl.DSLQuery; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
/** /**
@@ -21,7 +22,7 @@ public class SatisfactionChecker {
// check all the parse info in candidate // check all the parse info in candidate
public static boolean check(QueryContext queryContext) { public static boolean check(QueryContext queryContext) {
for (SemanticQuery query : queryContext.getCandidateQueries()) { for (SemanticQuery query : queryContext.getCandidateQueries()) {
if (query.getQueryMode().equals(DSLQuery.QUERY_MODE)) { if (query.getQueryMode().equals(DslQuery.QUERY_MODE)) {
continue; continue;
} }
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) { if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {
@@ -32,7 +33,7 @@ public class SatisfactionChecker {
} }
private static boolean checkThreshold(String queryText, SemanticParseInfo semanticParseInfo) { private static boolean checkThreshold(String queryText, SemanticParseInfo semanticParseInfo) {
int queryTextLength = queryText.length(); int queryTextLength = queryText.replaceAll(" ", "").length();
double degree = semanticParseInfo.getScore() / queryTextLength; double degree = semanticParseInfo.getScore() / queryTextLength;
if (queryTextLength > QUERY_TEXT_LENGTH_THRESHOLD) { if (queryTextLength > QUERY_TEXT_LENGTH_THRESHOLD) {
if (degree < LONG_TEXT_THRESHOLD) { if (degree < LONG_TEXT_THRESHOLD) {

View File

@@ -6,15 +6,6 @@ public class DSLDateHelper {
public static String getCurrentDate(Long modelId) { public static String getCurrentDate(Long modelId) {
return DateUtils.getBeforeDate(4); return DateUtils.getBeforeDate(4);
// ChatConfigFilter filter = new ChatConfigFilter();
// filter.setModelId(modelId);
//
// List<ChatConfigResp> configResps = ContextUtils.getBean(ConfigService.class).search(filter, null);
// if (CollectionUtils.isEmpty(configResps)) {
// return
// }
// ChatConfigResp chatConfigResp = configResps.get(0);
// chatConfigResp.getChatDetailConfig().getChatDefaultConfig().get
} }
} }

View File

@@ -2,13 +2,21 @@ package com.tencent.supersonic.chat.parser.llm.dsl;
import com.tencent.supersonic.chat.agent.tool.DslTool; import com.tencent.supersonic.chat.agent.tool.DslTool;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.plugin.PluginParseResult; import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
import com.tencent.supersonic.chat.query.dsl.LLMResp; import com.tencent.supersonic.chat.query.llm.dsl.LLMResp;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor;
@Data @Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class DSLParseResult { public class DSLParseResult {
private LLMReq llmReq;
private LLMResp llmResp; private LLMResp llmResp;
private QueryReq request; private QueryReq request;

View File

@@ -1,42 +1,51 @@
package com.tencent.supersonic.chat.parser.llm.dsl; package com.tencent.supersonic.chat.parser.llm.dsl;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType; import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.DslTool; import com.tencent.supersonic.chat.agent.tool.DslTool;
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.LLMConfig; import com.tencent.supersonic.chat.config.LLMConfig;
import com.tencent.supersonic.chat.corrector.BaseSemanticCorrector;
import com.tencent.supersonic.chat.parser.SatisfactionChecker; import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.parser.function.ModelResolver; import com.tencent.supersonic.chat.parser.plugin.function.ModelResolver;
import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.dsl.DSLQuery; import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.chat.query.dsl.LLMReq; import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
import com.tencent.supersonic.chat.query.dsl.LLMReq.ElementValue; import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.dsl.LLMResp; import com.tencent.supersonic.chat.query.llm.dsl.LLMResp;
import com.tencent.supersonic.chat.query.dsl.optimizer.BaseDSLOptimizer;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.AgentService; import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService; import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@@ -49,104 +58,217 @@ import org.springframework.util.CollectionUtils;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;
@Slf4j @Slf4j
public class LLMDSLParser implements SemanticParser { public class LLMDslParser implements SemanticParser {
public static final double FUNCTION_BONUS_THRESHOLD = 201; public static final double function_bonus_threshold = 201;
@Override @Override
public void parse(QueryContext queryCtx, ChatContext chatCtx) { public void parse(QueryContext queryCtx, ChatContext chatCtx) {
final LLMConfig llmConfig = ContextUtils.getBean(LLMConfig.class); QueryReq request = queryCtx.getRequest();
LLMConfig llmConfig = ContextUtils.getBean(LLMConfig.class);
if (StringUtils.isEmpty(llmConfig.getUrl()) || SatisfactionChecker.check(queryCtx)) { if (StringUtils.isEmpty(llmConfig.getUrl()) || SatisfactionChecker.check(queryCtx)) {
log.info("llmConfig:{}, skip function parser, queryText:{}", llmConfig, log.info("llmConfig:{}, skip function parser, queryText:{}", llmConfig, request.getQueryText());
queryCtx.getRequest().getQueryText());
return; return;
} }
List<DslTool> dslTools = getDslTools(queryCtx.getRequest().getAgentId());
Set<Long> distinctModelIds = dslTools.stream().map(DslTool::getModelIds)
.flatMap(Collection::stream)
.collect(Collectors.toSet());
try { try {
ModelResolver modelResolver = ComponentFactory.getModelResolver(); Long modelId = getModelId(queryCtx, chatCtx, request.getAgentId());
Long modelId = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
log.info("resolve modelId:{},dslModels:{}", modelId, distinctModelIds);
if (Objects.isNull(modelId) || modelId <= 0) { if (Objects.isNull(modelId) || modelId <= 0) {
return; return;
} }
Optional<DslTool> dslToolOptional = dslTools.stream().filter(tool ->
tool.getModelIds().contains(modelId)).findFirst(); DslTool dslTool = getDslTool(request, modelId);
if (!dslToolOptional.isPresent()) { if (Objects.isNull(dslTool)) {
log.info("no dsl tool in this agent, skip dsl parser"); log.info("no dsl tool in this agent, skip dsl parser");
return; return;
} }
DslTool dslTool = dslToolOptional.get();
LLMResp llmResp = requestLLM(queryCtx, modelId); LLMReq llmReq = getLlmReq(queryCtx, modelId);
LLMResp llmResp = requestLLM(llmReq, modelId, llmConfig);
if (Objects.isNull(llmResp)) { if (Objects.isNull(llmResp)) {
return; return;
} }
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(DSLQuery.QUERY_MODE); DSLParseResult dslParseResult = DSLParseResult.builder().request(request).dslTool(dslTool).llmReq(llmReq)
.llmResp(llmResp).build();
SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, dslTool, dslParseResult);
if (Objects.nonNull(modelId) && modelId > 0) {
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId)); String correctorSql = getCorrectorSql(queryCtx, parseInfo, llmResp.getSqlOutput());
}
DSLParseResult dslParseResult = new DSLParseResult(); llmResp.setCorrectorSql(correctorSql);
dslParseResult.setRequest(queryCtx.getRequest());
dslParseResult.setLlmResp(llmResp); setFilter(correctorSql, modelId, parseInfo);
dslParseResult.setDslTool(dslToolOptional.get());
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, dslParseResult);
properties.put("type", "internal");
properties.put("name", dslTool.getName());
parseInfo.setProperties(properties);
parseInfo.setScore(FUNCTION_BONUS_THRESHOLD);
parseInfo.setQueryMode(semanticQuery.getQueryMode());
SchemaElement Model = new SchemaElement();
Model.setModel(modelId);
Model.setId(modelId);
parseInfo.setModel(Model);
queryCtx.getCandidateQueries().add(semanticQuery);
} catch (Exception e) { } catch (Exception e) {
log.error("LLMDSLParser error", e); log.error("LLMDSLParser error", e);
} }
} }
public void setFilter(String correctorSql, Long modelId, SemanticParseInfo parseInfo) {
private LLMResp requestLLM(QueryContext queryCtx, Long modelId) { List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
long startTime = System.currentTimeMillis(); if (CollectionUtils.isEmpty(expressions)) {
String queryText = queryCtx.getRequest().getQueryText(); return;
final LLMConfig llmConfig = ContextUtils.getBean(LLMConfig.class);
if (StringUtils.isEmpty(llmConfig.getUrl())) {
log.warn("llmConfig url is null, skip llm parser");
return null;
} }
//set dataInfo
try {
DateConf dateInfo = getDateInfo(expressions);
parseInfo.setDateInfo(dateInfo);
} catch (Exception e) {
log.error("set dateInfo error :", e);
}
//set filter
try {
Map<String, SchemaElement> bizNameToElement = getBizNameToElement(modelId);
List<QueryFilter> result = getDimensionFilter(bizNameToElement, expressions);
parseInfo.getDimensionFilters().addAll(result);
} catch (Exception e) {
log.error("set dimensionFilter error :", e);
}
}
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> bizNameToElement,
List<FilterExpression> filterExpressions) {
List<QueryFilter> result = Lists.newArrayList();
for (FilterExpression expression : filterExpressions) {
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(expression.getFieldValue());
String bizName = expression.getFieldName();
SchemaElement schemaElement = bizNameToElement.get(bizName);
if (Objects.isNull(schemaElement)) {
continue;
}
String fieldName = schemaElement.getName();
dimensionFilter.setName(fieldName);
dimensionFilter.setBizName(bizName);
dimensionFilter.setElementID(schemaElement.getId());
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
dimensionFilter.setOperator(operatorEnum);
result.add(dimensionFilter);
}
return result;
}
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
List<FilterExpression> dateExpressions = filterExpressions.stream()
.filter(expression -> {
List<String> nameList = TimeDimensionEnum.getNameList();
if (StringUtils.isEmpty(expression.getFieldName())) {
return false;
}
return nameList.contains(expression.getFieldName().toLowerCase());
}).collect(Collectors.toList());
if (CollectionUtils.isEmpty(dateExpressions)) {
return new DateConf();
}
DateConf dateInfo = new DateConf();
dateInfo.setDateMode(DateMode.BETWEEN);
FilterExpression firstExpression = dateExpressions.get(0);
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
dateInfo.setDateMode(DateMode.BETWEEN);
return dateInfo;
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
}
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
}
}
return dateInfo;
}
private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator,
FilterOperatorEnum... operatorEnums) {
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
}
private boolean hasSecondDate(List<FilterExpression> dateExpressions) {
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
}
private String getCorrectorSql(QueryContext queryCtx, SemanticParseInfo parseInfo, String sql) {
CorrectionInfo correctionInfo = CorrectionInfo.builder()
.queryFilters(queryCtx.getRequest().getQueryFilters()).sql(sql)
.parseInfo(parseInfo).build();
List<SemanticCorrector> dslCorrections = ComponentFactory.getSqlCorrections();
dslCorrections.forEach(dslCorrection -> {
try {
dslCorrection.corrector(correctionInfo);
log.info("sqlCorrection:{} sql:{}", dslCorrection.getClass().getSimpleName(),
correctionInfo.getSql());
} catch (Exception e) {
log.error("sqlCorrection:{} execute error,correctionInfo:{}", dslCorrection, correctionInfo, e);
}
});
return correctionInfo.getSql();
}
private SemanticParseInfo getParseInfo(QueryContext queryCtx, Long modelId, DslTool dslTool,
DSLParseResult dslParseResult) {
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(DslQuery.QUERY_MODE);
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId));
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, dslParseResult);
properties.put("type", "internal");
properties.put("name", dslTool.getName());
parseInfo.setProperties(properties);
parseInfo.setScore(function_bonus_threshold);
parseInfo.setQueryMode(semanticQuery.getQueryMode());
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName(); Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
LLMReq llmReq = new LLMReq(); SchemaElement model = new SchemaElement();
llmReq.setQueryText(queryText); model.setModel(modelId);
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema(); model.setId(modelId);
llmSchema.setModelName(modelIdToName.get(modelId)); model.setName(modelIdToName.get(modelId));
llmSchema.setDomainName(modelIdToName.get(modelId)); parseInfo.setModel(model);
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema); queryCtx.getCandidateQueries().add(semanticQuery);
fieldNameList.add(BaseDSLOptimizer.DATE_FIELD); return parseInfo;
llmSchema.setFieldNameList(fieldNameList); }
llmReq.setSchema(llmSchema);
List<ElementValue> linking = new ArrayList<>();
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
llmReq.setLinking(linking);
String currentDate = DSLDateHelper.getCurrentDate(modelId);
llmReq.setCurrentDate(currentDate);
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq); private DslTool getDslTool(QueryReq request, Long modelId) {
AgentService agentService = ContextUtils.getBean(AgentService.class);
List<DslTool> dslTools = agentService.getDslTools(request.getAgentId(), AgentToolType.DSL);
Optional<DslTool> dslToolOptional = dslTools.stream().filter(tool -> tool.getModelIds().contains(modelId))
.findFirst();
return dslToolOptional.orElse(null);
}
private Long getModelId(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) {
AgentService agentService = ContextUtils.getBean(AgentService.class);
Set<Long> distinctModelIds = agentService.getDslToolsModelIds(agentId, AgentToolType.DSL);
ModelResolver modelResolver = ComponentFactory.getModelResolver();
Long modelId = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
log.info("resolve modelId:{},dslModels:{}", modelId, distinctModelIds);
return modelId;
}
private LLMResp requestLLM(LLMReq llmReq, Long modelId, LLMConfig llmConfig) {
String questUrl = llmConfig.getUrl() + llmConfig.getQueryToSqlPath(); String questUrl = llmConfig.getUrl() + llmConfig.getQueryToSqlPath();
long startTime = System.currentTimeMillis();
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class); RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
try { try {
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON); headers.setContentType(MediaType.APPLICATION_JSON);
@@ -163,6 +285,27 @@ public class LLMDSLParser implements SemanticParser {
return null; return null;
} }
private LLMReq getLlmReq(QueryContext queryCtx, Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
String queryText = queryCtx.getRequest().getQueryText();
LLMReq llmReq = new LLMReq();
llmReq.setQueryText(queryText);
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
llmSchema.setModelName(modelIdToName.get(modelId));
llmSchema.setDomainName(modelIdToName.get(modelId));
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema);
fieldNameList.add(BaseSemanticCorrector.DATE_FIELD);
llmSchema.setFieldNameList(fieldNameList);
llmReq.setSchema(llmSchema);
List<ElementValue> linking = new ArrayList<>();
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
llmReq.setLinking(linking);
String currentDate = DSLDateHelper.getCurrentDate(modelId);
llmReq.setCurrentDate(currentDate);
return llmReq;
}
private List<ElementValue> getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) { private List<ElementValue> getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema); Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
@@ -170,23 +313,37 @@ public class LLMDSLParser implements SemanticParser {
if (CollectionUtils.isEmpty(matchedElements)) { if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>(); return new ArrayList<>();
} }
Set<ElementValue> valueMatches = matchedElements.stream() Set<ElementValue> valueMatches = matchedElements
.stream()
.filter(elementMatch -> !elementMatch.isInherited())
.filter(schemaElementMatch -> { .filter(schemaElementMatch -> {
SchemaElementType type = schemaElementMatch.getElement().getType(); SchemaElementType type = schemaElementMatch.getElement().getType();
return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type); return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type);
}) })
.map(elementMatch -> .map(elementMatch -> {
{ ElementValue elementValue = new ElementValue();
ElementValue elementValue = new ElementValue(); elementValue.setFieldName(itemIdToName.get(elementMatch.getElement().getId()));
elementValue.setFieldName(itemIdToName.get(elementMatch.getElement().getId())); elementValue.setFieldValue(elementMatch.getWord());
elementValue.setFieldValue(elementMatch.getWord()); return elementValue;
return elementValue; }).collect(Collectors.toSet());
}
)
.collect(Collectors.toSet());
return new ArrayList<>(valueMatches); return new ArrayList<>(valueMatches);
} }
protected Map<String, SchemaElement> getBizNameToElement(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dimensions = semanticSchema.getDimensions();
List<SchemaElement> metrics = semanticSchema.getMetrics();
List<SchemaElement> allElements = Lists.newArrayList();
allElements.addAll(dimensions);
allElements.addAll(metrics);
return allElements.stream()
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
.collect(Collectors.toMap(SchemaElement::getBizName, Function.identity(), (value1, value2) -> value2));
}
private List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) { private List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema); Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
@@ -197,9 +354,9 @@ public class LLMDSLParser implements SemanticParser {
Set<String> fieldNameList = matchedElements.stream() Set<String> fieldNameList = matchedElements.stream()
.filter(schemaElementMatch -> { .filter(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType(); SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.METRIC.equals(elementType) || return SchemaElementType.METRIC.equals(elementType)
SchemaElementType.DIMENSION.equals(elementType) || || SchemaElementType.DIMENSION.equals(elementType)
SchemaElementType.VALUE.equals(elementType); || SchemaElementType.VALUE.equals(elementType);
}) })
.map(schemaElementMatch -> { .map(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType(); SchemaElementType elementType = schemaElementMatch.getElement().getType();
@@ -220,18 +377,4 @@ public class LLMDSLParser implements SemanticParser {
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2)); .collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
} }
private List<DslTool> getDslTools(Integer agentId) {
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent agent = agentService.getAgent(agentId);
if (agent == null) {
return Lists.newArrayList();
}
List<String> tools = agent.getTools(AgentToolType.DSL);
if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList();
}
return tools.stream().map(tool -> JSONObject.parseObject(tool, DslTool.class))
.collect(Collectors.toList());
}
} }

View File

@@ -7,11 +7,17 @@ import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool; import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
import com.tencent.supersonic.chat.api.component.SemanticLayer; import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.*; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.SatisfactionChecker; import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.query.metricInterpret.MetricInterpretQuery; import com.tencent.supersonic.chat.query.metricinterpret.MetricInterpretQuery;
import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.AgentService; import com.tencent.supersonic.chat.service.AgentService;
@@ -22,7 +28,11 @@ import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.*;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.HashMap;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Slf4j @Slf4j
@@ -34,7 +44,8 @@ public class MetricInterpretParser implements SemanticParser {
log.info("skip MetricInterpretParser"); log.info("skip MetricInterpretParser");
return; return;
} }
Map<Long, MetricInterpretTool> metricInterpretToolMap = getMetricInterpretTools(queryContext.getRequest().getAgentId()); Map<Long, MetricInterpretTool> metricInterpretToolMap =
getMetricInterpretTools(queryContext.getRequest().getAgentId());
log.info("metric interpret tool : {}", metricInterpretToolMap); log.info("metric interpret tool : {}", metricInterpretToolMap);
if (CollectionUtils.isEmpty(metricInterpretToolMap)) { if (CollectionUtils.isEmpty(metricInterpretToolMap)) {
return; return;
@@ -50,8 +61,10 @@ public class MetricInterpretParser implements SemanticParser {
} }
List<MetricOption> metricOptions = metricInterpretTool.getMetricOptions(); List<MetricOption> metricOptions = metricInterpretTool.getMetricOptions();
if (!CollectionUtils.isEmpty(metricOptions)) { if (!CollectionUtils.isEmpty(metricOptions)) {
List<Long> metricIds = metricOptions.stream().map(MetricOption::getMetricId).collect(Collectors.toList()); List<Long> metricIds = metricOptions.stream()
buildQuery(modelId, queryContext, metricIds, elementMatches.get(modelId), metricInterpretTool.getName()); .map(MetricOption::getMetricId).collect(Collectors.toList());
String name = metricInterpretTool.getName();
buildQuery(modelId, queryContext, metricIds, elementMatches.get(modelId), name);
} }
} }
} }
@@ -82,7 +95,7 @@ public class MetricInterpretParser implements SemanticParser {
if (agent == null) { if (agent == null) {
return new HashMap<>(); return new HashMap<>();
} }
List<String> tools= agent.getTools(AgentToolType.INTERPRET); List<String> tools = agent.getTools(AgentToolType.INTERPRET);
if (CollectionUtils.isEmpty(tools)) { if (CollectionUtils.isEmpty(tools)) {
return new HashMap<>(); return new HashMap<>();
} }
@@ -100,16 +113,16 @@ public class MetricInterpretParser implements SemanticParser {
private SemanticParseInfo buildSemanticParseInfo(Long modelId, QueryReq queryReq, Set<SchemaElement> metrics, private SemanticParseInfo buildSemanticParseInfo(Long modelId, QueryReq queryReq, Set<SchemaElement> metrics,
List<SchemaElementMatch> schemaElementMatches, String toolName) { List<SchemaElementMatch> schemaElementMatches, String toolName) {
SchemaElement Model = new SchemaElement(); SchemaElement model = new SchemaElement();
Model.setModel(modelId); model.setModel(modelId);
Model.setId(modelId); model.setId(modelId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setMetrics(metrics); semanticParseInfo.setMetrics(metrics);
SchemaElement dimension = new SchemaElement(); SchemaElement dimension = new SchemaElement();
dimension.setBizName(TimeDimensionEnum.DAY.getName()); dimension.setBizName(TimeDimensionEnum.DAY.getName());
semanticParseInfo.setDimensions(Sets.newHashSet(dimension)); semanticParseInfo.setDimensions(Sets.newHashSet(dimension));
semanticParseInfo.setElementMatches(schemaElementMatches); semanticParseInfo.setElementMatches(schemaElementMatches);
semanticParseInfo.setModel(Model); semanticParseInfo.setModel(model);
semanticParseInfo.setScore(queryReq.getQueryText().length()); semanticParseInfo.setScore(queryReq.getQueryText().length());
DateConf dateConf = new DateConf(); DateConf dateConf = new DateConf();
dateConf.setDateMode(DateConf.DateMode.RECENT); dateConf.setDateMode(DateConf.DateMode.RECENT);

View File

@@ -17,7 +17,7 @@ public class LLMTimeEnhancementParse implements SemanticParser {
@Override @Override
public void parse(QueryContext queryContext, ChatContext chatContext) { public void parse(QueryContext queryContext, ChatContext chatContext) {
log.info("before queryContext:{},chatContext:{}",queryContext,chatContext); log.info("before queryContext:{},chatContext:{}", queryContext, chatContext);
ChatGptHelper chatGptHelper = ContextUtils.getBean(ChatGptHelper.class); ChatGptHelper chatGptHelper = ContextUtils.getBean(ChatGptHelper.class);
try { try {
String inferredTime = chatGptHelper.inferredTime(queryContext.getRequest().getQueryText()); String inferredTime = chatGptHelper.inferredTime(queryContext.getRequest().getQueryText());
@@ -25,12 +25,12 @@ public class LLMTimeEnhancementParse implements SemanticParser {
for (SemanticQuery query : queryContext.getCandidateQueries()) { for (SemanticQuery query : queryContext.getCandidateQueries()) {
DateConf dateInfo = query.getParseInfo().getDateInfo(); DateConf dateInfo = query.getParseInfo().getDateInfo();
JSONObject jsonObject = JSON.parseObject(inferredTime); JSONObject jsonObject = JSON.parseObject(inferredTime);
if (jsonObject.containsKey("date")){ if (jsonObject.containsKey("date")) {
dateInfo.setDateMode(DateConf.DateMode.BETWEEN); dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
dateInfo.setStartDate(jsonObject.getString("date")); dateInfo.setStartDate(jsonObject.getString("date"));
dateInfo.setEndDate(jsonObject.getString("date")); dateInfo.setEndDate(jsonObject.getString("date"));
query.getParseInfo().setDateInfo(dateInfo); query.getParseInfo().setDateInfo(dateInfo);
}else if (jsonObject.containsKey("start")){ } else if (jsonObject.containsKey("start")) {
dateInfo.setDateMode(DateConf.DateMode.BETWEEN); dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
dateInfo.setStartDate(jsonObject.getString("start")); dateInfo.setStartDate(jsonObject.getString("start"));
dateInfo.setEndDate(jsonObject.getString("end")); dateInfo.setEndDate(jsonObject.getString("end"));
@@ -38,11 +38,13 @@ public class LLMTimeEnhancementParse implements SemanticParser {
} }
} }
} }
}catch (Exception exception){ } catch (Exception exception) {
log.error("{} parse error,this reason is:{}",LLMTimeEnhancementParse.class.getSimpleName(), (Object) exception.getStackTrace()); log.error("{} parse error,this reason is:{}", LLMTimeEnhancementParse.class.getSimpleName(),
(Object) exception.getStackTrace());
} }
log.info("{} after queryContext:{},chatContext:{}",LLMTimeEnhancementParse.class.getSimpleName(),queryContext,chatContext); log.info("{} after queryContext:{},chatContext:{}",
LLMTimeEnhancementParse.class.getSimpleName(), queryContext, chatContext);
} }

View File

@@ -1,8 +1,14 @@
package com.tencent.supersonic.chat.parser.embedding; package com.tencent.supersonic.chat.parser.plugin.embedding;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.*; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.ParseMode; import com.tencent.supersonic.chat.parser.ParseMode;
@@ -10,12 +16,16 @@ import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager; import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseResult; import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.dsl.DSLQuery; import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import java.util.*; import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.HashMap;
import java.util.Comparator;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -47,7 +57,7 @@ public class EmbeddingBasedParser implements SemanticParser {
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p)); Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) { for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId())); Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
if (plugin == null || DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) { if (plugin == null || DslQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
continue; continue;
} }
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext); Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
@@ -88,12 +98,12 @@ public class EmbeddingBasedParser implements SemanticParser {
if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) { if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) {
modelId = plugin.getModelList().get(0); modelId = plugin.getModelList().get(0);
} }
SchemaElement Model = new SchemaElement(); SchemaElement model = new SchemaElement();
Model.setModel(modelId); model.setModel(modelId);
Model.setId(modelId); model.setId(modelId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setElementMatches(schemaElementMatches); semanticParseInfo.setElementMatches(schemaElementMatches);
semanticParseInfo.setModel(Model); semanticParseInfo.setModel(model);
Map<String, Object> properties = new HashMap<>(); Map<String, Object> properties = new HashMap<>();
PluginParseResult pluginParseResult = new PluginParseResult(); PluginParseResult pluginParseResult = new PluginParseResult();
pluginParseResult.setPlugin(plugin); pluginParseResult.setPlugin(plugin);
@@ -111,9 +121,9 @@ public class EmbeddingBasedParser implements SemanticParser {
private void setEntity(Long modelId, SemanticParseInfo semanticParseInfo) { private void setEntity(Long modelId, SemanticParseInfo semanticParseInfo) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class); SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
ModelSchema ModelSchema = semanticService.getModelSchema(modelId); ModelSchema modelSchema = semanticService.getModelSchema(modelId);
if (ModelSchema != null && ModelSchema.getEntity() != null) { if (modelSchema != null && modelSchema.getEntity() != null) {
semanticParseInfo.setEntity(ModelSchema.getEntity()); semanticParseInfo.setEntity(modelSchema.getEntity());
} }
} }

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.embedding; package com.tencent.supersonic.chat.parser.plugin.embedding;
import lombok.Data; import lombok.Data;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;

View File

@@ -1,26 +1,26 @@
package com.tencent.supersonic.chat.parser.embedding; package com.tencent.supersonic.chat.parser.plugin.embedding;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.service.ConfigService; import com.tencent.supersonic.chat.service.ConfigService;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.util.List;
import java.util.Objects;
import java.util.Set;
@Slf4j @Slf4j
@Component("EmbeddingEntityResolver") @Component("EmbeddingEntityResolver")
public class EmbeddingEntityResolver { public class EmbeddingEntityResolver {
private ConfigService configService; private ConfigService configService;
public EmbeddingEntityResolver(ConfigService configService) { public EmbeddingEntityResolver(ConfigService configService) {
@@ -39,8 +39,8 @@ public class EmbeddingEntityResolver {
} }
} }
entityId = getEntityValueFromSchemaMapInfo(modelId, queryCtx.getMapInfo(), entityElementId); entityId = getEntityValueFromSchemaMapInfo(modelId, queryCtx.getMapInfo(), entityElementId);
log.info("get entity id:{} from schema map Info :{} ", entityId, log.info("get entity id:{} from schema map Info :{} ",
JSONObject.toJSONString(queryCtx.getMapInfo())); entityId, JSONObject.toJSONString(queryCtx.getMapInfo()));
if (entityId == null || entityId == 0) { if (entityId == null || entityId == 0) {
Long entityIdFromChat = getEntityValueFromParseInfo(chatCtx.getParseInfo(), entityElementId); Long entityIdFromChat = getEntityValueFromParseInfo(chatCtx.getParseInfo(), entityElementId);
if (entityIdFromChat != null && entityIdFromChat > 0) { if (entityIdFromChat != null && entityIdFromChat > 0) {
@@ -95,4 +95,4 @@ public class EmbeddingEntityResolver {
return null; return null;
} }
} }

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.embedding; package com.tencent.supersonic.chat.parser.plugin.embedding;
import java.util.List; import java.util.List;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.embedding; package com.tencent.supersonic.chat.parser.plugin.embedding;
import lombok.Data; import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.function; package com.tencent.supersonic.chat.parser.plugin.function;
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSON;
import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.component.SemanticParser;
@@ -15,13 +15,18 @@ import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.plugin.PluginParseResult; import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.dsl.DSLQuery; import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.chat.service.PluginService; import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import java.net.URI; import java.net.URI;
import java.util.*; import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.Objects;
import java.util.Map;
import java.util.HashMap;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -39,11 +44,6 @@ import org.springframework.web.util.UriComponentsBuilder;
@Slf4j @Slf4j
public class FunctionBasedParser implements SemanticParser { public class FunctionBasedParser implements SemanticParser {
public static final double FUNCTION_BONUS_THRESHOLD = 200;
public static final double SKIP_DSL_LENGTH = 10;
@Override @Override
public void parse(QueryContext queryCtx, ChatContext chatCtx) { public void parse(QueryContext queryCtx, ChatContext chatCtx) {
FunctionCallInfoConfig functionCallConfig = ContextUtils.getBean(FunctionCallInfoConfig.class); FunctionCallInfoConfig functionCallConfig = ContextUtils.getBean(FunctionCallInfoConfig.class);
@@ -59,12 +59,17 @@ public class FunctionBasedParser implements SemanticParser {
log.info("function call parser, plugin is empty, skip"); log.info("function call parser, plugin is empty, skip");
return; return;
} }
FunctionReq functionReq = FunctionReq.builder() FunctionResp functionResp = new FunctionResp();
.queryText(queryCtx.getRequest().getQueryText()) if (functionDOList.size() == 1) {
.pluginConfigs(functionDOList).build(); functionResp.setToolSelection(functionDOList.iterator().next().getName());
FunctionResp functionResp = requestFunction(functionUrl, functionReq); } else {
FunctionReq functionReq = FunctionReq.builder()
.queryText(queryCtx.getRequest().getQueryText())
.pluginConfigs(functionDOList).build();
functionResp = requestFunction(functionUrl, functionReq);
}
log.info("requestFunction result:{}", functionResp.getToolSelection()); log.info("requestFunction result:{}", functionResp.getToolSelection());
if (skipFunction(queryCtx, functionResp)) { if (skipFunction(functionResp)) {
return; return;
} }
PluginParseResult functionCallParseResult = new PluginParseResult(); PluginParseResult functionCallParseResult = new PluginParseResult();
@@ -80,10 +85,10 @@ public class FunctionBasedParser implements SemanticParser {
functionCallParseResult.setPlugin(plugin); functionCallParseResult.setPlugin(plugin);
log.info("QueryManager PluginQueryModes:{}", QueryManager.getPluginQueryModes()); log.info("QueryManager PluginQueryModes:{}", QueryManager.getPluginQueryModes());
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(toolSelection); PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(toolSelection);
ModelResolver ModelResolver = ComponentFactory.getModelResolver(); ModelResolver modelResolver = ComponentFactory.getModelResolver();
log.info("plugin ModelList:{}", plugin.getModelList()); log.info("plugin ModelList:{}", plugin.getModelList());
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryCtx); Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryCtx);
Long modelId = ModelResolver.resolve(queryCtx, chatCtx, pluginResolveResult.getRight()); Long modelId = modelResolver.resolve(queryCtx, chatCtx, pluginResolveResult.getRight());
log.info("FunctionBasedParser modelId:{}", modelId); log.info("FunctionBasedParser modelId:{}", modelId);
if ((Objects.isNull(modelId) || modelId <= 0) && !plugin.isContainsAllModel()) { if ((Objects.isNull(modelId) || modelId <= 0) && !plugin.isContainsAllModel()) {
log.info("Model is null, skip the parse, select tool: {}", toolSelection); log.info("Model is null, skip the parse, select tool: {}", toolSelection);
@@ -102,35 +107,24 @@ public class FunctionBasedParser implements SemanticParser {
properties.put("type", "plugin"); properties.put("type", "plugin");
properties.put("name", plugin.getName()); properties.put("name", plugin.getName());
parseInfo.setProperties(properties); parseInfo.setProperties(properties);
parseInfo.setScore(FUNCTION_BONUS_THRESHOLD); parseInfo.setScore(queryCtx.getRequest().getQueryText().length());
parseInfo.setQueryMode(semanticQuery.getQueryMode()); parseInfo.setQueryMode(semanticQuery.getQueryMode());
SchemaElement Model = new SchemaElement(); SchemaElement model = new SchemaElement();
Model.setModel(modelId); model.setModel(modelId);
Model.setId(modelId); model.setId(modelId);
parseInfo.setModel(Model); parseInfo.setModel(model);
queryCtx.getCandidateQueries().add(semanticQuery); queryCtx.getCandidateQueries().add(semanticQuery);
} }
private boolean skipFunction(QueryContext queryCtx, FunctionResp functionResp) { private boolean skipFunction(FunctionResp functionResp) {
if (Objects.isNull(functionResp) || StringUtils.isBlank(functionResp.getToolSelection())) { return Objects.isNull(functionResp) || StringUtils.isBlank(functionResp.getToolSelection());
return true;
}
String queryText = queryCtx.getRequest().getQueryText();
if (functionResp.getToolSelection().equalsIgnoreCase(DSLQuery.QUERY_MODE)
&& queryText.length() < SKIP_DSL_LENGTH) {
log.info("queryText length is :{}, less than the threshold :{}, skip dsl.", queryText.length(),
SKIP_DSL_LENGTH);
return true;
}
return false;
} }
private List<PluginParseConfig> getFunctionDO(Long modelId, QueryContext queryContext) { private List<PluginParseConfig> getFunctionDO(Long modelId, QueryContext queryContext) {
log.info("user decide Model:{}", modelId); log.info("user decide Model:{}", modelId);
List<Plugin> plugins = getPluginList(queryContext); List<Plugin> plugins = getPluginList(queryContext);
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> { List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
if (DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) { if (DslQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
return false; return false;
} }
if (plugin.getParseModeConfig() == null) { if (plugin.getParseModeConfig() == null) {

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.function; package com.tencent.supersonic.chat.parser.plugin.function;
import lombok.Data; import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.function; package com.tencent.supersonic.chat.parser.plugin.function;
import com.tencent.supersonic.chat.plugin.PluginParseConfig; import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import java.util.List; import java.util.List;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.function; package com.tencent.supersonic.chat.parser.plugin.function;
import lombok.Data; import lombok.Data;

View File

@@ -1,29 +1,40 @@
package com.tencent.supersonic.chat.parser.function; package com.tencent.supersonic.chat.parser.plugin.function;
import com.tencent.supersonic.chat.api.pojo.*; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.component.SemanticQuery;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import java.util.*; import java.util.Map;
import java.util.HashMap;
import java.util.Objects;
import java.util.List;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
@Slf4j @Slf4j
public class HeuristicModelResolver implements ModelResolver { public class HeuristicModelResolver implements ModelResolver {
protected static Long selectModelBySchemaElementCount(Map<Long, SemanticQuery> ModelQueryModes, protected static Long selectModelBySchemaElementCount(Map<Long, SemanticQuery> modelQueryModes,
SchemaMapInfo schemaMap) { SchemaMapInfo schemaMap) {
Map<Long, ModelMatchResult> ModelTypeMap = getModelTypeMap(schemaMap); Map<Long, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
if (ModelTypeMap.size() == 1) { if (modelTypeMap.size() == 1) {
Long ModelSelect = ModelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey(); Long modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
if (ModelQueryModes.containsKey(ModelSelect)) { if (modelQueryModes.containsKey(modelSelect)) {
log.info("selectModel with only one Model [{}]", ModelSelect); log.info("selectModel with only one Model [{}]", modelSelect);
return ModelSelect; return modelSelect;
} }
} else { } else {
Map.Entry<Long, ModelMatchResult> maxModel = ModelTypeMap.entrySet().stream() Map.Entry<Long, ModelMatchResult> maxModel = modelTypeMap.entrySet().stream()
.filter(entry -> ModelQueryModes.containsKey(entry.getKey())) .filter(entry -> modelQueryModes.containsKey(entry.getKey()))
.sorted((o1, o2) -> { .sorted((o1, o2) -> {
int difference = o2.getValue().getCount() - o1.getValue().getCount(); int difference = o2.getValue().getCount() - o1.getValue().getCount();
if (difference == 0) { if (difference == 0) {
@@ -45,23 +56,24 @@ public class HeuristicModelResolver implements ModelResolver {
* *
* @return false will use context Model, true will use other Model , maybe include context Model * @return false will use context Model, true will use other Model , maybe include context Model
*/ */
protected static boolean isAllowSwitch(Map<Long, SemanticQuery> ModelQueryModes, SchemaMapInfo schemaMap, protected static boolean isAllowSwitch(Map<Long, SemanticQuery> modelQueryModes, SchemaMapInfo schemaMap,
ChatContext chatCtx, QueryReq searchCtx, Long modelId, Set<Long> restrictiveModels) { ChatContext chatCtx, QueryReq searchCtx,
Long modelId, Set<Long> restrictiveModels) {
if (!Objects.nonNull(modelId) || modelId <= 0) { if (!Objects.nonNull(modelId) || modelId <= 0) {
return true; return true;
} }
// except content Model, calculate the number of types for each Model, if numbers<=1 will not switch // except content Model, calculate the number of types for each Model, if numbers<=1 will not switch
Map<Long, ModelMatchResult> ModelTypeMap = getModelTypeMap(schemaMap); Map<Long, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
log.info("isAllowSwitch ModelTypeMap [{}]", ModelTypeMap); log.info("isAllowSwitch ModelTypeMap [{}]", modelTypeMap);
long otherModelTypeNumBigOneCount = ModelTypeMap.entrySet().stream() long otherModelTypeNumBigOneCount = modelTypeMap.entrySet().stream()
.filter(entry -> ModelQueryModes.containsKey(entry.getKey()) && !entry.getKey().equals(modelId)) .filter(entry -> modelQueryModes.containsKey(entry.getKey()) && !entry.getKey().equals(modelId))
.filter(entry -> entry.getValue().getCount() > 1).count(); .filter(entry -> entry.getValue().getCount() > 1).count();
if (otherModelTypeNumBigOneCount >= 1) { if (otherModelTypeNumBigOneCount >= 1) {
return true; return true;
} }
// if query text only contain time , will not switch // if query text only contain time , will not switch
if (!CollectionUtils.isEmpty(ModelQueryModes.values())) { if (!CollectionUtils.isEmpty(modelQueryModes.values())) {
for (SemanticQuery semanticQuery : ModelQueryModes.values()) { for (SemanticQuery semanticQuery : modelQueryModes.values()) {
if (semanticQuery == null) { if (semanticQuery == null) {
continue; continue;
} }
@@ -71,7 +83,8 @@ public class HeuristicModelResolver implements ModelResolver {
} }
if (searchCtx.getQueryText() != null && semanticParseInfo.getDateInfo() != null) { if (searchCtx.getQueryText() != null && semanticParseInfo.getDateInfo() != null) {
if (semanticParseInfo.getDateInfo().getDetectWord() != null) { if (semanticParseInfo.getDateInfo().getDetectWord() != null) {
if (semanticParseInfo.getDateInfo().getDetectWord().equalsIgnoreCase(searchCtx.getQueryText())) { if (semanticParseInfo.getDateInfo().getDetectWord()
.equalsIgnoreCase(searchCtx.getQueryText())) {
log.info("timeParseResults is not null , can not switch context , timeParseResults:{},", log.info("timeParseResults is not null , can not switch context , timeParseResults:{},",
semanticParseInfo.getDateInfo()); semanticParseInfo.getDateInfo());
return false; return false;
@@ -94,14 +107,14 @@ public class HeuristicModelResolver implements ModelResolver {
} }
public static Map<Long, ModelMatchResult> getModelTypeMap(SchemaMapInfo schemaMap) { public static Map<Long, ModelMatchResult> getModelTypeMap(SchemaMapInfo schemaMap) {
Map<Long, ModelMatchResult> ModelCount = new HashMap<>(); Map<Long, ModelMatchResult> modelCount = new HashMap<>();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getModelElementMatches().entrySet()) { for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getModelElementMatches().entrySet()) {
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey()); List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
if (schemaElementMatches != null && schemaElementMatches.size() > 0) { if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
if (!ModelCount.containsKey(entry.getKey())) { if (!modelCount.containsKey(entry.getKey())) {
ModelCount.put(entry.getKey(), new ModelMatchResult()); modelCount.put(entry.getKey(), new ModelMatchResult());
} }
ModelMatchResult ModelMatchResult = ModelCount.get(entry.getKey()); ModelMatchResult modelMatchResult = modelCount.get(entry.getKey());
Set<SchemaElementType> schemaElementTypes = new HashSet<>(); Set<SchemaElementType> schemaElementTypes = new HashSet<>();
schemaElementMatches.stream() schemaElementMatches.stream()
.forEach(schemaElementMatch -> schemaElementTypes.add( .forEach(schemaElementMatch -> schemaElementTypes.add(
@@ -111,13 +124,13 @@ public class HeuristicModelResolver implements ModelResolver {
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100)) ((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
).findFirst().orElse(null); ).findFirst().orElse(null);
if (schemaElementMatchMax != null) { if (schemaElementMatchMax != null) {
ModelMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity()); modelMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
} }
ModelMatchResult.setCount(schemaElementTypes.size()); modelMatchResult.setCount(schemaElementTypes.size());
} }
} }
return ModelCount; return modelCount;
} }
@@ -137,40 +150,41 @@ public class HeuristicModelResolver implements ModelResolver {
.filter(restrictiveModels::contains) .filter(restrictiveModels::contains)
.collect(Collectors.toSet()); .collect(Collectors.toSet());
} }
Map<Long, SemanticQuery> ModelQueryModes = new HashMap<>(); Map<Long, SemanticQuery> modelQueryModes = new HashMap<>();
for (Long matchedModel : matchedModels) { for (Long matchedModel : matchedModels) {
ModelQueryModes.put(matchedModel, null); modelQueryModes.put(matchedModel, null);
} }
if(ModelQueryModes.size()==1){ if (modelQueryModes.size() == 1) {
return ModelQueryModes.keySet().stream().findFirst().get(); return modelQueryModes.keySet().stream().findFirst().get();
} }
return resolve(ModelQueryModes, queryContext, chatCtx, return resolve(modelQueryModes, queryContext, chatCtx,
queryContext.getMapInfo(),restrictiveModels); queryContext.getMapInfo(), restrictiveModels);
} }
public Long resolve(Map<Long, SemanticQuery> ModelQueryModes, QueryContext queryContext, public Long resolve(Map<Long, SemanticQuery> modelQueryModes, QueryContext queryContext,
ChatContext chatCtx, SchemaMapInfo schemaMap, Set<Long> restrictiveModels) { ChatContext chatCtx, SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
Long selectModel = selectModel(ModelQueryModes, queryContext.getRequest(), chatCtx, schemaMap,restrictiveModels); Long selectModel = selectModel(modelQueryModes, queryContext.getRequest(),
chatCtx, schemaMap, restrictiveModels);
if (selectModel > 0) { if (selectModel > 0) {
log.info("selectModel {} ", selectModel); log.info("selectModel {} ", selectModel);
return selectModel; return selectModel;
} }
// get the max SchemaElementType number // get the max SchemaElementType number
return selectModelBySchemaElementCount(ModelQueryModes, schemaMap); return selectModelBySchemaElementCount(modelQueryModes, schemaMap);
} }
public Long selectModel(Map<Long, SemanticQuery> ModelQueryModes, QueryReq queryContext, public Long selectModel(Map<Long, SemanticQuery> modelQueryModes, QueryReq queryContext,
ChatContext chatCtx, ChatContext chatCtx,
SchemaMapInfo schemaMap, Set<Long> restrictiveModels) { SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
// if QueryContext has modelId and in ModelQueryModes // if QueryContext has modelId and in ModelQueryModes
if (ModelQueryModes.containsKey(queryContext.getModelId())) { if (modelQueryModes.containsKey(queryContext.getModelId())) {
log.info("selectModel from QueryContext [{}]", queryContext.getModelId()); log.info("selectModel from QueryContext [{}]", queryContext.getModelId());
return queryContext.getModelId(); return queryContext.getModelId();
} }
// if ChatContext has modelId and in ModelQueryModes // if ChatContext has modelId and in ModelQueryModes
if (chatCtx.getParseInfo().getModelId() > 0) { if (chatCtx.getParseInfo().getModelId() > 0) {
Long modelId = chatCtx.getParseInfo().getModelId(); Long modelId = chatCtx.getParseInfo().getModelId();
if (!isAllowSwitch(ModelQueryModes, schemaMap, chatCtx, queryContext, modelId,restrictiveModels)) { if (!isAllowSwitch(modelQueryModes, schemaMap, chatCtx, queryContext, modelId, restrictiveModels)) {
log.info("selectModel from ChatContext [{}]", modelId); log.info("selectModel from ChatContext [{}]", modelId);
return modelId; return modelId;
} }

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.function; package com.tencent.supersonic.chat.parser.plugin.function;
import lombok.Data; import lombok.Data;

View File

@@ -1,13 +1,12 @@
package com.tencent.supersonic.chat.parser.function; package com.tencent.supersonic.chat.parser.plugin.function;
import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import java.util.List;
import java.util.Set; import java.util.Set;
public interface ModelResolver { public interface ModelResolver {
Long resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels); Long resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels);
} }

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.function; package com.tencent.supersonic.chat.parser.plugin.function;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;

View File

@@ -13,7 +13,6 @@ import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@@ -32,16 +31,24 @@ public class AgentCheckParser implements SemanticParser {
if (agent == null) { if (agent == null) {
return; return;
} }
List<String> queryModes = getRuleTools(agentId).stream().map(RuleQueryTool::getQueryModes) List<RuleQueryTool> queryTools = getRuleTools(agentId);
.flatMap(Collection::stream).collect(Collectors.toList()); if (CollectionUtils.isEmpty(queryTools)) {
if (CollectionUtils.isEmpty(queries)) {
queries.clear(); queries.clear();
return; return;
} }
log.info("queries resolved:{} {}", agent.getName(), log.info("queries resolved:{} {}", agent.getName(),
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList())); queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
queries.removeIf(query -> queries.removeIf(query -> {
!queryModes.contains(query.getQueryMode())); for (RuleQueryTool tool : queryTools) {
if (!tool.getQueryModes().contains(query.getQueryMode())) {
return true;
}
if (tool.isContainsAllModel() || tool.getModelIds().contains(query.getParseInfo().getModelId())) {
return false;
}
}
return true;
});
log.info("rule queries witch can be supported by agent :{} {}", agent.getName(), log.info("rule queries witch can be supported by agent :{} {}", agent.getName(),
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList())); queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
} }

View File

@@ -14,6 +14,7 @@ import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import java.util.AbstractMap; import java.util.AbstractMap;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@@ -21,6 +22,7 @@ import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;

View File

@@ -1,12 +1,5 @@
package com.tencent.supersonic.chat.parser.rule; package com.tencent.supersonic.chat.parser.rule;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.MODEL;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.ChatContext;
@@ -15,8 +8,8 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery; import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery; import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
import java.util.AbstractMap; import java.util.AbstractMap;
import java.util.ArrayList; import java.util.ArrayList;
@@ -28,6 +21,13 @@ import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.MODEL;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
@Slf4j @Slf4j
public class ContextInheritParser implements SemanticParser { public class ContextInheritParser implements SemanticParser {

View File

@@ -1,9 +1,12 @@
package com.tencent.supersonic.chat.parser.rule; package com.tencent.supersonic.chat.parser.rule;
import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.*; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import java.util.*; import java.util.List;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@Slf4j @Slf4j

View File

@@ -4,21 +4,23 @@ import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf;
import com.xkzhangsan.time.nlp.TimeNLP;
import com.xkzhangsan.time.nlp.TimeNLPUtil;
import java.text.DateFormat; import java.text.DateFormat;
import java.text.ParseException; import java.text.ParseException;
import java.text.SimpleDateFormat; import java.text.SimpleDateFormat;
import java.time.LocalDate; import java.time.LocalDate;
import java.util.Stack;
import java.util.Date; import java.util.Date;
import java.util.List; import java.util.List;
import java.util.Stack;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import com.xkzhangsan.time.nlp.TimeNLP;
import com.xkzhangsan.time.nlp.TimeNLPUtil;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings; import org.apache.logging.log4j.util.Strings;

View File

@@ -4,17 +4,14 @@ import java.util.Date;
public class AgentDO { public class AgentDO {
/** /**
*
*/ */
private Integer id; private Integer id;
/** /**
*
*/ */
private String name; private String name;
/** /**
*
*/ */
private String description; private String description;
@@ -24,83 +21,70 @@ public class AgentDO {
private Integer status; private Integer status;
/** /**
*
*/ */
private String examples; private String examples;
/** /**
*
*/ */
private String config; private String config;
/** /**
*
*/ */
private String createdBy; private String createdBy;
/** /**
*
*/ */
private Date createdAt; private Date createdAt;
/** /**
*
*/ */
private String updatedBy; private String updatedBy;
/** /**
*
*/ */
private Date updatedAt; private Date updatedAt;
/** /**
*
*/ */
private Integer enableSearch; private Integer enableSearch;
/** /**
* * @return id
* @return id
*/ */
public Integer getId() { public Integer getId() {
return id; return id;
} }
/** /**
* * @param id
* @param id
*/ */
public void setId(Integer id) { public void setId(Integer id) {
this.id = id; this.id = id;
} }
/** /**
* * @return name
* @return name
*/ */
public String getName() { public String getName() {
return name; return name;
} }
/** /**
* * @param name
* @param name
*/ */
public void setName(String name) { public void setName(String name) {
this.name = name == null ? null : name.trim(); this.name = name == null ? null : name.trim();
} }
/** /**
* * @return description
* @return description
*/ */
public String getDescription() { public String getDescription() {
return description; return description;
} }
/** /**
* * @param description
* @param description
*/ */
public void setDescription(String description) { public void setDescription(String description) {
this.description = description == null ? null : description.trim(); this.description = description == null ? null : description.trim();
@@ -123,114 +107,100 @@ public class AgentDO {
} }
/** /**
* * @return examples
* @return examples
*/ */
public String getExamples() { public String getExamples() {
return examples; return examples;
} }
/** /**
* * @param examples
* @param examples
*/ */
public void setExamples(String examples) { public void setExamples(String examples) {
this.examples = examples == null ? null : examples.trim(); this.examples = examples == null ? null : examples.trim();
} }
/** /**
* * @return config
* @return config
*/ */
public String getConfig() { public String getConfig() {
return config; return config;
} }
/** /**
* * @param config
* @param config
*/ */
public void setConfig(String config) { public void setConfig(String config) {
this.config = config == null ? null : config.trim(); this.config = config == null ? null : config.trim();
} }
/** /**
* * @return created_by
* @return created_by
*/ */
public String getCreatedBy() { public String getCreatedBy() {
return createdBy; return createdBy;
} }
/** /**
* * @param createdBy
* @param createdBy
*/ */
public void setCreatedBy(String createdBy) { public void setCreatedBy(String createdBy) {
this.createdBy = createdBy == null ? null : createdBy.trim(); this.createdBy = createdBy == null ? null : createdBy.trim();
} }
/** /**
* * @return created_at
* @return created_at
*/ */
public Date getCreatedAt() { public Date getCreatedAt() {
return createdAt; return createdAt;
} }
/** /**
* * @param createdAt
* @param createdAt
*/ */
public void setCreatedAt(Date createdAt) { public void setCreatedAt(Date createdAt) {
this.createdAt = createdAt; this.createdAt = createdAt;
} }
/** /**
* * @return updated_by
* @return updated_by
*/ */
public String getUpdatedBy() { public String getUpdatedBy() {
return updatedBy; return updatedBy;
} }
/** /**
* * @param updatedBy
* @param updatedBy
*/ */
public void setUpdatedBy(String updatedBy) { public void setUpdatedBy(String updatedBy) {
this.updatedBy = updatedBy == null ? null : updatedBy.trim(); this.updatedBy = updatedBy == null ? null : updatedBy.trim();
} }
/** /**
* * @return updated_at
* @return updated_at
*/ */
public Date getUpdatedAt() { public Date getUpdatedAt() {
return updatedAt; return updatedAt;
} }
/** /**
* * @param updatedAt
* @param updatedAt
*/ */
public void setUpdatedAt(Date updatedAt) { public void setUpdatedAt(Date updatedAt) {
this.updatedAt = updatedAt; this.updatedAt = updatedAt;
} }
/** /**
* * @return enable_search
* @return enable_search
*/ */
public Integer getEnableSearch() { public Integer getEnableSearch() {
return enableSearch; return enableSearch;
} }
/** /**
* * @param enableSearch
* @param enableSearch
*/ */
public void setEnableSearch(Integer enableSearch) { public void setEnableSearch(Integer enableSearch) {
this.enableSearch = enableSearch; this.enableSearch = enableSearch;
} }
} }

View File

@@ -31,7 +31,6 @@ public class AgentDOExample {
protected Integer limitEnd; protected Integer limitEnd;
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
public AgentDOExample() { public AgentDOExample() {
@@ -39,7 +38,6 @@ public class AgentDOExample {
} }
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
public void setOrderByClause(String orderByClause) { public void setOrderByClause(String orderByClause) {
@@ -47,7 +45,6 @@ public class AgentDOExample {
} }
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
public String getOrderByClause() { public String getOrderByClause() {
@@ -55,7 +52,6 @@ public class AgentDOExample {
} }
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
public void setDistinct(boolean distinct) { public void setDistinct(boolean distinct) {
@@ -63,7 +59,6 @@ public class AgentDOExample {
} }
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
public boolean isDistinct() { public boolean isDistinct() {
@@ -71,7 +66,6 @@ public class AgentDOExample {
} }
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
public List<Criteria> getOredCriteria() { public List<Criteria> getOredCriteria() {
@@ -79,7 +73,6 @@ public class AgentDOExample {
} }
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
public void or(Criteria criteria) { public void or(Criteria criteria) {
@@ -87,7 +80,6 @@ public class AgentDOExample {
} }
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
public Criteria or() { public Criteria or() {
@@ -97,7 +89,6 @@ public class AgentDOExample {
} }
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
public Criteria createCriteria() { public Criteria createCriteria() {
@@ -109,7 +100,6 @@ public class AgentDOExample {
} }
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
protected Criteria createCriteriaInternal() { protected Criteria createCriteriaInternal() {
@@ -118,7 +108,6 @@ public class AgentDOExample {
} }
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
public void clear() { public void clear() {
@@ -128,15 +117,13 @@ public class AgentDOExample {
} }
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
public void setLimitStart(Integer limitStart) { public void setLimitStart(Integer limitStart) {
this.limitStart=limitStart; this.limitStart = limitStart;
} }
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
public Integer getLimitStart() { public Integer getLimitStart() {
@@ -144,15 +131,13 @@ public class AgentDOExample {
} }
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
public void setLimitEnd(Integer limitEnd) { public void setLimitEnd(Integer limitEnd) {
this.limitEnd=limitEnd; this.limitEnd = limitEnd;
} }
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
public Integer getLimitEnd() { public Integer getLimitEnd() {
@@ -954,38 +939,6 @@ public class AgentDOExample {
private String typeHandler; private String typeHandler;
public String getCondition() {
return condition;
}
public Object getValue() {
return value;
}
public Object getSecondValue() {
return secondValue;
}
public boolean isNoValue() {
return noValue;
}
public boolean isSingleValue() {
return singleValue;
}
public boolean isBetweenValue() {
return betweenValue;
}
public boolean isListValue() {
return listValue;
}
public String getTypeHandler() {
return typeHandler;
}
protected Criterion(String condition) { protected Criterion(String condition) {
super(); super();
this.condition = condition; this.condition = condition;
@@ -1021,5 +974,37 @@ public class AgentDOExample {
protected Criterion(String condition, Object value, Object secondValue) { protected Criterion(String condition, Object value, Object secondValue) {
this(condition, value, secondValue, null); this(condition, value, secondValue, null);
} }
public String getCondition() {
return condition;
}
public Object getValue() {
return value;
}
public Object getSecondValue() {
return secondValue;
}
public boolean isNoValue() {
return noValue;
}
public boolean isSingleValue() {
return singleValue;
}
public boolean isBetweenValue() {
return betweenValue;
}
public boolean isListValue() {
return listValue;
}
public String getTypeHandler() {
return typeHandler;
}
} }
} }

View File

@@ -0,0 +1,142 @@
package com.tencent.supersonic.chat.persistence.dataobject;
import java.util.Date;
public class ChatParseDO {
/**
* questionId
*/
private Long questionId;
/**
* chatId
*/
private Long chatId;
/**
* parseId
*/
private Integer parseId;
/**
* createTime
*/
private Date createTime;
/**
* queryText
*/
private String queryText;
/**
* userName
*/
private String userName;
/**
* parseInfo
*/
private String parseInfo;
/**
* isCandidate
*/
private Integer isCandidate;
/**
* return question_id
*/
public Long getQuestionId() {
return questionId;
}
/**
* questionId
*/
public void setQuestionId(Long questionId) {
this.questionId = questionId;
}
/**
* return create_time
*/
public Date getCreateTime() {
return createTime;
}
/**
* createTime
*/
public void setCreateTime(Date createTime) {
this.createTime = createTime;
}
/**
* return user_name
*/
public String getUserName() {
return userName;
}
/**
* userName
*/
public void setUserName(String userName) {
this.userName = userName == null ? null : userName.trim();
}
/**
* return chat_id
*/
public Long getChatId() {
return chatId;
}
/**
* chatId
*/
public void setChatId(Long chatId) {
this.chatId = chatId;
}
/**
* return query_text
*/
public String getQueryText() {
return queryText;
}
/**
* queryText
*/
public void setQueryText(String queryText) {
this.queryText = queryText == null ? null : queryText.trim();
}
public Integer getIsCandidate() {
return isCandidate;
}
public Integer getParseId() {
return parseId;
}
public String getParseInfo() {
return parseInfo;
}
public void setParseId(Integer parseId) {
this.parseId = parseId;
}
public void setIsCandidate(Integer isCandidate) {
this.isCandidate = isCandidate;
}
public void setParseInfo(String parseInfo) {
this.parseInfo = parseInfo;
}
}

View File

@@ -2,177 +2,196 @@ package com.tencent.supersonic.chat.persistence.dataobject;
import java.util.Date; import java.util.Date;
public class ChatQueryDO { public class ChatQueryDO {
/** /**
* questionId
*/ */
private Long questionId; private Long questionId;
/** /**
* createTime */
private Integer agentId;
/**
*/ */
private Date createTime; private Date createTime;
/** /**
* userName
*/ */
private String userName; private String userName;
/** /**
* queryState
*/ */
private Integer queryState; private Integer queryState;
/** /**
* chatId
*/ */
private Long chatId; private Long chatId;
/** /**
* score
*/ */
private Integer score; private Integer score;
/** /**
* feedback
*/ */
private String feedback; private String feedback;
/** /**
* queryText
*/ */
private String queryText; private String queryText;
/** /**
* queryResponse
*/ */
private String queryResult; private String queryResult;
/** /**
* return question_id * @return question_id
*/ */
public Long getQuestionId() { public Long getQuestionId() {
return questionId; return questionId;
} }
/** /**
* questionId * @param questionId
*/ */
public void setQuestionId(Long questionId) { public void setQuestionId(Long questionId) {
this.questionId = questionId; this.questionId = questionId;
} }
/** /**
* return create_time * @return agent_id
*/
public Integer getAgentId() {
return agentId;
}
/**
* @param agentId
*/
public void setAgentId(Integer agentId) {
this.agentId = agentId;
}
/**
* @return create_time
*/ */
public Date getCreateTime() { public Date getCreateTime() {
return createTime; return createTime;
} }
/** /**
* createTime * @param createTime
*/ */
public void setCreateTime(Date createTime) { public void setCreateTime(Date createTime) {
this.createTime = createTime; this.createTime = createTime;
} }
/** /**
* return user_name * @return user_name
*/ */
public String getUserName() { public String getUserName() {
return userName; return userName;
} }
/** /**
* userName * @param userName
*/ */
public void setUserName(String userName) { public void setUserName(String userName) {
this.userName = userName == null ? null : userName.trim(); this.userName = userName == null ? null : userName.trim();
} }
/** /**
* return query_state *
* @return query_state
*/ */
public Integer getQueryState() { public Integer getQueryState() {
return queryState; return queryState;
} }
/** /**
* queryState *
* @param queryState
*/ */
public void setQueryState(Integer queryState) { public void setQueryState(Integer queryState) {
this.queryState = queryState; this.queryState = queryState;
} }
/** /**
* return chat_id *
* @return chat_id
*/ */
public Long getChatId() { public Long getChatId() {
return chatId; return chatId;
} }
/** /**
* chatId *
* @param chatId
*/ */
public void setChatId(Long chatId) { public void setChatId(Long chatId) {
this.chatId = chatId; this.chatId = chatId;
} }
/** /**
* return score *
* @return score
*/ */
public Integer getScore() { public Integer getScore() {
return score; return score;
} }
/** /**
* score *
* @param score
*/ */
public void setScore(Integer score) { public void setScore(Integer score) {
this.score = score; this.score = score;
} }
/** /**
* return feedback *
* @return feedback
*/ */
public String getFeedback() { public String getFeedback() {
return feedback; return feedback;
} }
/** /**
* feedback *
* @param feedback
*/ */
public void setFeedback(String feedback) { public void setFeedback(String feedback) {
this.feedback = feedback == null ? null : feedback.trim(); this.feedback = feedback == null ? null : feedback.trim();
} }
/** /**
* return query_text *
* @return query_text
*/ */
public String getQueryText() { public String getQueryText() {
return queryText; return queryText;
} }
/** /**
* queryText *
* @param queryText
*/ */
public void setQueryText(String queryText) { public void setQueryText(String queryText) {
this.queryText = queryText == null ? null : queryText.trim(); this.queryText = queryText == null ? null : queryText.trim();
} }
/** /**
* return query_response *
* @return query_result
*/ */
public String getQueryResult() { public String getQueryResult() {
return queryResult; return queryResult;
} }
/** /**
* queryResponse *
* @param queryResult
*/ */
public void setQueryResult(String queryResult) { public void setQueryResult(String queryResult) {
this.queryResult = queryResult == null ? null : queryResult.trim(); this.queryResult = queryResult == null ? null : queryResult.trim();
} }
} }

View File

@@ -5,47 +5,92 @@ import java.util.Date;
import java.util.List; import java.util.List;
public class ChatQueryDOExample { public class ChatQueryDOExample {
/**
* s2_chat_query
*/
protected String orderByClause; protected String orderByClause;
/**
* s2_chat_query
*/
protected boolean distinct; protected boolean distinct;
/**
* s2_chat_query
*/
protected List<Criteria> oredCriteria; protected List<Criteria> oredCriteria;
/**
* s2_chat_query
*/
protected Integer limitStart; protected Integer limitStart;
/**
* s2_chat_query
*/
protected Integer limitEnd; protected Integer limitEnd;
/**
* @mbg.generated
*/
public ChatQueryDOExample() { public ChatQueryDOExample() {
oredCriteria = new ArrayList<Criteria>(); oredCriteria = new ArrayList<Criteria>();
} }
public String getOrderByClause() { /**
return orderByClause; * @mbg.generated
} */
public void setOrderByClause(String orderByClause) { public void setOrderByClause(String orderByClause) {
this.orderByClause = orderByClause; this.orderByClause = orderByClause;
} }
public boolean isDistinct() { /**
return distinct; * @mbg.generated
*/
public String getOrderByClause() {
return orderByClause;
} }
/**
* @mbg.generated
*/
public void setDistinct(boolean distinct) { public void setDistinct(boolean distinct) {
this.distinct = distinct; this.distinct = distinct;
} }
/**
* @mbg.generated
*/
public boolean isDistinct() {
return distinct;
}
/**
* @mbg.generated
*/
public List<Criteria> getOredCriteria() { public List<Criteria> getOredCriteria() {
return oredCriteria; return oredCriteria;
} }
/**
* @mbg.generated
*/
public void or(Criteria criteria) { public void or(Criteria criteria) {
oredCriteria.add(criteria); oredCriteria.add(criteria);
} }
/**
* @mbg.generated
*/
public Criteria or() { public Criteria or() {
Criteria criteria = createCriteriaInternal(); Criteria criteria = createCriteriaInternal();
oredCriteria.add(criteria); oredCriteria.add(criteria);
return criteria; return criteria;
} }
/**
* @mbg.generated
*/
public Criteria createCriteria() { public Criteria createCriteria() {
Criteria criteria = createCriteriaInternal(); Criteria criteria = createCriteriaInternal();
if (oredCriteria.size() == 0) { if (oredCriteria.size() == 0) {
@@ -54,35 +99,55 @@ public class ChatQueryDOExample {
return criteria; return criteria;
} }
/**
* @mbg.generated
*/
protected Criteria createCriteriaInternal() { protected Criteria createCriteriaInternal() {
Criteria criteria = new Criteria(); Criteria criteria = new Criteria();
return criteria; return criteria;
} }
/**
* @mbg.generated
*/
public void clear() { public void clear() {
oredCriteria.clear(); oredCriteria.clear();
orderByClause = null; orderByClause = null;
distinct = false; distinct = false;
} }
public Integer getLimitStart() { /**
return limitStart; * @mbg.generated
} */
public void setLimitStart(Integer limitStart) { public void setLimitStart(Integer limitStart) {
this.limitStart = limitStart; this.limitStart = limitStart;
} }
public Integer getLimitEnd() { /**
return limitEnd; * @mbg.generated
*/
public Integer getLimitStart() {
return limitStart;
} }
/**
* @mbg.generated
*/
public void setLimitEnd(Integer limitEnd) { public void setLimitEnd(Integer limitEnd) {
this.limitEnd = limitEnd; this.limitEnd = limitEnd;
} }
protected abstract static class GeneratedCriteria { /**
* @mbg.generated
*/
public Integer getLimitEnd() {
return limitEnd;
}
/**
* s2_chat_query null
*/
protected abstract static class GeneratedCriteria {
protected List<Criterion> criteria; protected List<Criterion> criteria;
protected GeneratedCriteria() { protected GeneratedCriteria() {
@@ -183,6 +248,66 @@ public class ChatQueryDOExample {
return (Criteria) this; return (Criteria) this;
} }
public Criteria andAgentIdIsNull() {
addCriterion("agent_id is null");
return (Criteria) this;
}
public Criteria andAgentIdIsNotNull() {
addCriterion("agent_id is not null");
return (Criteria) this;
}
public Criteria andAgentIdEqualTo(Integer value) {
addCriterion("agent_id =", value, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdNotEqualTo(Integer value) {
addCriterion("agent_id <>", value, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdGreaterThan(Integer value) {
addCriterion("agent_id >", value, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdGreaterThanOrEqualTo(Integer value) {
addCriterion("agent_id >=", value, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdLessThan(Integer value) {
addCriterion("agent_id <", value, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdLessThanOrEqualTo(Integer value) {
addCriterion("agent_id <=", value, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdIn(List<Integer> values) {
addCriterion("agent_id in", values, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdNotIn(List<Integer> values) {
addCriterion("agent_id not in", values, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdBetween(Integer value1, Integer value2) {
addCriterion("agent_id between", value1, value2, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdNotBetween(Integer value1, Integer value2) {
addCriterion("agent_id not between", value1, value2, "agentId");
return (Criteria) this;
}
public Criteria andCreateTimeIsNull() { public Criteria andCreateTimeIsNull() {
addCriterion("create_time is null"); addCriterion("create_time is null");
return (Criteria) this; return (Criteria) this;
@@ -564,6 +689,9 @@ public class ChatQueryDOExample {
} }
} }
/**
* s2_chat_query
*/
public static class Criteria extends GeneratedCriteria { public static class Criteria extends GeneratedCriteria {
protected Criteria() { protected Criteria() {
@@ -571,8 +699,10 @@ public class ChatQueryDOExample {
} }
} }
/**
* s2_chat_query null
*/
public static class Criterion { public static class Criterion {
private String condition; private String condition;
private Object value; private Object value;
@@ -657,4 +787,4 @@ public class ChatQueryDOExample {
return typeHandler; return typeHandler;
} }
} }
} }

View File

@@ -0,0 +1,23 @@
package com.tencent.supersonic.chat.persistence.dataobject;
public enum CostType {
MAPPER(1, "mapper"),
PARSER(2, "parser"),
QUERY(3, "query");
private Integer type;
private String name;
CostType(Integer type, String name) {
this.type = type;
this.name = name;
}
public Integer getType() {
return type;
}
public String getName() {
return name;
}
}

View File

@@ -254,4 +254,4 @@ public class PluginDO {
public void setComment(String comment) { public void setComment(String comment) {
this.comment = comment == null ? null : comment.trim(); this.comment = comment == null ? null : comment.trim();
} }
} }

View File

@@ -892,38 +892,6 @@ public class PluginDOExample {
private String typeHandler; private String typeHandler;
public String getCondition() {
return condition;
}
public Object getValue() {
return value;
}
public Object getSecondValue() {
return secondValue;
}
public boolean isNoValue() {
return noValue;
}
public boolean isSingleValue() {
return singleValue;
}
public boolean isBetweenValue() {
return betweenValue;
}
public boolean isListValue() {
return listValue;
}
public String getTypeHandler() {
return typeHandler;
}
protected Criterion(String condition) { protected Criterion(String condition) {
super(); super();
this.condition = condition; this.condition = condition;
@@ -959,5 +927,37 @@ public class PluginDOExample {
protected Criterion(String condition, Object value, Object secondValue) { protected Criterion(String condition, Object value, Object secondValue) {
this(condition, value, secondValue, null); this(condition, value, secondValue, null);
} }
public String getCondition() {
return condition;
}
public Object getValue() {
return value;
}
public Object getSecondValue() {
return secondValue;
}
public boolean isNoValue() {
return noValue;
}
public boolean isSingleValue() {
return singleValue;
}
public boolean isBetweenValue() {
return betweenValue;
}
public boolean isListValue() {
return listValue;
}
public String getTypeHandler() {
return typeHandler;
}
} }
} }

View File

@@ -0,0 +1,54 @@
package com.tencent.supersonic.chat.persistence.dataobject;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import java.util.Date;
@Data
@Builder
@NoArgsConstructor
@Getter
@AllArgsConstructor
public class StatisticsDO {
/**
* questionId
*/
private Long questionId;
/**
* chatId
*/
private Long chatId;
/**
* createTime
*/
private Date createTime;
/**
* queryText
*/
private String queryText;
/**
* userName
*/
private String userName;
/**
* interface
*/
private String interfaceName;
/**
* cost
*/
private Integer cost;
private Integer type;
}

View File

@@ -0,0 +1,17 @@
package com.tencent.supersonic.chat.persistence.mapper;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
@Mapper
public interface ChatParseMapper {
boolean batchSaveParseInfo(@Param("list") List<ChatParseDO> list);
ChatParseDO getParseInfo(Long questionId, String userName, int parseId);
}

View File

@@ -14,4 +14,5 @@ public interface ChatQueryDOMapper {
int updateByPrimaryKeyWithBLOBs(ChatQueryDO record); int updateByPrimaryKeyWithBLOBs(ChatQueryDO record);
Boolean deleteByPrimaryKey(Long questionId);
} }

View File

@@ -0,0 +1,12 @@
package com.tencent.supersonic.chat.persistence.mapper;
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
@Mapper
public interface StatisticsMapper {
boolean batchSaveStatistics(@Param("list") List<StatisticsDO> list);
}

View File

@@ -0,0 +1,12 @@
package com.tencent.supersonic.chat.persistence.mapper.custom;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import org.apache.ibatis.annotations.Mapper;
import java.util.List;
@Mapper
public interface ShowCaseCustomMapper {
List<ChatQueryDO> queryShowCase(int start, int limit, int agentId);
}

View File

@@ -2,18 +2,37 @@ package com.tencent.supersonic.chat.persistence.repository;
import com.github.pagehelper.PageInfo; import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import java.util.List;
public interface ChatQueryRepository { public interface ChatQueryRepository {
PageInfo<QueryResp> getChatQuery(PageQueryInfoReq pageQueryInfoCommend, long chatId); PageInfo<QueryResp> getChatQuery(PageQueryInfoReq pageQueryInfoCommend, long chatId);
List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
void createChatQuery(QueryResult queryResult, ChatContext chatCtx); void createChatQuery(QueryResult queryResult, ChatContext chatCtx);
ChatQueryDO getLastChatQuery(long chatId); ChatQueryDO getLastChatQuery(long chatId);
int updateChatQuery(ChatQueryDO chatQueryDO); int updateChatQuery(ChatQueryDO chatQueryDO);
Long createChatParse(ParseResp parseResult, ChatContext chatCtx, QueryReq queryReq);
Boolean batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
ParseResp parseResult,
List<SemanticParseInfo> candidateParses,
List<SemanticParseInfo> selectedParses);
public ChatParseDO getParseInfo(Long questionId, String userName, int parseId);
Boolean deleteChatQuery(Long questionId);
} }

View File

@@ -0,0 +1,11 @@
package com.tencent.supersonic.chat.persistence.repository;
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import java.util.List;
public interface StatisticsRepository {
boolean batchSaveStatistics(List<StatisticsDO> list);
}

View File

@@ -1,13 +1,14 @@
package com.tencent.supersonic.chat.persistence.repository.impl; package com.tencent.supersonic.chat.persistence.repository.impl;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.config.ChatConfig; import com.tencent.supersonic.chat.config.ChatConfig;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.config.ChatConfigFilterInternal; import com.tencent.supersonic.chat.config.ChatConfigFilterInternal;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatConfigDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatConfigDO;
import com.tencent.supersonic.chat.persistence.mapper.ChatConfigMapper;
import com.tencent.supersonic.chat.persistence.repository.ChatConfigRepository; import com.tencent.supersonic.chat.persistence.repository.ChatConfigRepository;
import com.tencent.supersonic.chat.utils.ChatConfigHelper; import com.tencent.supersonic.chat.utils.ChatConfigHelper;
import com.tencent.supersonic.chat.persistence.mapper.ChatConfigMapper;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
@@ -23,7 +24,7 @@ public class ChatConfigRepositoryImpl implements ChatConfigRepository {
private final ChatConfigMapper chatConfigMapper; private final ChatConfigMapper chatConfigMapper;
public ChatConfigRepositoryImpl(ChatConfigHelper chatConfigHelper, public ChatConfigRepositoryImpl(ChatConfigHelper chatConfigHelper,
ChatConfigMapper chatConfigMapper) { ChatConfigMapper chatConfigMapper) {
this.chatConfigHelper = chatConfigHelper; this.chatConfigHelper = chatConfigHelper;
this.chatConfigMapper = chatConfigMapper; this.chatConfigMapper = chatConfigMapper;
} }
@@ -52,8 +53,8 @@ public class ChatConfigRepositoryImpl implements ChatConfigRepository {
List<ChatConfigDO> chaConfigDOList = chatConfigMapper.search(filterInternal); List<ChatConfigDO> chaConfigDOList = chatConfigMapper.search(filterInternal);
if (!CollectionUtils.isEmpty(chaConfigDOList)) { if (!CollectionUtils.isEmpty(chaConfigDOList)) {
chaConfigDOList.stream().forEach(chaConfigDO -> chaConfigDOList.stream().forEach(chaConfigDO ->
chaConfigDescriptorList.add( chaConfigDescriptorList.add(chatConfigHelper
chatConfigHelper.chatConfigDO2Descriptor(chaConfigDO.getModelId(), chaConfigDO))); .chatConfigDO2Descriptor(chaConfigDO.getModelId(), chaConfigDO)));
} }
return chaConfigDescriptorList; return chaConfigDescriptorList;
} }

View File

@@ -52,8 +52,9 @@ public class ChatContextRepositoryImpl implements ChatContextRepository {
chatContext.setUser(contextDO.getUser()); chatContext.setUser(contextDO.getUser());
chatContext.setQueryText(contextDO.getQueryText()); chatContext.setQueryText(contextDO.getQueryText());
if (contextDO.getSemanticParse() != null && !contextDO.getSemanticParse().isEmpty()) { if (contextDO.getSemanticParse() != null && !contextDO.getSemanticParse().isEmpty()) {
log.info("--->: {}",contextDO.getSemanticParse()); log.info("--->: {}", contextDO.getSemanticParse());
SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(), SemanticParseInfo.class); SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(),
SemanticParseInfo.class);
chatContext.setParseInfo(semanticParseInfo); chatContext.setParseInfo(semanticParseInfo);
} }
return chatContext; return chatContext;

View File

@@ -3,20 +3,30 @@ package com.tencent.supersonic.chat.persistence.repository.impl;
import com.github.pagehelper.PageHelper; import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo; import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample; import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample.Criteria; import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample.Criteria;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.persistence.mapper.ChatParseMapper;
import com.tencent.supersonic.chat.persistence.mapper.ChatQueryDOMapper; import com.tencent.supersonic.chat.persistence.mapper.ChatQueryDOMapper;
import com.tencent.supersonic.chat.persistence.mapper.custom.ShowCaseCustomMapper;
import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.PageUtils; import com.tencent.supersonic.common.util.PageUtils;
import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.context.annotation.Primary; import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository; import org.springframework.stereotype.Repository;
@@ -29,8 +39,16 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
private final ChatQueryDOMapper chatQueryDOMapper; private final ChatQueryDOMapper chatQueryDOMapper;
public ChatQueryRepositoryImpl(ChatQueryDOMapper chatQueryDOMapper) { private final ChatParseMapper chatParseMapper;
private final ShowCaseCustomMapper showCaseCustomMapper;
public ChatQueryRepositoryImpl(ChatQueryDOMapper chatQueryDOMapper,
ChatParseMapper chatParseMapper,
ShowCaseCustomMapper showCaseCustomMapper) {
this.chatQueryDOMapper = chatQueryDOMapper; this.chatQueryDOMapper = chatQueryDOMapper;
this.chatParseMapper = chatParseMapper;
this.showCaseCustomMapper = showCaseCustomMapper;
} }
@Override @Override
@@ -47,18 +65,27 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
PageInfo<QueryResp> chatQueryVOPageInfo = PageUtils.pageInfo2PageInfoVo(pageInfo); PageInfo<QueryResp> chatQueryVOPageInfo = PageUtils.pageInfo2PageInfoVo(pageInfo);
chatQueryVOPageInfo.setList( chatQueryVOPageInfo.setList(
pageInfo.getList().stream().map(this::convertTo) pageInfo.getList().stream().filter(o -> !StringUtils.isEmpty(o.getQueryResult())).map(this::convertTo)
.sorted(Comparator.comparingInt(o -> o.getQuestionId().intValue())) .sorted(Comparator.comparingInt(o -> o.getQuestionId().intValue()))
.collect(Collectors.toList())); .collect(Collectors.toList()));
return chatQueryVOPageInfo; return chatQueryVOPageInfo;
} }
@Override
public List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId) {
return showCaseCustomMapper.queryShowCase(pageQueryInfoCommend.getCurrent(),
pageQueryInfoCommend.getPageSize(), agentId).stream().map(this::convertTo)
.collect(Collectors.toList());
}
private QueryResp convertTo(ChatQueryDO chatQueryDO) { private QueryResp convertTo(ChatQueryDO chatQueryDO) {
QueryResp queryResponse = new QueryResp(); QueryResp queryResponse = new QueryResp();
BeanUtils.copyProperties(chatQueryDO, queryResponse); BeanUtils.copyProperties(chatQueryDO, queryResponse);
QueryResult queryResult = JsonUtil.toObject(chatQueryDO.getQueryResult(), QueryResult.class); QueryResult queryResult = JsonUtil.toObject(chatQueryDO.getQueryResult(), QueryResult.class);
queryResult.setQueryId(chatQueryDO.getQuestionId()); if (queryResult != null) {
queryResponse.setQueryResult(queryResult); queryResult.setQueryId(chatQueryDO.getQuestionId());
queryResponse.setQueryResult(queryResult);
}
return queryResponse; return queryResponse;
} }
@@ -71,12 +98,63 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
chatQueryDO.setQueryState(queryResult.getQueryState().ordinal()); chatQueryDO.setQueryState(queryResult.getQueryState().ordinal());
chatQueryDO.setQueryText(chatCtx.getQueryText()); chatQueryDO.setQueryText(chatCtx.getQueryText());
chatQueryDO.setQueryResult(JsonUtil.toString(queryResult)); chatQueryDO.setQueryResult(JsonUtil.toString(queryResult));
chatQueryDO.setAgentId(chatCtx.getAgentId());
chatQueryDOMapper.insert(chatQueryDO); chatQueryDOMapper.insert(chatQueryDO);
ChatQueryDO lastChatQuery = getLastChatQuery(chatCtx.getChatId()); ChatQueryDO lastChatQuery = getLastChatQuery(chatCtx.getChatId());
Long queryId = lastChatQuery.getQuestionId(); Long queryId = lastChatQuery.getQuestionId();
queryResult.setQueryId(queryId); queryResult.setQueryId(queryId);
} }
public Long createChatParse(ParseResp parseResult, ChatContext chatCtx, QueryReq queryReq) {
ChatQueryDO chatQueryDO = new ChatQueryDO();
chatQueryDO.setChatId(Long.valueOf(chatCtx.getChatId()));
chatQueryDO.setCreateTime(new java.util.Date());
chatQueryDO.setUserName(queryReq.getUser().getName());
chatQueryDO.setQueryText(queryReq.getQueryText());
chatQueryDO.setAgentId(queryReq.getAgentId());
chatQueryDO.setQueryResult("");
try {
chatQueryDOMapper.insert(chatQueryDO);
} catch (Exception e) {
log.info("database insert has an exception:{}", e.toString());
}
ChatQueryDO lastChatQuery = getLastChatQuery(chatCtx.getChatId());
Long queryId = lastChatQuery.getQuestionId();
parseResult.setQueryId(queryId);
return queryId;
}
public Boolean batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
ParseResp parseResult,
List<SemanticParseInfo> candidateParses,
List<SemanticParseInfo> selectedParses) {
Long queryId = createChatParse(parseResult, chatCtx, queryReq);
List<ChatParseDO> chatParseDOList = new ArrayList<>();
log.info("candidateParses size:{},selectedParses size:{}", candidateParses.size(), selectedParses.size());
getChatParseDO(chatCtx, queryReq, queryId, 0, 1, candidateParses, chatParseDOList);
getChatParseDO(chatCtx, queryReq, queryId, candidateParses.size(), 0, selectedParses, chatParseDOList);
Boolean save = chatParseMapper.batchSaveParseInfo(chatParseDOList);
return save;
}
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId, int base, int isCandidate,
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
for (int i = 0; i < parses.size(); i++) {
ChatParseDO chatParseDO = new ChatParseDO();
parses.get(i).setId(base + i + 1);
chatParseDO.setChatId(Long.valueOf(chatCtx.getChatId()));
chatParseDO.setQuestionId(queryId);
chatParseDO.setQueryText(queryReq.getQueryText());
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
chatParseDO.setIsCandidate(isCandidate);
chatParseDO.setParseId(base + i + 1);
chatParseDO.setCreateTime(new java.util.Date());
chatParseDO.setUserName(queryReq.getUser().getName());
chatParseDOList.add(chatParseDO);
}
}
@Override @Override
public ChatQueryDO getLastChatQuery(long chatId) { public ChatQueryDO getLastChatQuery(long chatId) {
ChatQueryDOExample example = new ChatQueryDOExample(); ChatQueryDOExample example = new ChatQueryDOExample();
@@ -96,4 +174,13 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
public int updateChatQuery(ChatQueryDO chatQueryDO) { public int updateChatQuery(ChatQueryDO chatQueryDO) {
return chatQueryDOMapper.updateByPrimaryKeyWithBLOBs(chatQueryDO); return chatQueryDOMapper.updateByPrimaryKeyWithBLOBs(chatQueryDO);
} }
public ChatParseDO getParseInfo(Long questionId, String userName, int parseId) {
return chatParseMapper.getParseInfo(questionId, userName, parseId);
}
@Override
public Boolean deleteChatQuery(Long questionId) {
return chatQueryDOMapper.deleteByPrimaryKey(questionId);
}
} }

View File

@@ -5,13 +5,14 @@ import com.tencent.supersonic.chat.persistence.dataobject.PluginDOExample;
import com.tencent.supersonic.chat.persistence.mapper.PluginDOMapper; import com.tencent.supersonic.chat.persistence.mapper.PluginDOMapper;
import com.tencent.supersonic.chat.persistence.repository.PluginRepository; import com.tencent.supersonic.chat.persistence.repository.PluginRepository;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;
import org.springframework.stereotype.Repository;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;
import org.springframework.stereotype.Repository;
@Repository @Repository
@Slf4j @Slf4j

View File

@@ -0,0 +1,29 @@
package com.tencent.supersonic.chat.persistence.repository.impl;
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.persistence.mapper.StatisticsMapper;
import com.tencent.supersonic.chat.persistence.repository.StatisticsRepository;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;
import java.util.List;
@Repository
@Primary
@Slf4j
public class StatisticsRepositoryImpl implements StatisticsRepository {
private final StatisticsMapper statisticsMapper;
public StatisticsRepositoryImpl(StatisticsMapper statisticsMapper) {
this.statisticsMapper = statisticsMapper;
}
public boolean batchSaveStatistics(List<StatisticsDO> list) {
return statisticsMapper.batchSaveStatistics(list);
}
;
}

View File

@@ -3,15 +3,17 @@ package com.tencent.supersonic.chat.plugin;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.agent.tool.DslTool; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.*; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.agent.Agent; import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType; import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.PluginTool; import com.tencent.supersonic.chat.agent.tool.PluginTool;
import com.tencent.supersonic.chat.parser.ParseMode; import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig; import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingResp; import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
import com.tencent.supersonic.chat.parser.embedding.RecallRetrieval;
import com.tencent.supersonic.chat.plugin.event.PluginAddEvent; import com.tencent.supersonic.chat.plugin.event.PluginAddEvent;
import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent; import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.query.plugin.ParamOption; import com.tencent.supersonic.chat.query.plugin.ParamOption;
@@ -20,10 +22,16 @@ import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.PluginService; import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import java.net.URI; import java.net.URI;
import java.util.*; import java.util.List;
import java.util.Collection;
import java.util.Set;
import java.util.Optional;
import java.util.HashSet;
import java.util.HashMap;
import java.util.Objects;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
@@ -31,7 +39,11 @@ import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.util.Strings; import org.apache.logging.log4j.util.Strings;
import org.springframework.context.event.EventListener; import org.springframework.context.event.EventListener;
import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.*; import org.springframework.http.ResponseEntity;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.HttpEntity;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriComponentsBuilder;
@@ -149,9 +161,6 @@ public class PluginManager {
} }
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) { public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
plugins = plugins.stream()
.filter(plugin -> ParseMode.EMBEDDING_RECALL.equals(plugin.getParseMode()))
.collect(Collectors.toList());
requestEmbeddingPluginAdd(convert(plugins)); requestEmbeddingPluginAdd(convert(plugins));
} }
@@ -229,11 +238,11 @@ public class PluginManager {
} }
List<ParamOption> paramOptions = getSemanticOption(plugin); List<ParamOption> paramOptions = getSemanticOption(plugin);
if (CollectionUtils.isEmpty(paramOptions)) { if (CollectionUtils.isEmpty(paramOptions)) {
return Pair.of(true, Sets.newHashSet()); return Pair.of(true, pluginMatchedModel);
} }
Set<Long> matchedModel = Sets.newHashSet(); Set<Long> matchedModel = Sets.newHashSet();
Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream(). Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream()
collect(Collectors.groupingBy(ParamOption::getModelId)); .collect(Collectors.groupingBy(ParamOption::getModelId));
for (Long modelId : paramOptionMap.keySet()) { for (Long modelId : paramOptionMap.keySet()) {
List<ParamOption> params = paramOptionMap.get(modelId); List<ParamOption> params = paramOptionMap.get(modelId);
if (CollectionUtils.isEmpty(params)) { if (CollectionUtils.isEmpty(params)) {
@@ -268,8 +277,8 @@ public class PluginManager {
return Sets.newHashSet(); return Sets.newHashSet();
} }
return schemaElementMatches.stream().filter(schemaElementMatch -> return schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()) || SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
SchemaElementType.ID.equals(schemaElementMatch.getElement().getType())) || SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.map(SchemaElementMatch::getElement) .map(SchemaElementMatch::getElement)
.map(SchemaElement::getId) .map(SchemaElement::getId)
.collect(Collectors.toSet()); .collect(Collectors.toSet());

View File

@@ -1,12 +1,13 @@
package com.tencent.supersonic.chat.plugin; package com.tencent.supersonic.chat.plugin;
import com.tencent.supersonic.chat.parser.function.Parameters; import com.tencent.supersonic.chat.parser.plugin.function.Parameters;
import java.io.Serializable;
import java.util.List;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import java.io.Serializable;
import java.util.List;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.ToString; import lombok.ToString;
@@ -17,12 +18,12 @@ import lombok.ToString;
@NoArgsConstructor @NoArgsConstructor
public class PluginParseConfig implements Serializable { public class PluginParseConfig implements Serializable {
private String name;
private String description;
public Parameters parameters; public Parameters parameters;
public List<String> examples; public List<String> examples;
private String name;
private String description;
} }

View File

@@ -5,11 +5,13 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import java.util.List;
import java.util.ArrayList;
import java.util.OptionalDouble;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery; import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery; import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import java.util.ArrayList;
import java.util.List;
import java.util.OptionalDouble;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
@@ -50,8 +52,8 @@ public class HeuristicQuerySelector implements QuerySelector {
return true; return true;
} }
for (SemanticQuery candidateQuery : candidateQueries) { for (SemanticQuery candidateQuery : candidateQueries) {
if (candidateQuery.getQueryMode().equals(MetricEntityQuery.QUERY_MODE) && if (candidateQuery.getQueryMode().equals(MetricEntityQuery.QUERY_MODE)
semanticQuery.getParseInfo().getScore() == candidateQuery.getParseInfo().getScore()) { && semanticQuery.getParseInfo().getScore() == candidateQuery.getParseInfo().getScore()) {
return false; return false;
} }
} }

View File

@@ -1,17 +0,0 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FieldCorrector extends BaseDSLOptimizer {
@Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
String replaceFields = CCJSqlParserUtils.replaceFields(correctionInfo.getSql(),
getFieldToBizName(correctionInfo.getParseInfo().getModelId()));
correctionInfo.setSql(replaceFields);
return correctionInfo;
}
}

View File

@@ -1,16 +0,0 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FunctionCorrector extends BaseDSLOptimizer {
@Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
String replaceFunction = CCJSqlParserUtils.replaceFunction(correctionInfo.getSql());
correctionInfo.setSql(replaceFunction);
return correctionInfo;
}
}

View File

@@ -1,15 +1,12 @@
package com.tencent.supersonic.chat.query.dsl; package com.tencent.supersonic.chat.query.llm.dsl;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticLayer; import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState; import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult; import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.component.DSLOptimizer;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.ComponentFactory;
@@ -29,12 +26,12 @@ import org.springframework.stereotype.Component;
@Slf4j @Slf4j
@Component @Component
public class DSLQuery extends PluginSemanticQuery { public class DslQuery extends PluginSemanticQuery {
public static final String QUERY_MODE = "DSL"; public static final String QUERY_MODE = "DSL";
protected SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer(); protected SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
public DSLQuery() { public DslQuery() {
QueryManager.register(this); QueryManager.register(this);
} }
@@ -48,31 +45,12 @@ public class DSLQuery extends PluginSemanticQuery {
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT)); String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class); DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class);
LLMResp llmResp = dslParseResult.getLlmResp(); LLMResp llmResp = dslParseResult.getLlmResp();
QueryReq queryReq = dslParseResult.getRequest();
CorrectionInfo correctionInfo = CorrectionInfo.builder()
.queryFilters(queryReq.getQueryFilters())
.sql(llmResp.getSqlOutput())
.parseInfo(parseInfo)
.build();
List<DSLOptimizer> DSLCorrections = ComponentFactory.getSqlCorrections();
DSLCorrections.forEach(DSLCorrection -> {
try {
DSLCorrection.rewriter(correctionInfo);
log.info("sqlCorrection:{} sql:{}", DSLCorrection.getClass().getSimpleName(), correctionInfo.getSql());
} catch (Exception e) {
log.error("sqlCorrection:{} execute error,correctionInfo:{}", DSLCorrection, correctionInfo, e);
}
});
String querySql = correctionInfo.getSql();
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(querySql, parseInfo.getModelId()); QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(llmResp.getCorrectorSql(), parseInfo.getModelId());
QueryResultWithSchemaResp queryResp = semanticLayer.queryByDsl(queryDslReq, user); QueryResultWithSchemaResp queryResp = semanticLayer.queryByDsl(queryDslReq, user);
log.info("queryByDsl cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql); log.info("queryByDsl cost:{},querySql:{}", System.currentTimeMillis() - startTime, llmResp.getSqlOutput());
QueryResult queryResult = new QueryResult(); QueryResult queryResult = new QueryResult();
if (Objects.nonNull(queryResp)) { if (Objects.nonNull(queryResp)) {

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.dsl; package com.tencent.supersonic.chat.query.llm.dsl;
import java.util.List; import java.util.List;
import lombok.Data; import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.dsl; package com.tencent.supersonic.chat.query.llm.dsl;
import java.util.List; import java.util.List;
import lombok.Data; import lombok.Data;
@@ -17,4 +17,6 @@ public class LLMResp {
private String schemaLinkingOutput; private String schemaLinkingOutput;
private String schemaLinkStr; private String schemaLinkStr;
private String correctorSql;
} }

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.metricInterpret; package com.tencent.supersonic.chat.query.metricinterpret;
import lombok.Data; import lombok.Data;

View File

@@ -1,11 +1,9 @@
package com.tencent.supersonic.chat.query.metricInterpret; package com.tencent.supersonic.chat.query.metricinterpret;
import lombok.Data; import lombok.Data;
@Data @Data
public class LLmAnswerResp { public class LLmAnswerResp {
private String assistantMessage;
private String assistant_message;
} }

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.metricInterpret; package com.tencent.supersonic.chat.query.metricinterpret;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
@@ -26,7 +26,11 @@ import org.apache.commons.lang3.StringUtils;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.*;
import java.util.Map;
import java.util.HashMap;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Slf4j @Slf4j
@@ -34,17 +38,17 @@ import java.util.stream.Collectors;
public class MetricInterpretQuery extends PluginSemanticQuery { public class MetricInterpretQuery extends PluginSemanticQuery {
public final static String QUERY_MODE = "METRIC_INTERPRET"; public static final String QUERY_MODE = "METRIC_INTERPRET";
public MetricInterpretQuery() {
QueryManager.register(this);
}
@Override @Override
public String getQueryMode() { public String getQueryMode() {
return QUERY_MODE; return QUERY_MODE;
} }
public MetricInterpretQuery() {
QueryManager.register(this);
}
@Override @Override
public QueryResult execute(User user) throws SqlParseException { public QueryResult execute(User user) throws SqlParseException {
QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo); QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo);
@@ -55,10 +59,11 @@ public class MetricInterpretQuery extends PluginSemanticQuery {
String text = generateTableText(queryResultWithSchemaResp); String text = generateTableText(queryResultWithSchemaResp);
Map<String, Object> properties = parseInfo.getProperties(); Map<String, Object> properties = parseInfo.getProperties();
Map<String, String> replacedMap = new HashMap<>(); Map<String, String> replacedMap = new HashMap<>();
String textReplaced = replaceText((String) properties.get("queryText"), parseInfo.getElementMatches(), replacedMap); String textReplaced = replaceText((String) properties.get("queryText"),
parseInfo.getElementMatches(), replacedMap);
String answer = replaceAnswer(fetchInterpret(textReplaced, text), replacedMap); String answer = replaceAnswer(fetchInterpret(textReplaced, text), replacedMap);
QueryResult queryResult = new QueryResult(); QueryResult queryResult = new QueryResult();
List<QueryColumn> queryColumns = Lists.newArrayList(new QueryColumn("结果","string","answer")); List<QueryColumn> queryColumns = Lists.newArrayList(new QueryColumn("结果", "string", "answer"));
Map<String, Object> result = new HashMap<>(); Map<String, Object> result = new HashMap<>();
result.put("answer", answer); result.put("answer", answer);
List<Map<String, Object>> resultList = Lists.newArrayList(); List<Map<String, Object>> resultList = Lists.newArrayList();
@@ -70,7 +75,8 @@ public class MetricInterpretQuery extends PluginSemanticQuery {
return queryResult; return queryResult;
} }
private String replaceText(String text, List<SchemaElementMatch> schemaElementMatches, Map<String, String> replacedMap) { private String replaceText(String text, List<SchemaElementMatch> schemaElementMatches,
Map<String, String> replacedMap) {
if (CollectionUtils.isEmpty(schemaElementMatches)) { if (CollectionUtils.isEmpty(schemaElementMatches)) {
return text; return text;
} }
@@ -134,10 +140,10 @@ public class MetricInterpretQuery extends PluginSemanticQuery {
JSONObject.toJSONString(lLmAnswerReq)); JSONObject.toJSONString(lLmAnswerReq));
LLmAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLmAnswerResp.class); LLmAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLmAnswerResp.class);
if (lLmAnswerResp != null) { if (lLmAnswerResp != null) {
return lLmAnswerResp.getAssistant_message(); return lLmAnswerResp.getAssistantMessage();
} }
return null; return null;
} }
} }

Some files were not shown because too many files have changed in this diff Show More