[release][project] supersonic 0.7.3 version backend update (#40)

* [improvement] add some features

* [improvement] revise CHANGELOG

---------

Co-authored-by: zuopengge <hwzuopengge@tencent.com>
This commit is contained in:
mainmain
2023-08-29 20:06:34 +08:00
committed by GitHub
parent 6fe9ab79ed
commit e1911bc81b
260 changed files with 6466 additions and 7108 deletions

View File

@@ -4,7 +4,17 @@
- "Breaking Changes" describes any changes that may break existing functionality or cause
compatibility issues with previous versions.
## SuperSonic [0.7.3] - 2023-08-29
### Added
- meet checkstyle code requirements
- save parseInfo after parsing
- add time statistics
- add agent
### Updated
- dsl where condition is used for front-end display
- dsl remove context inheritance
## SuperSonic [0.7.2] - 2023-08-12

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash
sbinDir=$(cd "$(dirname "$0")"; pwd)
baseDir=$(cd "$sbinDir/.." && pwd -P)
baseDir=$(readlink -f $sbinDir/../)
runtimeDir=$baseDir/runtime
buildDir=$baseDir/build

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash
sbinDir=$(cd "$(dirname "$0")"; pwd)
baseDir=$(cd "$sbinDir/.." && pwd -P)
baseDir=$(readlink -f $sbinDir/../)
buildDir=$baseDir/build
cd $baseDir/bin

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash
sbinDir=$(cd "$(dirname "$0")"; pwd)
baseDir=$(cd "$sbinDir/.." && pwd -P)
baseDir=$(readlink -f $sbinDir/../)
runtimeDir=$baseDir/runtime
buildDir=$baseDir/build

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash
sbinDir=$(cd "$(dirname "$0")"; pwd)
baseDir=$(cd "$sbinDir/.." && pwd -P)
baseDir=$(readlink -f $sbinDir/../)
runtimeDir=$baseDir/runtime
buildDir=$baseDir/build

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash
sbinDir=$(cd "$(dirname "$0")"; pwd)
baseDir=$(cd "$sbinDir/.." && pwd -P)
baseDir=$(readlink -f $sbinDir/../)
runtimeDir=$baseDir/../runtime
buildDir=$baseDir/build

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash
sbinDir=$(cd "$(dirname "$0")"; pwd)
baseDir=$(cd "$sbinDir/.." && pwd -P)
baseDir=$(readlink -f $sbinDir/../)
runtimeDir=$baseDir/../runtime
buildDir=$baseDir/build

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash
sbinDir=$(cd "$(dirname "$0")"; pwd)
baseDir=$(cd "$sbinDir/.." && pwd -P)
baseDir=$(readlink -f $sbinDir/../)
runtimeDir=$baseDir/../runtime
buildDir=$baseDir/build
@@ -29,4 +29,4 @@ rm -fr ${buildDir}/supersonic-webapp
#start standalone service
sh ${runtimeDir}/supersonic-standalone/bin/service.sh restart
#start llm service
sh ${runtimeDir}/supersonic-standalone/llm/bin/service.sh restart
sh ${runtimeDir}/supersonic-standalone/llm/bin/service.sh restart

View File

@@ -5,7 +5,7 @@ import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq;
import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
public interface AuthService {

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.auth.authentication.persistence.repository.Impl;
package com.tencent.supersonic.auth.authentication.persistence.repository.impl;
import com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO;

View File

@@ -11,12 +11,12 @@ import java.util.Set;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.PathVariable;
@RestController
@RequestMapping("/api/auth/user")

View File

@@ -9,7 +9,6 @@ import static com.tencent.supersonic.auth.api.authentication.constant.UserConsta
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_ID;
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_NAME;
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_PASSWORD;
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.pojo.UserWithPassword;
@@ -22,9 +21,11 @@ import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;
@Slf4j
@Component
public class UserTokenUtils {
@@ -68,7 +69,9 @@ public class UserTokenUtils {
public UserWithPassword getUserWithPassword(HttpServletRequest request) {
String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey());
if (StringUtils.isBlank(token)) {
throw new AccessException("token is blank, get user failed");
String message = "token is blank, get user failed";
log.warn("{}, uri: {}", message, request.getServletPath());
throw new AccessException(message);
}
final Claims claims = getClaims(token);
Long userId = Long.parseLong(claims.getOrDefault(TOKEN_USER_ID, 0).toString());

View File

@@ -12,13 +12,16 @@ import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResource
import com.tencent.supersonic.auth.api.authorization.service.AuthService;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
import com.tencent.supersonic.common.util.S2ThreadContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.*;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.Map;
import java.util.ArrayList;
import java.util.stream.Collectors;
@Service

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.api.component;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import net.sf.jsqlparser.JSQLParserException;
public interface DSLOptimizer {
CorrectionInfo rewriter(CorrectionInfo correctionInfo) throws JSQLParserException;
public interface SemanticCorrector {
CorrectionInfo corrector(CorrectionInfo correctionInfo) throws JSQLParserException;
}

View File

@@ -6,14 +6,15 @@ import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.List;
/**
@@ -31,22 +32,13 @@ import java.util.List;
public interface SemanticLayer {
QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user);
QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
QueryResultWithSchemaResp queryByDsl(QueryDslReq queryDslReq, User user);
List<ModelSchema> getModelSchema();
List<ModelSchema> getModelSchema(List<Long> ids);
ModelSchema getModelSchema(Long model, Boolean cacheEnable);
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionCmd);
PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricCmd);
List<DomainResp> getDomainList(User user);
List<ModelResp> getModelList(AuthType authType, Long domainId, User user);
}

View File

@@ -6,6 +6,7 @@ import lombok.Data;
public class ChatContext {
private Integer chatId;
private Integer agentId;
private String queryText;
private SemanticParseInfo parseInfo = new SemanticParseInfo();
private String user;

View File

@@ -32,6 +32,7 @@ public class ModelSchema {
break;
case VALUE:
element = dimensionValues.stream().filter(e -> e.getId() == elementID).findFirst();
break;
default:
}

View File

@@ -1,18 +1,19 @@
package com.tencent.supersonic.chat.api.pojo;
import com.google.common.base.Objects;
import java.io.Serializable;
import java.util.List;
import lombok.Builder;
import lombok.Data;
import lombok.Getter;
import lombok.Builder;
import lombok.NoArgsConstructor;
@Data
@Getter
@Builder
@NoArgsConstructor
//@AllArgsConstructor
public class SchemaElement implements Serializable {
private Long model;
@@ -23,11 +24,8 @@ public class SchemaElement implements Serializable {
private SchemaElementType type;
private List<String> alias;
// public SchemaElement() {
// }
public SchemaElement(Long model, Long id, String name, String bizName,
Long useCnt, SchemaElementType type, List<String> alias) {
Long useCnt, SchemaElementType type, List<String> alias) {
this.model = model;
this.id = id;
this.name = name;

View File

@@ -1,23 +1,26 @@
package com.tencent.supersonic.chat.api.pojo;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import java.util.LinkedHashSet;
import java.util.ArrayList;
import java.util.Map;
import java.util.HashMap;
import java.util.Comparator;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import lombok.Data;
@Data
public class SemanticParseInfo {
private Integer id;
private String queryMode;
private SchemaElement model;
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
@@ -33,7 +36,7 @@ public class SemanticParseInfo {
private double score;
private List<SchemaElementMatch> elementMatches = new ArrayList<>();
private Map<String, Object> properties = new HashMap<>();
private EntityInfo entityInfo;
public Long getModelId() {
return model != null ? model.getId() : 0L;
}
@@ -43,7 +46,6 @@ public class SemanticParseInfo {
}
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
@Override
public int compare(SchemaElement o1, SchemaElement o2) {
int len1 = o1.getName().length();

View File

@@ -9,8 +9,11 @@ import lombok.Data;
public class ExecuteQueryReq {
private User user;
private Integer agentId;
private Integer chatId;
private String queryText;
private Long queryId;
private Integer parseId;
private SemanticParseInfo parseInfo;
private boolean saveAnswer = true;
}

View File

@@ -6,6 +6,5 @@ import lombok.Data;
@Data
public class AggregateInfo {
private List<MetricInfo> metricInfos = new ArrayList<>();
}

View File

@@ -1,12 +1,13 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.Getter;
import lombok.Builder;
import lombok.NoArgsConstructor;
import lombok.AllArgsConstructor;
import java.util.List;
@Data
@Getter
@@ -14,9 +15,9 @@ import lombok.NoArgsConstructor;
@NoArgsConstructor
@AllArgsConstructor
public class ParseResp {
private Integer chatId;
private String queryText;
private Long queryId;
private ParseState state;
private List<SemanticParseInfo> selectedParses;
private List<SemanticParseInfo> candidateParses;

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.api.pojo.response;
import java.util.List;
import lombok.Data;
@Data
public class SearchResp {

View File

@@ -0,0 +1,18 @@
package com.tencent.supersonic.chat.api.pojo.response;
import lombok.Data;
import java.util.List;
import java.util.Map;
@Data
public class ShowCaseResp {
private Map<Long, List<QueryResp>> showCaseMap;
private int pageSize;
private int current;
}

View File

@@ -5,6 +5,7 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.common.pojo.RecordInfo;
import java.util.Objects;
import lombok.Data;
import org.springframework.util.CollectionUtils;
import java.util.List;
@@ -23,7 +24,6 @@ public class Agent extends RecordInfo {
private Integer status;
private List<String> examples;
private String agentConfig;
public List<String> getTools(AgentToolType type) {
Map map = JSONObject.parseObject(agentConfig, Map.class);
if (CollectionUtils.isEmpty(map) || map.get("tools") == null) {
@@ -31,7 +31,13 @@ public class Agent extends RecordInfo {
}
List<Map> toolList = (List) map.get("tools");
return toolList.stream()
.filter(tool -> type.name().equals(tool.get("type")))
.filter(tool -> {
if (Objects.isNull(type)) {
return true;
}
return type.name().equals(tool.get("type"));
}
)
.map(JSONObject::toJSONString)
.collect(Collectors.toList());
}

View File

@@ -1,12 +1,9 @@
package com.tencent.supersonic.chat.agent.tool;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@NoArgsConstructor
@AllArgsConstructor

View File

@@ -2,12 +2,19 @@ package com.tencent.supersonic.chat.agent.tool;
import lombok.Data;
import org.apache.commons.collections.CollectionUtils;
import java.util.List;
@Data
public class RuleQueryTool extends AgentTool {
private List<Long> modelIds;
private List<String> queryModes;
public boolean isContainsAllModel() {
return CollectionUtils.isNotEmpty(modelIds) && modelIds.contains(-1L);
}
}

View File

@@ -0,0 +1,99 @@
/*
//package com.tencent.supersonic.chat.aspect;
//
//import lombok.extern.slf4j.Slf4j;
//import org.aspectj.lang.JoinPoint;
//import org.aspectj.lang.ProceedingJoinPoint;
//import org.aspectj.lang.annotation.*;
//import org.springframework.stereotype.Component;
//
//import java.util.HashMap;
//import java.util.Map;
//
//@Aspect
//@Component
//@Slf4j
//public class TimeCostAspect {
//
// ThreadLocal<Long> startTime = new ThreadLocal<>();
//
// ThreadLocal<Map<String, Long>> map = new ThreadLocal<>();
//
// @Pointcut("execution(public * com.tencent.supersonic.chat.mapper.HanlpDictMapper.*(*))")
// //@Pointcut("execution(* public com.tencent.supersonic.chat.parser.*.*(..))")
// //@Pointcut("execution(* com.tencent.supersonic.chat.mapper.*Mapper.map(..)) ")
// //@Pointcut("execution(* com.tencent.supersonic.chat.mapper.HanlpDictMapper.map(..)) ")
// //@Pointcut("execution(* com.tencent.supersonic.chat.parser.rule.QueryModeParser.*(..)) ")
// public void point() {
// }
//
// @Around("point()")
// public void doAround(ProceedingJoinPoint joinPoint) throws Throwable {
// long start = System.currentTimeMillis();
// try {
// log.info("切面开始");
// Object result = joinPoint.proceed();
// log.info("切面开始");
// if (result == null) {
// //如果切到了 没有返回类型的void方法这里直接返回
// //return null;
// }
// long end = System.currentTimeMillis();
// log.info("===================");
// String targetClassName = joinPoint.getSignature().getDeclaringTypeName();
// String MethodName = joinPoint.getSignature().getName();
// String typeStr = joinPoint.getSignature().getDeclaringType().toString().split(" ")[0];
// log.info("类/接口:" + targetClassName + "(" + typeStr + ")");
// log.info("方法:" + MethodName);
// Long total = end - start;
// log.info("耗时: " + total + " ms!");
// map.get().put(targetClassName + "_" + MethodName, total);
// //return result;
// } catch (Throwable e) {
// long end = System.currentTimeMillis();
// log.info("====around " + joinPoint + "\tUse time : " + (end - start) + " ms with exception : "
// + e.getMessage());
// throw e;
// }
// }
//
//// //对Controller下面的方法执行前进行切入初始化开始时间
//// @Before(value = "execution(* com.appleyk.controller.*.*(..))")
//// public void beforMehhod(JoinPoint jp) {
//// startTime.set(System.currentTimeMillis());
//// }
////
//// //对Controller下面的方法执行后进行切入统计方法执行的次数和耗时情况
//// //注意这里的执行方法统计的数据不止包含Controller下面的方法也包括环绕切入的所有方法的统计信息
//// @AfterReturning(value = "execution(* com.appleyk.controller.*.*(..))")
//// public void afterMehhod(JoinPoint jp) {
//// long end = System.currentTimeMillis();
//// long total = end - startTime.get();
//// String methodName = jp.getSignature().getName();
//// log.info("连接点方法为:" + methodName + ",执行总耗时为:" +total+"ms");
////
//// //重新new一个map
//// Map<String, Long> map = new HashMap<>();
//////从map2中将最后的 连接点方法给移除了,替换成最终的,避免连接点方法多次进行叠加计算
//// //由于map2受ThreadLocal的保护这里不支持remove因此需要单开一个map进行数据交接
//// for(Map.Entry<String, Long> entry:map2.get().entrySet()){
//// if(entry.getKey().equals(methodName)){
//// map.put(methodName, total);
////
//// }else{
//// map.put(entry.getKey(), entry.getValue());
//// }
//// }
////
//// for (Map.Entry<String, Long> entry :map1.get().entrySet()) {
//// for(Map.Entry<String, Long> entry2 :map.entrySet()){
//// if(entry.getKey().equals(entry2.getKey())){
//// System.err.println(entry.getKey()+",被调用次数:"+entry.getValue()+",综合耗时:"+entry2.getValue()+"ms");
//// }
//// }
////
//// }
//// }
//
//}
*/

View File

@@ -3,11 +3,9 @@ package com.tencent.supersonic.chat.config;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Configuration
@Data
@Configuration
public class AggregatorConfig {
@Value("${metric.aggregator.ratio.enable:true}")
private Boolean enableRatio;
}

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.component.DSLOptimizer;
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -13,7 +13,7 @@ import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public abstract class BaseDSLOptimizer implements DSLOptimizer {
public abstract class BaseSemanticCorrector implements SemanticCorrector {
public static final String DATE_FIELD = "数据日期";
protected Map<String, String> getFieldToBizName(Long modelId) {

View File

@@ -1,23 +1,24 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class DateFieldCorrector extends BaseDSLOptimizer {
public class DateFieldCorrector extends BaseSemanticCorrector {
@Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
String sql = correctionInfo.getSql();
List<String> whereFields = CCJSqlParserUtils.getWhereFields(sql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(BaseDSLOptimizer.DATE_FIELD)) {
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DATE_FIELD)) {
String currentDate = DSLDateHelper.getCurrentDate(correctionInfo.getParseInfo().getModelId());
sql = CCJSqlParserUtils.addWhere(sql, BaseDSLOptimizer.DATE_FIELD, currentDate);
sql = SqlParserUpdateHelper.addWhere(sql, DATE_FIELD, currentDate);
}
correctionInfo.setSql(sql);
return correctionInfo;

View File

@@ -0,0 +1,17 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FieldCorrector extends BaseSemanticCorrector {
@Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
String replaceFields = SqlParserUpdateHelper.replaceFields(correctionInfo.getSql(),
getFieldToBizName(correctionInfo.getParseInfo().getModelId()));
correctionInfo.setSql(replaceFields);
return correctionInfo;
}
}

View File

@@ -0,0 +1,48 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class FieldValueCorrector extends BaseSemanticCorrector {
@Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
Object context = correctionInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
if (Objects.isNull(context)) {
return correctionInfo;
}
DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class);
if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) {
return correctionInfo;
}
LLMReq llmReq = dslParseResult.getLlmReq();
List<ElementValue> linking = llmReq.getLinking();
if (CollectionUtils.isEmpty(linking)) {
return correctionInfo;
}
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
Collectors.groupingBy(ElementValue::getFieldValue,
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
String sql = SqlParserUpdateHelper.replaceValueFields(correctionInfo.getSql(), fieldValueToFieldNames);
correctionInfo.setSql(sql);
return correctionInfo;
}
}

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FunctionCorrector extends BaseSemanticCorrector {
@Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
String replaceFunction = SqlParserUpdateHelper.replaceFunction(correctionInfo.getSql());
correctionInfo.setSql(replaceFunction);
return correctionInfo;
}
}

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
@@ -15,17 +15,17 @@ import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@Slf4j
public class QueryFilterAppend extends BaseDSLOptimizer {
public class QueryFilterAppend extends BaseSemanticCorrector {
@Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) throws JSQLParserException {
public CorrectionInfo corrector(CorrectionInfo correctionInfo) throws JSQLParserException {
String queryFilter = getQueryFilter(correctionInfo.getQueryFilters());
String sql = correctionInfo.getSql();
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to sql :{}", queryFilter);
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
sql = CCJSqlParserUtils.addWhere(sql, expression);
sql = SqlParserUpdateHelper.addWhere(sql, expression);
}
correctionInfo.setSql(sql);
return correctionInfo;

View File

@@ -1,7 +1,8 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.ArrayList;
import java.util.HashSet;
@@ -10,25 +11,28 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class SelectFieldAppendCorrector extends BaseDSLOptimizer {
public class SelectFieldAppendCorrector extends BaseSemanticCorrector {
@Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
String sql = correctionInfo.getSql();
if (CCJSqlParserUtils.hasAggregateFunction(sql)) {
if (SqlParserSelectHelper.hasAggregateFunction(sql)) {
return correctionInfo;
}
Set<String> selectFields = new HashSet<>(CCJSqlParserUtils.getSelectFields(sql));
Set<String> whereFields = new HashSet<>(CCJSqlParserUtils.getWhereFields(sql));
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql));
Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(sql));
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
return correctionInfo;
}
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
whereFields.removeAll(selectFields);
whereFields.remove(TimeDimensionEnum.DAY.getName());
whereFields.remove(TimeDimensionEnum.WEEK.getName());
whereFields.remove(TimeDimensionEnum.MONTH.getName());
String replaceFields = CCJSqlParserUtils.addFieldsToSelect(sql, new ArrayList<>(whereFields));
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
correctionInfo.setSql(replaceFields);
return correctionInfo;
}

View File

@@ -1,19 +1,19 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class TableNameCorrector extends BaseDSLOptimizer {
public class TableNameCorrector extends BaseSemanticCorrector {
public static final String TABLE_PREFIX = "t_";
@Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
Long modelId = correctionInfo.getParseInfo().getModelId();
String sqlOutput = correctionInfo.getSql();
String replaceTable = CCJSqlParserUtils.replaceTable(sqlOutput, TABLE_PREFIX + modelId);
String replaceTable = SqlParserUpdateHelper.replaceTable(sqlOutput, TABLE_PREFIX + modelId);
correctionInfo.setSql(replaceTable);
return correctionInfo;
}

View File

@@ -2,19 +2,19 @@ package com.tencent.supersonic.chat.mapper;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.stream.Collectors;
@Slf4j
@@ -33,7 +33,7 @@ public class EntityMapper implements SchemaMapper {
continue;
}
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
@@ -51,7 +51,7 @@ public class EntityMapper implements SchemaMapper {
}
private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch,
List<SchemaElementMatch> schemaElementMatchList) {
List<SchemaElementMatch> schemaElementMatchList) {
List<SchemaElementMatch> entitySchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
SchemaElementType.ENTITY.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());

View File

@@ -2,12 +2,12 @@ package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
@@ -43,13 +43,13 @@ public class FuzzyNameMapper implements SchemaMapper {
log.debug("after db mapper,mapInfo:{}", queryContext.getMapInfo());
}
private void detectAndAddToSchema(QueryContext queryContext, List<Term> terms, List<SchemaElement> Models,
private void detectAndAddToSchema(QueryContext queryContext, List<Term> terms, List<SchemaElement> models,
SchemaElementType schemaElementType) {
try {
Map<String, Set<SchemaElement>> ModelResultSet = getResultSet(queryContext, terms, Models);
Map<String, Set<SchemaElement>> modelResultSet = getResultSet(queryContext, terms, models);
addToSchemaMapInfo(ModelResultSet, queryContext.getMapInfo(), schemaElementType);
addToSchemaMapInfo(modelResultSet, queryContext.getMapInfo(), schemaElementType);
} catch (Exception e) {
log.error("detectAndAddToSchema error", e);
@@ -57,20 +57,21 @@ public class FuzzyNameMapper implements SchemaMapper {
}
private Map<String, Set<SchemaElement>> getResultSet(QueryContext queryContext, List<Term> terms,
List<SchemaElement> Models) {
List<SchemaElement> models) {
String queryText = queryContext.getRequest().getQueryText();
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest());
Double metricDimensionThresholdConfig = getThreshold(queryContext, mapperHelper);
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(Models);
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(models);
Map<Integer, Integer> regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length))
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
Map<String, Set<SchemaElement>> ModelResultSet = new HashMap<>();
Map<String, Set<SchemaElement>> modelResultSet = new HashMap<>();
for (Integer startIndex = 0; startIndex <= queryText.length() - 1; ) {
for (Integer endIndex = startIndex; endIndex <= queryText.length(); ) {
endIndex = mapperHelper.getStepIndex(regOffsetToLength, endIndex);
@@ -86,8 +87,12 @@ public class FuzzyNameMapper implements SchemaMapper {
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
continue;
}
Set<SchemaElement> preSchemaElements = ModelResultSet.putIfAbsent(detectSegment,
schemaElements);
if (!CollectionUtils.isEmpty(modelIds)) {
schemaElements = schemaElements.stream()
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
.collect(Collectors.toSet());
}
Set<SchemaElement> preSchemaElements = modelResultSet.putIfAbsent(detectSegment, schemaElements);
if (Objects.nonNull(preSchemaElements)) {
preSchemaElements.addAll(schemaElements);
}
@@ -95,7 +100,7 @@ public class FuzzyNameMapper implements SchemaMapper {
}
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
}
return ModelResultSet;
return modelResultSet;
}
private Double getThreshold(QueryContext queryContext, MapperHelper mapperHelper) {
@@ -103,9 +108,9 @@ public class FuzzyNameMapper implements SchemaMapper {
Double metricDimensionThresholdConfig = mapperHelper.getMetricDimensionThresholdConfig();
Double metricDimensionMinThresholdConfig = mapperHelper.getMetricDimensionMinThresholdConfig();
Map<Long, List<SchemaElementMatch>> ModelElementMatches = queryContext.getMapInfo()
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo()
.getModelElementMatches();
boolean existElement = ModelElementMatches.entrySet().stream()
boolean existElement = modelElementMatches.entrySet().stream()
.anyMatch(entry -> entry.getValue().size() >= 1);
if (!existElement) {
@@ -114,13 +119,13 @@ public class FuzzyNameMapper implements SchemaMapper {
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
: metricDimensionMinThresholdConfig;
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
ModelElementMatches, metricDimensionThresholdConfig);
modelElementMatches, metricDimensionThresholdConfig);
}
return metricDimensionThresholdConfig;
}
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> Models) {
return Models.stream().collect(
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
return models.stream().collect(
Collectors.toMap(SchemaElement::getName, a -> {
Set<SchemaElement> result = new HashSet<>();
result.add(a);

View File

@@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.NatureHelper;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
@@ -21,6 +21,7 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@@ -37,10 +38,13 @@ public class HanlpDictMapper implements SchemaMapper {
for (Term term : terms) {
log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
}
Long modelId = queryContext.getRequest().getModelId();
QueryMatchStrategy matchStrategy = ContextUtils.getBean(QueryMatchStrategy.class);
Map<MatchText, List<MapResult>> matchResult = matchStrategy.match(queryText, terms, modelId);
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest());
Map<MatchText, List<MapResult>> matchResult = matchStrategy.match(queryContext.getRequest(), terms,
detectModelIds);
List<MapResult> matches = getMatches(matchResult);
@@ -51,6 +55,7 @@ public class HanlpDictMapper implements SchemaMapper {
convertTermsToSchemaMapInfo(matches, queryContext.getMapInfo(), terms);
}
private void convertTermsToSchemaMapInfo(List<MapResult> mapResults, SchemaMapInfo schemaMap, List<Term> terms) {
if (CollectionUtils.isEmpty(mapResults)) {
return;

View File

@@ -1,10 +1,16 @@
package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.algorithm.EditDistance;
import com.tencent.supersonic.chat.utils.NatureHelper;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
@@ -84,5 +90,24 @@ public class MapperHelper {
detectSegment.length());
}
public Set<Long> getModelIds(QueryReq request) {
Long modelId = request.getModelId();
AgentService agentService = ContextUtils.getBean(AgentService.class);
Set<Long> detectModelIds = agentService.getDslToolsModelIds(request.getAgentId(), null);
if (Objects.nonNull(detectModelIds)) {
detectModelIds = detectModelIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
}
if (Objects.nonNull(modelId) && modelId > 0 && Objects.nonNull(detectModelIds)) {
if (detectModelIds.contains(modelId)) {
Set<Long> result = new HashSet<>();
result.add(modelId);
return result;
}
}
return detectModelIds;
}
}

View File

@@ -1,15 +1,17 @@
package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* match strategy
*/
public interface MatchStrategy {
Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectModelId);
Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelId);
}

View File

@@ -3,25 +3,25 @@ package com.tencent.supersonic.chat.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.pojo.Constants;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class QueryFilterMapper implements SchemaMapper {
private Long FREQUENCY = 9999999L;
private double SIMILARITY = 1.0;
private Long frequency = 9999999L;
private double similarity = 1.0;
@Override
public void map(QueryContext queryContext) {
@@ -49,7 +49,7 @@ public class QueryFilterMapper implements SchemaMapper {
}
private List<SchemaElementMatch> addValueSchemaElementMatch(List<SchemaElementMatch> candidateElementMatches,
QueryFilters queryFilter) {
QueryFilters queryFilter) {
if (queryFilter == null || CollectionUtils.isEmpty(queryFilter.getFilters())) {
return candidateElementMatches;
}
@@ -65,9 +65,9 @@ public class QueryFilterMapper implements SchemaMapper {
.build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element)
.frequency(FREQUENCY)
.frequency(frequency)
.word(String.valueOf(filter.getValue()))
.similarity(SIMILARITY)
.similarity(similarity)
.detectWord(Constants.EMPTY)
.build();
candidateElementMatches.add(schemaElementMatch);
@@ -76,7 +76,7 @@ public class QueryFilterMapper implements SchemaMapper {
}
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,
List<SchemaElementMatch> schemaElementMatches) {
List<SchemaElementMatch> schemaElementMatches) {
List<SchemaElementMatch> valueSchemaElements = schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.service.SearchService;
import java.util.ArrayList;
@@ -32,7 +32,8 @@ public class QueryMatchStrategy implements MatchStrategy {
private MapperHelper mapperHelper;
@Override
public Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectModelId) {
public Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
String text = queryReq.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
}
@@ -43,18 +44,19 @@ public class QueryMatchStrategy implements MatchStrategy {
List<Integer> offsetList = terms.stream().sorted(Comparator.comparing(Term::getOffset))
.map(term -> term.getOffset()).collect(Collectors.toList());
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectModelId:{}", terms,
regOffsetToLength, offsetList, detectModelId);
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectModelIds:{}", terms,
regOffsetToLength, offsetList, detectModelIds);
List<MapResult> detects = detect(text, regOffsetToLength, offsetList, detectModelId);
List<MapResult> detects = detect(queryReq, regOffsetToLength, offsetList, detectModelIds);
Map<MatchText, List<MapResult>> result = new HashMap<>();
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
return result;
}
private List<MapResult> detect(String text, Map<Integer, Integer> regOffsetToLength, List<Integer> offsetList,
Long detectModelId) {
private List<MapResult> detect(QueryReq queryReq, Map<Integer, Integer> regOffsetToLength, List<Integer> offsetList,
Set<Long> detectModelIds) {
String text = queryReq.getQueryText();
List<MapResult> results = Lists.newArrayList();
for (Integer index = 0; index <= text.length() - 1; ) {
@@ -65,7 +67,7 @@ public class QueryMatchStrategy implements MatchStrategy {
int offset = mapperHelper.getStepOffset(offsetList, index);
i = mapperHelper.getStepIndex(regOffsetToLength, i);
if (i <= text.length()) {
List<MapResult> mapResults = detectByStep(text, detectModelId, index, i, offset);
List<MapResult> mapResults = detectByStep(queryReq, detectModelIds, index, i, offset);
selectMapResultInOneRound(mapResultRowSet, mapResults);
}
}
@@ -102,16 +104,19 @@ public class QueryMatchStrategy implements MatchStrategy {
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
}
private List<MapResult> detectByStep(String text, Long detectModelId, Integer index, Integer i, int offset) {
private List<MapResult> detectByStep(QueryReq queryReq, Set<Long> detectModelIds, Integer index, Integer i,
int offset) {
String text = queryReq.getQueryText();
Integer agentId = queryReq.getAgentId();
String detectSegment = text.substring(index, i);
// step1. pre search
Integer oneDetectionMaxSize = mapperHelper.getOneDetectionMaxSize();
LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize)
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize, agentId,
detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
// step2. suffix search
LinkedHashSet<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionMaxSize)
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
LinkedHashSet<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionMaxSize,
agentId, detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
mapResults.addAll(suffixMapResults);
@@ -121,27 +126,15 @@ public class QueryMatchStrategy implements MatchStrategy {
// step3. merge pre/suffix result
mapResults = mapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toCollection(LinkedHashSet::new));
// step4. filter by classId
if (Objects.nonNull(detectModelId) && detectModelId > 0) {
log.debug("detectModelId:{}, before parseResults:{}", mapResults);
mapResults = mapResults.stream().map(entry -> {
List<String> natures = entry.getNatures().stream().filter(
nature -> nature.startsWith(DictWordType.NATURE_SPILT + detectModelId) || (nature.startsWith(
DictWordType.NATURE_SPILT))
).collect(Collectors.toList());
entry.setNatures(natures);
return entry;
}).collect(Collectors.toCollection(LinkedHashSet::new));
log.info("after modelId parseResults:{}", mapResults);
}
// step5. filter by similarity
// step4. filter by similarity
mapResults = mapResults.stream()
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
>= mapperHelper.getThresholdMatch(term.getNatures()))
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
.collect(Collectors.toCollection(LinkedHashSet::new));
log.debug("after isSimilarity parseResults:{}", mapResults);
log.info("after isSimilarity parseResults:{}", mapResults);
mapResults = mapResults.stream().map(parseResult -> {
parseResult.setOffset(offset);
@@ -149,7 +142,7 @@ public class QueryMatchStrategy implements MatchStrategy {
return parseResult;
}).collect(Collectors.toCollection(LinkedHashSet::new));
// step6. take only one dimension or 10 metric/dimension value per rond.
// step5. take only one dimension or 10 metric/dimension value per rond.
List<MapResult> dimensionMetrics = mapResults.stream()
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
.collect(Collectors.toList())

View File

@@ -2,12 +2,14 @@ package com.tencent.supersonic.chat.mapper;
import com.google.common.collect.Lists;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.service.SearchService;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
@@ -23,9 +25,8 @@ public class SearchMatchStrategy implements MatchStrategy {
private static final int SEARCH_SIZE = 3;
@Override
public Map<MatchText, List<MapResult>> match(String text, List<Term> originals,
Long detectModelId) {
public Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> originals, Set<Long> detectModelIds) {
String text = queryReq.getQueryText();
Map<Integer, Integer> regOffsetToLength = originals.stream()
.filter(entry -> !entry.nature.toString().startsWith(DictWordType.NATURE_SPILT))
.collect(Collectors.toMap(Term::getOffset, value -> value.word.length(),
@@ -51,24 +52,16 @@ public class SearchMatchStrategy implements MatchStrategy {
String detectSegment = text.substring(detectIndex);
if (StringUtils.isNotEmpty(detectSegment)) {
List<MapResult> mapResults = SearchService.prefixSearch(detectSegment);
List<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, SEARCH_SIZE);
List<MapResult> mapResults = SearchService.prefixSearch(detectSegment,
SearchService.SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
List<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, SEARCH_SIZE,
queryReq.getAgentId(), detectModelIds);
mapResults.addAll(suffixMapResults);
// remove entity name where search
mapResults = mapResults.stream().filter(entry -> {
List<String> natures = entry.getNatures().stream()
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType()))
.filter(nature -> {
if (Objects.isNull(detectModelId) || detectModelId <= 0) {
return true;
}
if (nature.startsWith(DictWordType.NATURE_SPILT + detectModelId)
&& nature.startsWith(DictWordType.NATURE_SPILT)) {
return true;
}
return false;
}
).collect(Collectors.toList());
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(natures)) {
return false;
}
@@ -84,4 +77,4 @@ public class SearchMatchStrategy implements MatchStrategy {
);
return regTextMap;
}
}
}

View File

@@ -2,8 +2,9 @@ package com.tencent.supersonic.chat.parser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import lombok.extern.slf4j.Slf4j;
/**
@@ -21,7 +22,7 @@ public class SatisfactionChecker {
// check all the parse info in candidate
public static boolean check(QueryContext queryContext) {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
if (query.getQueryMode().equals(DSLQuery.QUERY_MODE)) {
if (query.getQueryMode().equals(DslQuery.QUERY_MODE)) {
continue;
}
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {
@@ -32,7 +33,7 @@ public class SatisfactionChecker {
}
private static boolean checkThreshold(String queryText, SemanticParseInfo semanticParseInfo) {
int queryTextLength = queryText.length();
int queryTextLength = queryText.replaceAll(" ", "").length();
double degree = semanticParseInfo.getScore() / queryTextLength;
if (queryTextLength > QUERY_TEXT_LENGTH_THRESHOLD) {
if (degree < LONG_TEXT_THRESHOLD) {

View File

@@ -6,15 +6,6 @@ public class DSLDateHelper {
public static String getCurrentDate(Long modelId) {
return DateUtils.getBeforeDate(4);
// ChatConfigFilter filter = new ChatConfigFilter();
// filter.setModelId(modelId);
//
// List<ChatConfigResp> configResps = ContextUtils.getBean(ConfigService.class).search(filter, null);
// if (CollectionUtils.isEmpty(configResps)) {
// return
// }
// ChatConfigResp chatConfigResp = configResps.get(0);
// chatConfigResp.getChatDetailConfig().getChatDefaultConfig().get
}
}

View File

@@ -2,13 +2,21 @@ package com.tencent.supersonic.chat.parser.llm.dsl;
import com.tencent.supersonic.chat.agent.tool.DslTool;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.dsl.LLMResp;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
import com.tencent.supersonic.chat.query.llm.dsl.LLMResp;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class DSLParseResult {
private LLMReq llmReq;
private LLMResp llmResp;
private QueryReq request;

View File

@@ -1,42 +1,51 @@
package com.tencent.supersonic.chat.parser.llm.dsl;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.DslTool;
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.LLMConfig;
import com.tencent.supersonic.chat.corrector.BaseSemanticCorrector;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.parser.function.ModelResolver;
import com.tencent.supersonic.chat.parser.plugin.function.ModelResolver;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
import com.tencent.supersonic.chat.query.dsl.LLMReq;
import com.tencent.supersonic.chat.query.dsl.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.dsl.LLMResp;
import com.tencent.supersonic.chat.query.dsl.optimizer.BaseDSLOptimizer;
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.dsl.LLMResp;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
@@ -49,104 +58,217 @@ import org.springframework.util.CollectionUtils;
import org.springframework.web.client.RestTemplate;
@Slf4j
public class LLMDSLParser implements SemanticParser {
public class LLMDslParser implements SemanticParser {
public static final double FUNCTION_BONUS_THRESHOLD = 201;
public static final double function_bonus_threshold = 201;
@Override
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
final LLMConfig llmConfig = ContextUtils.getBean(LLMConfig.class);
QueryReq request = queryCtx.getRequest();
LLMConfig llmConfig = ContextUtils.getBean(LLMConfig.class);
if (StringUtils.isEmpty(llmConfig.getUrl()) || SatisfactionChecker.check(queryCtx)) {
log.info("llmConfig:{}, skip function parser, queryText:{}", llmConfig,
queryCtx.getRequest().getQueryText());
log.info("llmConfig:{}, skip function parser, queryText:{}", llmConfig, request.getQueryText());
return;
}
List<DslTool> dslTools = getDslTools(queryCtx.getRequest().getAgentId());
Set<Long> distinctModelIds = dslTools.stream().map(DslTool::getModelIds)
.flatMap(Collection::stream)
.collect(Collectors.toSet());
try {
ModelResolver modelResolver = ComponentFactory.getModelResolver();
Long modelId = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
log.info("resolve modelId:{},dslModels:{}", modelId, distinctModelIds);
Long modelId = getModelId(queryCtx, chatCtx, request.getAgentId());
if (Objects.isNull(modelId) || modelId <= 0) {
return;
}
Optional<DslTool> dslToolOptional = dslTools.stream().filter(tool ->
tool.getModelIds().contains(modelId)).findFirst();
if (!dslToolOptional.isPresent()) {
DslTool dslTool = getDslTool(request, modelId);
if (Objects.isNull(dslTool)) {
log.info("no dsl tool in this agent, skip dsl parser");
return;
}
DslTool dslTool = dslToolOptional.get();
LLMResp llmResp = requestLLM(queryCtx, modelId);
LLMReq llmReq = getLlmReq(queryCtx, modelId);
LLMResp llmResp = requestLLM(llmReq, modelId, llmConfig);
if (Objects.isNull(llmResp)) {
return;
}
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(DSLQuery.QUERY_MODE);
DSLParseResult dslParseResult = DSLParseResult.builder().request(request).dslTool(dslTool).llmReq(llmReq)
.llmResp(llmResp).build();
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
if (Objects.nonNull(modelId) && modelId > 0) {
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId));
}
DSLParseResult dslParseResult = new DSLParseResult();
dslParseResult.setRequest(queryCtx.getRequest());
dslParseResult.setLlmResp(llmResp);
dslParseResult.setDslTool(dslToolOptional.get());
SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, dslTool, dslParseResult);
String correctorSql = getCorrectorSql(queryCtx, parseInfo, llmResp.getSqlOutput());
llmResp.setCorrectorSql(correctorSql);
setFilter(correctorSql, modelId, parseInfo);
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, dslParseResult);
properties.put("type", "internal");
properties.put("name", dslTool.getName());
parseInfo.setProperties(properties);
parseInfo.setScore(FUNCTION_BONUS_THRESHOLD);
parseInfo.setQueryMode(semanticQuery.getQueryMode());
SchemaElement Model = new SchemaElement();
Model.setModel(modelId);
Model.setId(modelId);
parseInfo.setModel(Model);
queryCtx.getCandidateQueries().add(semanticQuery);
} catch (Exception e) {
log.error("LLMDSLParser error", e);
}
}
public void setFilter(String correctorSql, Long modelId, SemanticParseInfo parseInfo) {
private LLMResp requestLLM(QueryContext queryCtx, Long modelId) {
long startTime = System.currentTimeMillis();
String queryText = queryCtx.getRequest().getQueryText();
final LLMConfig llmConfig = ContextUtils.getBean(LLMConfig.class);
if (StringUtils.isEmpty(llmConfig.getUrl())) {
log.warn("llmConfig url is null, skip llm parser");
return null;
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
if (CollectionUtils.isEmpty(expressions)) {
return;
}
//set dataInfo
try {
DateConf dateInfo = getDateInfo(expressions);
parseInfo.setDateInfo(dateInfo);
} catch (Exception e) {
log.error("set dateInfo error :", e);
}
//set filter
try {
Map<String, SchemaElement> bizNameToElement = getBizNameToElement(modelId);
List<QueryFilter> result = getDimensionFilter(bizNameToElement, expressions);
parseInfo.getDimensionFilters().addAll(result);
} catch (Exception e) {
log.error("set dimensionFilter error :", e);
}
}
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> bizNameToElement,
List<FilterExpression> filterExpressions) {
List<QueryFilter> result = Lists.newArrayList();
for (FilterExpression expression : filterExpressions) {
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(expression.getFieldValue());
String bizName = expression.getFieldName();
SchemaElement schemaElement = bizNameToElement.get(bizName);
if (Objects.isNull(schemaElement)) {
continue;
}
String fieldName = schemaElement.getName();
dimensionFilter.setName(fieldName);
dimensionFilter.setBizName(bizName);
dimensionFilter.setElementID(schemaElement.getId());
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
dimensionFilter.setOperator(operatorEnum);
result.add(dimensionFilter);
}
return result;
}
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
List<FilterExpression> dateExpressions = filterExpressions.stream()
.filter(expression -> {
List<String> nameList = TimeDimensionEnum.getNameList();
if (StringUtils.isEmpty(expression.getFieldName())) {
return false;
}
return nameList.contains(expression.getFieldName().toLowerCase());
}).collect(Collectors.toList());
if (CollectionUtils.isEmpty(dateExpressions)) {
return new DateConf();
}
DateConf dateInfo = new DateConf();
dateInfo.setDateMode(DateMode.BETWEEN);
FilterExpression firstExpression = dateExpressions.get(0);
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
dateInfo.setDateMode(DateMode.BETWEEN);
return dateInfo;
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
}
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
}
}
return dateInfo;
}
private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator,
FilterOperatorEnum... operatorEnums) {
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
}
private boolean hasSecondDate(List<FilterExpression> dateExpressions) {
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
}
private String getCorrectorSql(QueryContext queryCtx, SemanticParseInfo parseInfo, String sql) {
CorrectionInfo correctionInfo = CorrectionInfo.builder()
.queryFilters(queryCtx.getRequest().getQueryFilters()).sql(sql)
.parseInfo(parseInfo).build();
List<SemanticCorrector> dslCorrections = ComponentFactory.getSqlCorrections();
dslCorrections.forEach(dslCorrection -> {
try {
dslCorrection.corrector(correctionInfo);
log.info("sqlCorrection:{} sql:{}", dslCorrection.getClass().getSimpleName(),
correctionInfo.getSql());
} catch (Exception e) {
log.error("sqlCorrection:{} execute error,correctionInfo:{}", dslCorrection, correctionInfo, e);
}
});
return correctionInfo.getSql();
}
private SemanticParseInfo getParseInfo(QueryContext queryCtx, Long modelId, DslTool dslTool,
DSLParseResult dslParseResult) {
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(DslQuery.QUERY_MODE);
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId));
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, dslParseResult);
properties.put("type", "internal");
properties.put("name", dslTool.getName());
parseInfo.setProperties(properties);
parseInfo.setScore(function_bonus_threshold);
parseInfo.setQueryMode(semanticQuery.getQueryMode());
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
LLMReq llmReq = new LLMReq();
llmReq.setQueryText(queryText);
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
llmSchema.setModelName(modelIdToName.get(modelId));
llmSchema.setDomainName(modelIdToName.get(modelId));
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema);
fieldNameList.add(BaseDSLOptimizer.DATE_FIELD);
llmSchema.setFieldNameList(fieldNameList);
llmReq.setSchema(llmSchema);
List<ElementValue> linking = new ArrayList<>();
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
llmReq.setLinking(linking);
String currentDate = DSLDateHelper.getCurrentDate(modelId);
llmReq.setCurrentDate(currentDate);
SchemaElement model = new SchemaElement();
model.setModel(modelId);
model.setId(modelId);
model.setName(modelIdToName.get(modelId));
parseInfo.setModel(model);
queryCtx.getCandidateQueries().add(semanticQuery);
return parseInfo;
}
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
private DslTool getDslTool(QueryReq request, Long modelId) {
AgentService agentService = ContextUtils.getBean(AgentService.class);
List<DslTool> dslTools = agentService.getDslTools(request.getAgentId(), AgentToolType.DSL);
Optional<DslTool> dslToolOptional = dslTools.stream().filter(tool -> tool.getModelIds().contains(modelId))
.findFirst();
return dslToolOptional.orElse(null);
}
private Long getModelId(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) {
AgentService agentService = ContextUtils.getBean(AgentService.class);
Set<Long> distinctModelIds = agentService.getDslToolsModelIds(agentId, AgentToolType.DSL);
ModelResolver modelResolver = ComponentFactory.getModelResolver();
Long modelId = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
log.info("resolve modelId:{},dslModels:{}", modelId, distinctModelIds);
return modelId;
}
private LLMResp requestLLM(LLMReq llmReq, Long modelId, LLMConfig llmConfig) {
String questUrl = llmConfig.getUrl() + llmConfig.getQueryToSqlPath();
long startTime = System.currentTimeMillis();
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
try {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
@@ -163,6 +285,27 @@ public class LLMDSLParser implements SemanticParser {
return null;
}
private LLMReq getLlmReq(QueryContext queryCtx, Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
String queryText = queryCtx.getRequest().getQueryText();
LLMReq llmReq = new LLMReq();
llmReq.setQueryText(queryText);
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
llmSchema.setModelName(modelIdToName.get(modelId));
llmSchema.setDomainName(modelIdToName.get(modelId));
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema);
fieldNameList.add(BaseSemanticCorrector.DATE_FIELD);
llmSchema.setFieldNameList(fieldNameList);
llmReq.setSchema(llmSchema);
List<ElementValue> linking = new ArrayList<>();
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
llmReq.setLinking(linking);
String currentDate = DSLDateHelper.getCurrentDate(modelId);
llmReq.setCurrentDate(currentDate);
return llmReq;
}
private List<ElementValue> getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
@@ -170,23 +313,37 @@ public class LLMDSLParser implements SemanticParser {
if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>();
}
Set<ElementValue> valueMatches = matchedElements.stream()
Set<ElementValue> valueMatches = matchedElements
.stream()
.filter(elementMatch -> !elementMatch.isInherited())
.filter(schemaElementMatch -> {
SchemaElementType type = schemaElementMatch.getElement().getType();
return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type);
})
.map(elementMatch ->
{
ElementValue elementValue = new ElementValue();
elementValue.setFieldName(itemIdToName.get(elementMatch.getElement().getId()));
elementValue.setFieldValue(elementMatch.getWord());
return elementValue;
}
)
.collect(Collectors.toSet());
.map(elementMatch -> {
ElementValue elementValue = new ElementValue();
elementValue.setFieldName(itemIdToName.get(elementMatch.getElement().getId()));
elementValue.setFieldValue(elementMatch.getWord());
return elementValue;
}).collect(Collectors.toSet());
return new ArrayList<>(valueMatches);
}
protected Map<String, SchemaElement> getBizNameToElement(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dimensions = semanticSchema.getDimensions();
List<SchemaElement> metrics = semanticSchema.getMetrics();
List<SchemaElement> allElements = Lists.newArrayList();
allElements.addAll(dimensions);
allElements.addAll(metrics);
return allElements.stream()
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
.collect(Collectors.toMap(SchemaElement::getBizName, Function.identity(), (value1, value2) -> value2));
}
private List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
@@ -197,9 +354,9 @@ public class LLMDSLParser implements SemanticParser {
Set<String> fieldNameList = matchedElements.stream()
.filter(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.METRIC.equals(elementType) ||
SchemaElementType.DIMENSION.equals(elementType) ||
SchemaElementType.VALUE.equals(elementType);
return SchemaElementType.METRIC.equals(elementType)
|| SchemaElementType.DIMENSION.equals(elementType)
|| SchemaElementType.VALUE.equals(elementType);
})
.map(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
@@ -220,18 +377,4 @@ public class LLMDSLParser implements SemanticParser {
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
}
private List<DslTool> getDslTools(Integer agentId) {
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent agent = agentService.getAgent(agentId);
if (agent == null) {
return Lists.newArrayList();
}
List<String> tools = agent.getTools(AgentToolType.DSL);
if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList();
}
return tools.stream().map(tool -> JSONObject.parseObject(tool, DslTool.class))
.collect(Collectors.toList());
}
}

View File

@@ -7,11 +7,17 @@ import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.query.metricInterpret.MetricInterpretQuery;
import com.tencent.supersonic.chat.query.metricinterpret.MetricInterpretQuery;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.AgentService;
@@ -22,7 +28,11 @@ import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.*;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.HashMap;
import java.util.stream.Collectors;
@Slf4j
@@ -34,7 +44,8 @@ public class MetricInterpretParser implements SemanticParser {
log.info("skip MetricInterpretParser");
return;
}
Map<Long, MetricInterpretTool> metricInterpretToolMap = getMetricInterpretTools(queryContext.getRequest().getAgentId());
Map<Long, MetricInterpretTool> metricInterpretToolMap =
getMetricInterpretTools(queryContext.getRequest().getAgentId());
log.info("metric interpret tool : {}", metricInterpretToolMap);
if (CollectionUtils.isEmpty(metricInterpretToolMap)) {
return;
@@ -50,8 +61,10 @@ public class MetricInterpretParser implements SemanticParser {
}
List<MetricOption> metricOptions = metricInterpretTool.getMetricOptions();
if (!CollectionUtils.isEmpty(metricOptions)) {
List<Long> metricIds = metricOptions.stream().map(MetricOption::getMetricId).collect(Collectors.toList());
buildQuery(modelId, queryContext, metricIds, elementMatches.get(modelId), metricInterpretTool.getName());
List<Long> metricIds = metricOptions.stream()
.map(MetricOption::getMetricId).collect(Collectors.toList());
String name = metricInterpretTool.getName();
buildQuery(modelId, queryContext, metricIds, elementMatches.get(modelId), name);
}
}
}
@@ -82,7 +95,7 @@ public class MetricInterpretParser implements SemanticParser {
if (agent == null) {
return new HashMap<>();
}
List<String> tools= agent.getTools(AgentToolType.INTERPRET);
List<String> tools = agent.getTools(AgentToolType.INTERPRET);
if (CollectionUtils.isEmpty(tools)) {
return new HashMap<>();
}
@@ -100,16 +113,16 @@ public class MetricInterpretParser implements SemanticParser {
private SemanticParseInfo buildSemanticParseInfo(Long modelId, QueryReq queryReq, Set<SchemaElement> metrics,
List<SchemaElementMatch> schemaElementMatches, String toolName) {
SchemaElement Model = new SchemaElement();
Model.setModel(modelId);
Model.setId(modelId);
SchemaElement model = new SchemaElement();
model.setModel(modelId);
model.setId(modelId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setMetrics(metrics);
SchemaElement dimension = new SchemaElement();
dimension.setBizName(TimeDimensionEnum.DAY.getName());
semanticParseInfo.setDimensions(Sets.newHashSet(dimension));
semanticParseInfo.setElementMatches(schemaElementMatches);
semanticParseInfo.setModel(Model);
semanticParseInfo.setModel(model);
semanticParseInfo.setScore(queryReq.getQueryText().length());
DateConf dateConf = new DateConf();
dateConf.setDateMode(DateConf.DateMode.RECENT);

View File

@@ -17,7 +17,7 @@ public class LLMTimeEnhancementParse implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
log.info("before queryContext:{},chatContext:{}",queryContext,chatContext);
log.info("before queryContext:{},chatContext:{}", queryContext, chatContext);
ChatGptHelper chatGptHelper = ContextUtils.getBean(ChatGptHelper.class);
try {
String inferredTime = chatGptHelper.inferredTime(queryContext.getRequest().getQueryText());
@@ -25,12 +25,12 @@ public class LLMTimeEnhancementParse implements SemanticParser {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
DateConf dateInfo = query.getParseInfo().getDateInfo();
JSONObject jsonObject = JSON.parseObject(inferredTime);
if (jsonObject.containsKey("date")){
if (jsonObject.containsKey("date")) {
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
dateInfo.setStartDate(jsonObject.getString("date"));
dateInfo.setEndDate(jsonObject.getString("date"));
query.getParseInfo().setDateInfo(dateInfo);
}else if (jsonObject.containsKey("start")){
} else if (jsonObject.containsKey("start")) {
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
dateInfo.setStartDate(jsonObject.getString("start"));
dateInfo.setEndDate(jsonObject.getString("end"));
@@ -38,11 +38,13 @@ public class LLMTimeEnhancementParse implements SemanticParser {
}
}
}
}catch (Exception exception){
log.error("{} parse error,this reason is:{}",LLMTimeEnhancementParse.class.getSimpleName(), (Object) exception.getStackTrace());
} catch (Exception exception) {
log.error("{} parse error,this reason is:{}", LLMTimeEnhancementParse.class.getSimpleName(),
(Object) exception.getStackTrace());
}
log.info("{} after queryContext:{},chatContext:{}",LLMTimeEnhancementParse.class.getSimpleName(),queryContext,chatContext);
log.info("{} after queryContext:{},chatContext:{}",
LLMTimeEnhancementParse.class.getSimpleName(), queryContext, chatContext);
}

View File

@@ -1,8 +1,14 @@
package com.tencent.supersonic.chat.parser.embedding;
package com.tencent.supersonic.chat.parser.plugin.embedding;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.ParseMode;
@@ -10,12 +16,16 @@ import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.*;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.HashMap;
import java.util.Comparator;
import java.util.stream.Collectors;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import lombok.extern.slf4j.Slf4j;
@@ -47,7 +57,7 @@ public class EmbeddingBasedParser implements SemanticParser {
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
if (plugin == null || DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
if (plugin == null || DslQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
continue;
}
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
@@ -88,12 +98,12 @@ public class EmbeddingBasedParser implements SemanticParser {
if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) {
modelId = plugin.getModelList().get(0);
}
SchemaElement Model = new SchemaElement();
Model.setModel(modelId);
Model.setId(modelId);
SchemaElement model = new SchemaElement();
model.setModel(modelId);
model.setId(modelId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setElementMatches(schemaElementMatches);
semanticParseInfo.setModel(Model);
semanticParseInfo.setModel(model);
Map<String, Object> properties = new HashMap<>();
PluginParseResult pluginParseResult = new PluginParseResult();
pluginParseResult.setPlugin(plugin);
@@ -111,9 +121,9 @@ public class EmbeddingBasedParser implements SemanticParser {
private void setEntity(Long modelId, SemanticParseInfo semanticParseInfo) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
ModelSchema ModelSchema = semanticService.getModelSchema(modelId);
if (ModelSchema != null && ModelSchema.getEntity() != null) {
semanticParseInfo.setEntity(ModelSchema.getEntity());
ModelSchema modelSchema = semanticService.getModelSchema(modelId);
if (modelSchema != null && modelSchema.getEntity() != null) {
semanticParseInfo.setEntity(modelSchema.getEntity());
}
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.embedding;
package com.tencent.supersonic.chat.parser.plugin.embedding;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;

View File

@@ -1,26 +1,26 @@
package com.tencent.supersonic.chat.parser.embedding;
package com.tencent.supersonic.chat.parser.plugin.embedding;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.service.ConfigService;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.Objects;
import java.util.Set;
@Slf4j
@Component("EmbeddingEntityResolver")
public class EmbeddingEntityResolver {
private ConfigService configService;
public EmbeddingEntityResolver(ConfigService configService) {
@@ -39,8 +39,8 @@ public class EmbeddingEntityResolver {
}
}
entityId = getEntityValueFromSchemaMapInfo(modelId, queryCtx.getMapInfo(), entityElementId);
log.info("get entity id:{} from schema map Info :{} ", entityId,
JSONObject.toJSONString(queryCtx.getMapInfo()));
log.info("get entity id:{} from schema map Info :{} ",
entityId, JSONObject.toJSONString(queryCtx.getMapInfo()));
if (entityId == null || entityId == 0) {
Long entityIdFromChat = getEntityValueFromParseInfo(chatCtx.getParseInfo(), entityElementId);
if (entityIdFromChat != null && entityIdFromChat > 0) {
@@ -95,4 +95,4 @@ public class EmbeddingEntityResolver {
return null;
}
}
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.embedding;
package com.tencent.supersonic.chat.parser.plugin.embedding;
import java.util.List;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.embedding;
package com.tencent.supersonic.chat.parser.plugin.embedding;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.function;
package com.tencent.supersonic.chat.parser.plugin.function;
import com.alibaba.fastjson.JSON;
import com.tencent.supersonic.chat.api.component.SemanticParser;
@@ -15,13 +15,18 @@ import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import java.net.URI;
import java.util.*;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.Objects;
import java.util.Map;
import java.util.HashMap;
import java.util.stream.Collectors;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
@@ -39,11 +44,6 @@ import org.springframework.web.util.UriComponentsBuilder;
@Slf4j
public class FunctionBasedParser implements SemanticParser {
public static final double FUNCTION_BONUS_THRESHOLD = 200;
public static final double SKIP_DSL_LENGTH = 10;
@Override
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
FunctionCallInfoConfig functionCallConfig = ContextUtils.getBean(FunctionCallInfoConfig.class);
@@ -59,12 +59,17 @@ public class FunctionBasedParser implements SemanticParser {
log.info("function call parser, plugin is empty, skip");
return;
}
FunctionReq functionReq = FunctionReq.builder()
.queryText(queryCtx.getRequest().getQueryText())
.pluginConfigs(functionDOList).build();
FunctionResp functionResp = requestFunction(functionUrl, functionReq);
FunctionResp functionResp = new FunctionResp();
if (functionDOList.size() == 1) {
functionResp.setToolSelection(functionDOList.iterator().next().getName());
} else {
FunctionReq functionReq = FunctionReq.builder()
.queryText(queryCtx.getRequest().getQueryText())
.pluginConfigs(functionDOList).build();
functionResp = requestFunction(functionUrl, functionReq);
}
log.info("requestFunction result:{}", functionResp.getToolSelection());
if (skipFunction(queryCtx, functionResp)) {
if (skipFunction(functionResp)) {
return;
}
PluginParseResult functionCallParseResult = new PluginParseResult();
@@ -80,10 +85,10 @@ public class FunctionBasedParser implements SemanticParser {
functionCallParseResult.setPlugin(plugin);
log.info("QueryManager PluginQueryModes:{}", QueryManager.getPluginQueryModes());
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(toolSelection);
ModelResolver ModelResolver = ComponentFactory.getModelResolver();
ModelResolver modelResolver = ComponentFactory.getModelResolver();
log.info("plugin ModelList:{}", plugin.getModelList());
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryCtx);
Long modelId = ModelResolver.resolve(queryCtx, chatCtx, pluginResolveResult.getRight());
Long modelId = modelResolver.resolve(queryCtx, chatCtx, pluginResolveResult.getRight());
log.info("FunctionBasedParser modelId:{}", modelId);
if ((Objects.isNull(modelId) || modelId <= 0) && !plugin.isContainsAllModel()) {
log.info("Model is null, skip the parse, select tool: {}", toolSelection);
@@ -102,35 +107,24 @@ public class FunctionBasedParser implements SemanticParser {
properties.put("type", "plugin");
properties.put("name", plugin.getName());
parseInfo.setProperties(properties);
parseInfo.setScore(FUNCTION_BONUS_THRESHOLD);
parseInfo.setScore(queryCtx.getRequest().getQueryText().length());
parseInfo.setQueryMode(semanticQuery.getQueryMode());
SchemaElement Model = new SchemaElement();
Model.setModel(modelId);
Model.setId(modelId);
parseInfo.setModel(Model);
SchemaElement model = new SchemaElement();
model.setModel(modelId);
model.setId(modelId);
parseInfo.setModel(model);
queryCtx.getCandidateQueries().add(semanticQuery);
}
private boolean skipFunction(QueryContext queryCtx, FunctionResp functionResp) {
if (Objects.isNull(functionResp) || StringUtils.isBlank(functionResp.getToolSelection())) {
return true;
}
String queryText = queryCtx.getRequest().getQueryText();
if (functionResp.getToolSelection().equalsIgnoreCase(DSLQuery.QUERY_MODE)
&& queryText.length() < SKIP_DSL_LENGTH) {
log.info("queryText length is :{}, less than the threshold :{}, skip dsl.", queryText.length(),
SKIP_DSL_LENGTH);
return true;
}
return false;
private boolean skipFunction(FunctionResp functionResp) {
return Objects.isNull(functionResp) || StringUtils.isBlank(functionResp.getToolSelection());
}
private List<PluginParseConfig> getFunctionDO(Long modelId, QueryContext queryContext) {
log.info("user decide Model:{}", modelId);
List<Plugin> plugins = getPluginList(queryContext);
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
if (DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
if (DslQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
return false;
}
if (plugin.getParseModeConfig() == null) {

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.function;
package com.tencent.supersonic.chat.parser.plugin.function;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.function;
package com.tencent.supersonic.chat.parser.plugin.function;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import java.util.List;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.function;
package com.tencent.supersonic.chat.parser.plugin.function;
import lombok.Data;

View File

@@ -1,29 +1,40 @@
package com.tencent.supersonic.chat.parser.function;
package com.tencent.supersonic.chat.parser.plugin.function;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import lombok.extern.slf4j.Slf4j;
import java.util.*;
import java.util.Map;
import java.util.HashMap;
import java.util.Objects;
import java.util.List;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public class HeuristicModelResolver implements ModelResolver {
protected static Long selectModelBySchemaElementCount(Map<Long, SemanticQuery> ModelQueryModes,
SchemaMapInfo schemaMap) {
Map<Long, ModelMatchResult> ModelTypeMap = getModelTypeMap(schemaMap);
if (ModelTypeMap.size() == 1) {
Long ModelSelect = ModelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
if (ModelQueryModes.containsKey(ModelSelect)) {
log.info("selectModel with only one Model [{}]", ModelSelect);
return ModelSelect;
protected static Long selectModelBySchemaElementCount(Map<Long, SemanticQuery> modelQueryModes,
SchemaMapInfo schemaMap) {
Map<Long, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
if (modelTypeMap.size() == 1) {
Long modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
if (modelQueryModes.containsKey(modelSelect)) {
log.info("selectModel with only one Model [{}]", modelSelect);
return modelSelect;
}
} else {
Map.Entry<Long, ModelMatchResult> maxModel = ModelTypeMap.entrySet().stream()
.filter(entry -> ModelQueryModes.containsKey(entry.getKey()))
Map.Entry<Long, ModelMatchResult> maxModel = modelTypeMap.entrySet().stream()
.filter(entry -> modelQueryModes.containsKey(entry.getKey()))
.sorted((o1, o2) -> {
int difference = o2.getValue().getCount() - o1.getValue().getCount();
if (difference == 0) {
@@ -45,23 +56,24 @@ public class HeuristicModelResolver implements ModelResolver {
*
* @return false will use context Model, true will use other Model , maybe include context Model
*/
protected static boolean isAllowSwitch(Map<Long, SemanticQuery> ModelQueryModes, SchemaMapInfo schemaMap,
ChatContext chatCtx, QueryReq searchCtx, Long modelId, Set<Long> restrictiveModels) {
protected static boolean isAllowSwitch(Map<Long, SemanticQuery> modelQueryModes, SchemaMapInfo schemaMap,
ChatContext chatCtx, QueryReq searchCtx,
Long modelId, Set<Long> restrictiveModels) {
if (!Objects.nonNull(modelId) || modelId <= 0) {
return true;
}
// except content Model, calculate the number of types for each Model, if numbers<=1 will not switch
Map<Long, ModelMatchResult> ModelTypeMap = getModelTypeMap(schemaMap);
log.info("isAllowSwitch ModelTypeMap [{}]", ModelTypeMap);
long otherModelTypeNumBigOneCount = ModelTypeMap.entrySet().stream()
.filter(entry -> ModelQueryModes.containsKey(entry.getKey()) && !entry.getKey().equals(modelId))
Map<Long, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
log.info("isAllowSwitch ModelTypeMap [{}]", modelTypeMap);
long otherModelTypeNumBigOneCount = modelTypeMap.entrySet().stream()
.filter(entry -> modelQueryModes.containsKey(entry.getKey()) && !entry.getKey().equals(modelId))
.filter(entry -> entry.getValue().getCount() > 1).count();
if (otherModelTypeNumBigOneCount >= 1) {
return true;
}
// if query text only contain time , will not switch
if (!CollectionUtils.isEmpty(ModelQueryModes.values())) {
for (SemanticQuery semanticQuery : ModelQueryModes.values()) {
if (!CollectionUtils.isEmpty(modelQueryModes.values())) {
for (SemanticQuery semanticQuery : modelQueryModes.values()) {
if (semanticQuery == null) {
continue;
}
@@ -71,7 +83,8 @@ public class HeuristicModelResolver implements ModelResolver {
}
if (searchCtx.getQueryText() != null && semanticParseInfo.getDateInfo() != null) {
if (semanticParseInfo.getDateInfo().getDetectWord() != null) {
if (semanticParseInfo.getDateInfo().getDetectWord().equalsIgnoreCase(searchCtx.getQueryText())) {
if (semanticParseInfo.getDateInfo().getDetectWord()
.equalsIgnoreCase(searchCtx.getQueryText())) {
log.info("timeParseResults is not null , can not switch context , timeParseResults:{},",
semanticParseInfo.getDateInfo());
return false;
@@ -94,14 +107,14 @@ public class HeuristicModelResolver implements ModelResolver {
}
public static Map<Long, ModelMatchResult> getModelTypeMap(SchemaMapInfo schemaMap) {
Map<Long, ModelMatchResult> ModelCount = new HashMap<>();
Map<Long, ModelMatchResult> modelCount = new HashMap<>();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getModelElementMatches().entrySet()) {
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
if (!ModelCount.containsKey(entry.getKey())) {
ModelCount.put(entry.getKey(), new ModelMatchResult());
if (!modelCount.containsKey(entry.getKey())) {
modelCount.put(entry.getKey(), new ModelMatchResult());
}
ModelMatchResult ModelMatchResult = ModelCount.get(entry.getKey());
ModelMatchResult modelMatchResult = modelCount.get(entry.getKey());
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
schemaElementMatches.stream()
.forEach(schemaElementMatch -> schemaElementTypes.add(
@@ -111,13 +124,13 @@ public class HeuristicModelResolver implements ModelResolver {
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
).findFirst().orElse(null);
if (schemaElementMatchMax != null) {
ModelMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
modelMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
}
ModelMatchResult.setCount(schemaElementTypes.size());
modelMatchResult.setCount(schemaElementTypes.size());
}
}
return ModelCount;
return modelCount;
}
@@ -137,40 +150,41 @@ public class HeuristicModelResolver implements ModelResolver {
.filter(restrictiveModels::contains)
.collect(Collectors.toSet());
}
Map<Long, SemanticQuery> ModelQueryModes = new HashMap<>();
Map<Long, SemanticQuery> modelQueryModes = new HashMap<>();
for (Long matchedModel : matchedModels) {
ModelQueryModes.put(matchedModel, null);
modelQueryModes.put(matchedModel, null);
}
if(ModelQueryModes.size()==1){
return ModelQueryModes.keySet().stream().findFirst().get();
if (modelQueryModes.size() == 1) {
return modelQueryModes.keySet().stream().findFirst().get();
}
return resolve(ModelQueryModes, queryContext, chatCtx,
queryContext.getMapInfo(),restrictiveModels);
return resolve(modelQueryModes, queryContext, chatCtx,
queryContext.getMapInfo(), restrictiveModels);
}
public Long resolve(Map<Long, SemanticQuery> ModelQueryModes, QueryContext queryContext,
ChatContext chatCtx, SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
Long selectModel = selectModel(ModelQueryModes, queryContext.getRequest(), chatCtx, schemaMap,restrictiveModels);
public Long resolve(Map<Long, SemanticQuery> modelQueryModes, QueryContext queryContext,
ChatContext chatCtx, SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
Long selectModel = selectModel(modelQueryModes, queryContext.getRequest(),
chatCtx, schemaMap, restrictiveModels);
if (selectModel > 0) {
log.info("selectModel {} ", selectModel);
return selectModel;
}
// get the max SchemaElementType number
return selectModelBySchemaElementCount(ModelQueryModes, schemaMap);
return selectModelBySchemaElementCount(modelQueryModes, schemaMap);
}
public Long selectModel(Map<Long, SemanticQuery> ModelQueryModes, QueryReq queryContext,
ChatContext chatCtx,
SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
public Long selectModel(Map<Long, SemanticQuery> modelQueryModes, QueryReq queryContext,
ChatContext chatCtx,
SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
// if QueryContext has modelId and in ModelQueryModes
if (ModelQueryModes.containsKey(queryContext.getModelId())) {
if (modelQueryModes.containsKey(queryContext.getModelId())) {
log.info("selectModel from QueryContext [{}]", queryContext.getModelId());
return queryContext.getModelId();
}
// if ChatContext has modelId and in ModelQueryModes
if (chatCtx.getParseInfo().getModelId() > 0) {
Long modelId = chatCtx.getParseInfo().getModelId();
if (!isAllowSwitch(ModelQueryModes, schemaMap, chatCtx, queryContext, modelId,restrictiveModels)) {
if (!isAllowSwitch(modelQueryModes, schemaMap, chatCtx, queryContext, modelId, restrictiveModels)) {
log.info("selectModel from ChatContext [{}]", modelId);
return modelId;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.function;
package com.tencent.supersonic.chat.parser.plugin.function;
import lombok.Data;

View File

@@ -1,13 +1,12 @@
package com.tencent.supersonic.chat.parser.function;
package com.tencent.supersonic.chat.parser.plugin.function;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import java.util.List;
import java.util.Set;
public interface ModelResolver {
Long resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels);
}
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.function;
package com.tencent.supersonic.chat.parser.plugin.function;
import java.util.List;
import java.util.Map;

View File

@@ -13,7 +13,6 @@ import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
@@ -32,16 +31,24 @@ public class AgentCheckParser implements SemanticParser {
if (agent == null) {
return;
}
List<String> queryModes = getRuleTools(agentId).stream().map(RuleQueryTool::getQueryModes)
.flatMap(Collection::stream).collect(Collectors.toList());
if (CollectionUtils.isEmpty(queries)) {
List<RuleQueryTool> queryTools = getRuleTools(agentId);
if (CollectionUtils.isEmpty(queryTools)) {
queries.clear();
return;
}
log.info("queries resolved:{} {}", agent.getName(),
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
queries.removeIf(query ->
!queryModes.contains(query.getQueryMode()));
queries.removeIf(query -> {
for (RuleQueryTool tool : queryTools) {
if (!tool.getQueryModes().contains(query.getQueryMode())) {
return true;
}
if (tool.isContainsAllModel() || tool.getModelIds().contains(query.getParseInfo().getModelId())) {
return false;
}
}
return true;
});
log.info("rule queries witch can be supported by agent :{} {}", agent.getName(),
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
}

View File

@@ -14,6 +14,7 @@ import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import java.util.AbstractMap;
import java.util.HashMap;
import java.util.Map;
@@ -21,6 +22,7 @@ import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;

View File

@@ -1,12 +1,5 @@
package com.tencent.supersonic.chat.parser.rule;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.MODEL;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
@@ -15,8 +8,8 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
import java.util.AbstractMap;
import java.util.ArrayList;
@@ -28,6 +21,13 @@ import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.extern.slf4j.Slf4j;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.MODEL;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
@Slf4j
public class ContextInheritParser implements SemanticParser {

View File

@@ -1,9 +1,12 @@
package com.tencent.supersonic.chat.parser.rule;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import java.util.*;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
@Slf4j

View File

@@ -4,21 +4,23 @@ import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.xkzhangsan.time.nlp.TimeNLP;
import com.xkzhangsan.time.nlp.TimeNLPUtil;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.time.LocalDate;
import java.util.Stack;
import java.util.Date;
import java.util.List;
import java.util.Stack;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import com.xkzhangsan.time.nlp.TimeNLP;
import com.xkzhangsan.time.nlp.TimeNLPUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;

View File

@@ -4,17 +4,14 @@ import java.util.Date;
public class AgentDO {
/**
*
*/
private Integer id;
/**
*
*/
private String name;
/**
*
*/
private String description;
@@ -24,83 +21,70 @@ public class AgentDO {
private Integer status;
/**
*
*/
private String examples;
/**
*
*/
private String config;
/**
*
*/
private String createdBy;
/**
*
*/
private Date createdAt;
/**
*
*/
private String updatedBy;
/**
*
*/
private Date updatedAt;
/**
*
*/
private Integer enableSearch;
/**
*
* @return id
* @return id
*/
public Integer getId() {
return id;
}
/**
*
* @param id
* @param id
*/
public void setId(Integer id) {
this.id = id;
}
/**
*
* @return name
* @return name
*/
public String getName() {
return name;
}
/**
*
* @param name
* @param name
*/
public void setName(String name) {
this.name = name == null ? null : name.trim();
}
/**
*
* @return description
* @return description
*/
public String getDescription() {
return description;
}
/**
*
* @param description
* @param description
*/
public void setDescription(String description) {
this.description = description == null ? null : description.trim();
@@ -123,114 +107,100 @@ public class AgentDO {
}
/**
*
* @return examples
* @return examples
*/
public String getExamples() {
return examples;
}
/**
*
* @param examples
* @param examples
*/
public void setExamples(String examples) {
this.examples = examples == null ? null : examples.trim();
}
/**
*
* @return config
* @return config
*/
public String getConfig() {
return config;
}
/**
*
* @param config
* @param config
*/
public void setConfig(String config) {
this.config = config == null ? null : config.trim();
}
/**
*
* @return created_by
* @return created_by
*/
public String getCreatedBy() {
return createdBy;
}
/**
*
* @param createdBy
* @param createdBy
*/
public void setCreatedBy(String createdBy) {
this.createdBy = createdBy == null ? null : createdBy.trim();
}
/**
*
* @return created_at
* @return created_at
*/
public Date getCreatedAt() {
return createdAt;
}
/**
*
* @param createdAt
* @param createdAt
*/
public void setCreatedAt(Date createdAt) {
this.createdAt = createdAt;
}
/**
*
* @return updated_by
* @return updated_by
*/
public String getUpdatedBy() {
return updatedBy;
}
/**
*
* @param updatedBy
* @param updatedBy
*/
public void setUpdatedBy(String updatedBy) {
this.updatedBy = updatedBy == null ? null : updatedBy.trim();
}
/**
*
* @return updated_at
* @return updated_at
*/
public Date getUpdatedAt() {
return updatedAt;
}
/**
*
* @param updatedAt
* @param updatedAt
*/
public void setUpdatedAt(Date updatedAt) {
this.updatedAt = updatedAt;
}
/**
*
* @return enable_search
* @return enable_search
*/
public Integer getEnableSearch() {
return enableSearch;
}
/**
*
* @param enableSearch
* @param enableSearch
*/
public void setEnableSearch(Integer enableSearch) {
this.enableSearch = enableSearch;
}
}
}

View File

@@ -31,7 +31,6 @@ public class AgentDOExample {
protected Integer limitEnd;
/**
*
* @mbg.generated
*/
public AgentDOExample() {
@@ -39,7 +38,6 @@ public class AgentDOExample {
}
/**
*
* @mbg.generated
*/
public void setOrderByClause(String orderByClause) {
@@ -47,7 +45,6 @@ public class AgentDOExample {
}
/**
*
* @mbg.generated
*/
public String getOrderByClause() {
@@ -55,7 +52,6 @@ public class AgentDOExample {
}
/**
*
* @mbg.generated
*/
public void setDistinct(boolean distinct) {
@@ -63,7 +59,6 @@ public class AgentDOExample {
}
/**
*
* @mbg.generated
*/
public boolean isDistinct() {
@@ -71,7 +66,6 @@ public class AgentDOExample {
}
/**
*
* @mbg.generated
*/
public List<Criteria> getOredCriteria() {
@@ -79,7 +73,6 @@ public class AgentDOExample {
}
/**
*
* @mbg.generated
*/
public void or(Criteria criteria) {
@@ -87,7 +80,6 @@ public class AgentDOExample {
}
/**
*
* @mbg.generated
*/
public Criteria or() {
@@ -97,7 +89,6 @@ public class AgentDOExample {
}
/**
*
* @mbg.generated
*/
public Criteria createCriteria() {
@@ -109,7 +100,6 @@ public class AgentDOExample {
}
/**
*
* @mbg.generated
*/
protected Criteria createCriteriaInternal() {
@@ -118,7 +108,6 @@ public class AgentDOExample {
}
/**
*
* @mbg.generated
*/
public void clear() {
@@ -128,15 +117,13 @@ public class AgentDOExample {
}
/**
*
* @mbg.generated
*/
public void setLimitStart(Integer limitStart) {
this.limitStart=limitStart;
this.limitStart = limitStart;
}
/**
*
* @mbg.generated
*/
public Integer getLimitStart() {
@@ -144,15 +131,13 @@ public class AgentDOExample {
}
/**
*
* @mbg.generated
*/
public void setLimitEnd(Integer limitEnd) {
this.limitEnd=limitEnd;
this.limitEnd = limitEnd;
}
/**
*
* @mbg.generated
*/
public Integer getLimitEnd() {
@@ -954,38 +939,6 @@ public class AgentDOExample {
private String typeHandler;
public String getCondition() {
return condition;
}
public Object getValue() {
return value;
}
public Object getSecondValue() {
return secondValue;
}
public boolean isNoValue() {
return noValue;
}
public boolean isSingleValue() {
return singleValue;
}
public boolean isBetweenValue() {
return betweenValue;
}
public boolean isListValue() {
return listValue;
}
public String getTypeHandler() {
return typeHandler;
}
protected Criterion(String condition) {
super();
this.condition = condition;
@@ -1021,5 +974,37 @@ public class AgentDOExample {
protected Criterion(String condition, Object value, Object secondValue) {
this(condition, value, secondValue, null);
}
public String getCondition() {
return condition;
}
public Object getValue() {
return value;
}
public Object getSecondValue() {
return secondValue;
}
public boolean isNoValue() {
return noValue;
}
public boolean isSingleValue() {
return singleValue;
}
public boolean isBetweenValue() {
return betweenValue;
}
public boolean isListValue() {
return listValue;
}
public String getTypeHandler() {
return typeHandler;
}
}
}
}

View File

@@ -0,0 +1,142 @@
package com.tencent.supersonic.chat.persistence.dataobject;
import java.util.Date;
public class ChatParseDO {
/**
* questionId
*/
private Long questionId;
/**
* chatId
*/
private Long chatId;
/**
* parseId
*/
private Integer parseId;
/**
* createTime
*/
private Date createTime;
/**
* queryText
*/
private String queryText;
/**
* userName
*/
private String userName;
/**
* parseInfo
*/
private String parseInfo;
/**
* isCandidate
*/
private Integer isCandidate;
/**
* return question_id
*/
public Long getQuestionId() {
return questionId;
}
/**
* questionId
*/
public void setQuestionId(Long questionId) {
this.questionId = questionId;
}
/**
* return create_time
*/
public Date getCreateTime() {
return createTime;
}
/**
* createTime
*/
public void setCreateTime(Date createTime) {
this.createTime = createTime;
}
/**
* return user_name
*/
public String getUserName() {
return userName;
}
/**
* userName
*/
public void setUserName(String userName) {
this.userName = userName == null ? null : userName.trim();
}
/**
* return chat_id
*/
public Long getChatId() {
return chatId;
}
/**
* chatId
*/
public void setChatId(Long chatId) {
this.chatId = chatId;
}
/**
* return query_text
*/
public String getQueryText() {
return queryText;
}
/**
* queryText
*/
public void setQueryText(String queryText) {
this.queryText = queryText == null ? null : queryText.trim();
}
public Integer getIsCandidate() {
return isCandidate;
}
public Integer getParseId() {
return parseId;
}
public String getParseInfo() {
return parseInfo;
}
public void setParseId(Integer parseId) {
this.parseId = parseId;
}
public void setIsCandidate(Integer isCandidate) {
this.isCandidate = isCandidate;
}
public void setParseInfo(String parseInfo) {
this.parseInfo = parseInfo;
}
}

View File

@@ -2,177 +2,196 @@ package com.tencent.supersonic.chat.persistence.dataobject;
import java.util.Date;
public class ChatQueryDO {
/**
* questionId
*/
private Long questionId;
/**
* createTime
*/
private Integer agentId;
/**
*/
private Date createTime;
/**
* userName
*/
private String userName;
/**
* queryState
*/
private Integer queryState;
/**
* chatId
*/
private Long chatId;
/**
* score
*/
private Integer score;
/**
* feedback
*/
private String feedback;
/**
* queryText
*/
private String queryText;
/**
* queryResponse
*/
private String queryResult;
/**
* return question_id
* @return question_id
*/
public Long getQuestionId() {
return questionId;
}
/**
* questionId
* @param questionId
*/
public void setQuestionId(Long questionId) {
this.questionId = questionId;
}
/**
* return create_time
* @return agent_id
*/
public Integer getAgentId() {
return agentId;
}
/**
* @param agentId
*/
public void setAgentId(Integer agentId) {
this.agentId = agentId;
}
/**
* @return create_time
*/
public Date getCreateTime() {
return createTime;
}
/**
* createTime
* @param createTime
*/
public void setCreateTime(Date createTime) {
this.createTime = createTime;
}
/**
* return user_name
* @return user_name
*/
public String getUserName() {
return userName;
}
/**
* userName
* @param userName
*/
public void setUserName(String userName) {
this.userName = userName == null ? null : userName.trim();
}
/**
* return query_state
*
* @return query_state
*/
public Integer getQueryState() {
return queryState;
}
/**
* queryState
*
* @param queryState
*/
public void setQueryState(Integer queryState) {
this.queryState = queryState;
}
/**
* return chat_id
*
* @return chat_id
*/
public Long getChatId() {
return chatId;
}
/**
* chatId
*
* @param chatId
*/
public void setChatId(Long chatId) {
this.chatId = chatId;
}
/**
* return score
*
* @return score
*/
public Integer getScore() {
return score;
}
/**
* score
*
* @param score
*/
public void setScore(Integer score) {
this.score = score;
}
/**
* return feedback
*
* @return feedback
*/
public String getFeedback() {
return feedback;
}
/**
* feedback
*
* @param feedback
*/
public void setFeedback(String feedback) {
this.feedback = feedback == null ? null : feedback.trim();
}
/**
* return query_text
*
* @return query_text
*/
public String getQueryText() {
return queryText;
}
/**
* queryText
*
* @param queryText
*/
public void setQueryText(String queryText) {
this.queryText = queryText == null ? null : queryText.trim();
}
/**
* return query_response
*
* @return query_result
*/
public String getQueryResult() {
return queryResult;
}
/**
* queryResponse
*
* @param queryResult
*/
public void setQueryResult(String queryResult) {
this.queryResult = queryResult == null ? null : queryResult.trim();
}
}
}

View File

@@ -5,47 +5,92 @@ import java.util.Date;
import java.util.List;
public class ChatQueryDOExample {
/**
* s2_chat_query
*/
protected String orderByClause;
/**
* s2_chat_query
*/
protected boolean distinct;
/**
* s2_chat_query
*/
protected List<Criteria> oredCriteria;
/**
* s2_chat_query
*/
protected Integer limitStart;
/**
* s2_chat_query
*/
protected Integer limitEnd;
/**
* @mbg.generated
*/
public ChatQueryDOExample() {
oredCriteria = new ArrayList<Criteria>();
}
public String getOrderByClause() {
return orderByClause;
}
/**
* @mbg.generated
*/
public void setOrderByClause(String orderByClause) {
this.orderByClause = orderByClause;
}
public boolean isDistinct() {
return distinct;
/**
* @mbg.generated
*/
public String getOrderByClause() {
return orderByClause;
}
/**
* @mbg.generated
*/
public void setDistinct(boolean distinct) {
this.distinct = distinct;
}
/**
* @mbg.generated
*/
public boolean isDistinct() {
return distinct;
}
/**
* @mbg.generated
*/
public List<Criteria> getOredCriteria() {
return oredCriteria;
}
/**
* @mbg.generated
*/
public void or(Criteria criteria) {
oredCriteria.add(criteria);
}
/**
* @mbg.generated
*/
public Criteria or() {
Criteria criteria = createCriteriaInternal();
oredCriteria.add(criteria);
return criteria;
}
/**
* @mbg.generated
*/
public Criteria createCriteria() {
Criteria criteria = createCriteriaInternal();
if (oredCriteria.size() == 0) {
@@ -54,35 +99,55 @@ public class ChatQueryDOExample {
return criteria;
}
/**
* @mbg.generated
*/
protected Criteria createCriteriaInternal() {
Criteria criteria = new Criteria();
return criteria;
}
/**
* @mbg.generated
*/
public void clear() {
oredCriteria.clear();
orderByClause = null;
distinct = false;
}
public Integer getLimitStart() {
return limitStart;
}
/**
* @mbg.generated
*/
public void setLimitStart(Integer limitStart) {
this.limitStart = limitStart;
}
public Integer getLimitEnd() {
return limitEnd;
/**
* @mbg.generated
*/
public Integer getLimitStart() {
return limitStart;
}
/**
* @mbg.generated
*/
public void setLimitEnd(Integer limitEnd) {
this.limitEnd = limitEnd;
}
protected abstract static class GeneratedCriteria {
/**
* @mbg.generated
*/
public Integer getLimitEnd() {
return limitEnd;
}
/**
* s2_chat_query null
*/
protected abstract static class GeneratedCriteria {
protected List<Criterion> criteria;
protected GeneratedCriteria() {
@@ -183,6 +248,66 @@ public class ChatQueryDOExample {
return (Criteria) this;
}
public Criteria andAgentIdIsNull() {
addCriterion("agent_id is null");
return (Criteria) this;
}
public Criteria andAgentIdIsNotNull() {
addCriterion("agent_id is not null");
return (Criteria) this;
}
public Criteria andAgentIdEqualTo(Integer value) {
addCriterion("agent_id =", value, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdNotEqualTo(Integer value) {
addCriterion("agent_id <>", value, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdGreaterThan(Integer value) {
addCriterion("agent_id >", value, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdGreaterThanOrEqualTo(Integer value) {
addCriterion("agent_id >=", value, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdLessThan(Integer value) {
addCriterion("agent_id <", value, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdLessThanOrEqualTo(Integer value) {
addCriterion("agent_id <=", value, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdIn(List<Integer> values) {
addCriterion("agent_id in", values, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdNotIn(List<Integer> values) {
addCriterion("agent_id not in", values, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdBetween(Integer value1, Integer value2) {
addCriterion("agent_id between", value1, value2, "agentId");
return (Criteria) this;
}
public Criteria andAgentIdNotBetween(Integer value1, Integer value2) {
addCriterion("agent_id not between", value1, value2, "agentId");
return (Criteria) this;
}
public Criteria andCreateTimeIsNull() {
addCriterion("create_time is null");
return (Criteria) this;
@@ -564,6 +689,9 @@ public class ChatQueryDOExample {
}
}
/**
* s2_chat_query
*/
public static class Criteria extends GeneratedCriteria {
protected Criteria() {
@@ -571,8 +699,10 @@ public class ChatQueryDOExample {
}
}
/**
* s2_chat_query null
*/
public static class Criterion {
private String condition;
private Object value;
@@ -657,4 +787,4 @@ public class ChatQueryDOExample {
return typeHandler;
}
}
}
}

View File

@@ -0,0 +1,23 @@
package com.tencent.supersonic.chat.persistence.dataobject;
public enum CostType {
MAPPER(1, "mapper"),
PARSER(2, "parser"),
QUERY(3, "query");
private Integer type;
private String name;
CostType(Integer type, String name) {
this.type = type;
this.name = name;
}
public Integer getType() {
return type;
}
public String getName() {
return name;
}
}

View File

@@ -254,4 +254,4 @@ public class PluginDO {
public void setComment(String comment) {
this.comment = comment == null ? null : comment.trim();
}
}
}

View File

@@ -892,38 +892,6 @@ public class PluginDOExample {
private String typeHandler;
public String getCondition() {
return condition;
}
public Object getValue() {
return value;
}
public Object getSecondValue() {
return secondValue;
}
public boolean isNoValue() {
return noValue;
}
public boolean isSingleValue() {
return singleValue;
}
public boolean isBetweenValue() {
return betweenValue;
}
public boolean isListValue() {
return listValue;
}
public String getTypeHandler() {
return typeHandler;
}
protected Criterion(String condition) {
super();
this.condition = condition;
@@ -959,5 +927,37 @@ public class PluginDOExample {
protected Criterion(String condition, Object value, Object secondValue) {
this(condition, value, secondValue, null);
}
public String getCondition() {
return condition;
}
public Object getValue() {
return value;
}
public Object getSecondValue() {
return secondValue;
}
public boolean isNoValue() {
return noValue;
}
public boolean isSingleValue() {
return singleValue;
}
public boolean isBetweenValue() {
return betweenValue;
}
public boolean isListValue() {
return listValue;
}
public String getTypeHandler() {
return typeHandler;
}
}
}
}

View File

@@ -0,0 +1,54 @@
package com.tencent.supersonic.chat.persistence.dataobject;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import java.util.Date;
@Data
@Builder
@NoArgsConstructor
@Getter
@AllArgsConstructor
public class StatisticsDO {
/**
* questionId
*/
private Long questionId;
/**
* chatId
*/
private Long chatId;
/**
* createTime
*/
private Date createTime;
/**
* queryText
*/
private String queryText;
/**
* userName
*/
private String userName;
/**
* interface
*/
private String interfaceName;
/**
* cost
*/
private Integer cost;
private Integer type;
}

View File

@@ -0,0 +1,17 @@
package com.tencent.supersonic.chat.persistence.mapper;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
@Mapper
public interface ChatParseMapper {
boolean batchSaveParseInfo(@Param("list") List<ChatParseDO> list);
ChatParseDO getParseInfo(Long questionId, String userName, int parseId);
}

View File

@@ -14,4 +14,5 @@ public interface ChatQueryDOMapper {
int updateByPrimaryKeyWithBLOBs(ChatQueryDO record);
Boolean deleteByPrimaryKey(Long questionId);
}

View File

@@ -0,0 +1,12 @@
package com.tencent.supersonic.chat.persistence.mapper;
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
@Mapper
public interface StatisticsMapper {
boolean batchSaveStatistics(@Param("list") List<StatisticsDO> list);
}

View File

@@ -0,0 +1,12 @@
package com.tencent.supersonic.chat.persistence.mapper.custom;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import org.apache.ibatis.annotations.Mapper;
import java.util.List;
@Mapper
public interface ShowCaseCustomMapper {
List<ChatQueryDO> queryShowCase(int start, int limit, int agentId);
}

View File

@@ -2,18 +2,37 @@ package com.tencent.supersonic.chat.persistence.repository;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import java.util.List;
public interface ChatQueryRepository {
PageInfo<QueryResp> getChatQuery(PageQueryInfoReq pageQueryInfoCommend, long chatId);
List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
void createChatQuery(QueryResult queryResult, ChatContext chatCtx);
ChatQueryDO getLastChatQuery(long chatId);
int updateChatQuery(ChatQueryDO chatQueryDO);
Long createChatParse(ParseResp parseResult, ChatContext chatCtx, QueryReq queryReq);
Boolean batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
ParseResp parseResult,
List<SemanticParseInfo> candidateParses,
List<SemanticParseInfo> selectedParses);
public ChatParseDO getParseInfo(Long questionId, String userName, int parseId);
Boolean deleteChatQuery(Long questionId);
}

View File

@@ -0,0 +1,11 @@
package com.tencent.supersonic.chat.persistence.repository;
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import java.util.List;
public interface StatisticsRepository {
boolean batchSaveStatistics(List<StatisticsDO> list);
}

View File

@@ -1,13 +1,14 @@
package com.tencent.supersonic.chat.persistence.repository.impl;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.config.ChatConfig;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.config.ChatConfigFilterInternal;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatConfigDO;
import com.tencent.supersonic.chat.persistence.mapper.ChatConfigMapper;
import com.tencent.supersonic.chat.persistence.repository.ChatConfigRepository;
import com.tencent.supersonic.chat.utils.ChatConfigHelper;
import com.tencent.supersonic.chat.persistence.mapper.ChatConfigMapper;
import java.util.ArrayList;
import java.util.List;
import org.springframework.beans.BeanUtils;
@@ -23,7 +24,7 @@ public class ChatConfigRepositoryImpl implements ChatConfigRepository {
private final ChatConfigMapper chatConfigMapper;
public ChatConfigRepositoryImpl(ChatConfigHelper chatConfigHelper,
ChatConfigMapper chatConfigMapper) {
ChatConfigMapper chatConfigMapper) {
this.chatConfigHelper = chatConfigHelper;
this.chatConfigMapper = chatConfigMapper;
}
@@ -52,8 +53,8 @@ public class ChatConfigRepositoryImpl implements ChatConfigRepository {
List<ChatConfigDO> chaConfigDOList = chatConfigMapper.search(filterInternal);
if (!CollectionUtils.isEmpty(chaConfigDOList)) {
chaConfigDOList.stream().forEach(chaConfigDO ->
chaConfigDescriptorList.add(
chatConfigHelper.chatConfigDO2Descriptor(chaConfigDO.getModelId(), chaConfigDO)));
chaConfigDescriptorList.add(chatConfigHelper
.chatConfigDO2Descriptor(chaConfigDO.getModelId(), chaConfigDO)));
}
return chaConfigDescriptorList;
}

View File

@@ -52,8 +52,9 @@ public class ChatContextRepositoryImpl implements ChatContextRepository {
chatContext.setUser(contextDO.getUser());
chatContext.setQueryText(contextDO.getQueryText());
if (contextDO.getSemanticParse() != null && !contextDO.getSemanticParse().isEmpty()) {
log.info("--->: {}",contextDO.getSemanticParse());
SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(), SemanticParseInfo.class);
log.info("--->: {}", contextDO.getSemanticParse());
SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(),
SemanticParseInfo.class);
chatContext.setParseInfo(semanticParseInfo);
}
return chatContext;

View File

@@ -3,20 +3,30 @@ package com.tencent.supersonic.chat.persistence.repository.impl;
import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample.Criteria;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.persistence.mapper.ChatParseMapper;
import com.tencent.supersonic.chat.persistence.mapper.ChatQueryDOMapper;
import com.tencent.supersonic.chat.persistence.mapper.custom.ShowCaseCustomMapper;
import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.PageUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;
@@ -29,8 +39,16 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
private final ChatQueryDOMapper chatQueryDOMapper;
public ChatQueryRepositoryImpl(ChatQueryDOMapper chatQueryDOMapper) {
private final ChatParseMapper chatParseMapper;
private final ShowCaseCustomMapper showCaseCustomMapper;
public ChatQueryRepositoryImpl(ChatQueryDOMapper chatQueryDOMapper,
ChatParseMapper chatParseMapper,
ShowCaseCustomMapper showCaseCustomMapper) {
this.chatQueryDOMapper = chatQueryDOMapper;
this.chatParseMapper = chatParseMapper;
this.showCaseCustomMapper = showCaseCustomMapper;
}
@Override
@@ -47,18 +65,27 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
PageInfo<QueryResp> chatQueryVOPageInfo = PageUtils.pageInfo2PageInfoVo(pageInfo);
chatQueryVOPageInfo.setList(
pageInfo.getList().stream().map(this::convertTo)
pageInfo.getList().stream().filter(o -> !StringUtils.isEmpty(o.getQueryResult())).map(this::convertTo)
.sorted(Comparator.comparingInt(o -> o.getQuestionId().intValue()))
.collect(Collectors.toList()));
return chatQueryVOPageInfo;
}
@Override
public List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId) {
return showCaseCustomMapper.queryShowCase(pageQueryInfoCommend.getCurrent(),
pageQueryInfoCommend.getPageSize(), agentId).stream().map(this::convertTo)
.collect(Collectors.toList());
}
private QueryResp convertTo(ChatQueryDO chatQueryDO) {
QueryResp queryResponse = new QueryResp();
BeanUtils.copyProperties(chatQueryDO, queryResponse);
QueryResult queryResult = JsonUtil.toObject(chatQueryDO.getQueryResult(), QueryResult.class);
queryResult.setQueryId(chatQueryDO.getQuestionId());
queryResponse.setQueryResult(queryResult);
if (queryResult != null) {
queryResult.setQueryId(chatQueryDO.getQuestionId());
queryResponse.setQueryResult(queryResult);
}
return queryResponse;
}
@@ -71,12 +98,63 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
chatQueryDO.setQueryState(queryResult.getQueryState().ordinal());
chatQueryDO.setQueryText(chatCtx.getQueryText());
chatQueryDO.setQueryResult(JsonUtil.toString(queryResult));
chatQueryDO.setAgentId(chatCtx.getAgentId());
chatQueryDOMapper.insert(chatQueryDO);
ChatQueryDO lastChatQuery = getLastChatQuery(chatCtx.getChatId());
Long queryId = lastChatQuery.getQuestionId();
queryResult.setQueryId(queryId);
}
public Long createChatParse(ParseResp parseResult, ChatContext chatCtx, QueryReq queryReq) {
ChatQueryDO chatQueryDO = new ChatQueryDO();
chatQueryDO.setChatId(Long.valueOf(chatCtx.getChatId()));
chatQueryDO.setCreateTime(new java.util.Date());
chatQueryDO.setUserName(queryReq.getUser().getName());
chatQueryDO.setQueryText(queryReq.getQueryText());
chatQueryDO.setAgentId(queryReq.getAgentId());
chatQueryDO.setQueryResult("");
try {
chatQueryDOMapper.insert(chatQueryDO);
} catch (Exception e) {
log.info("database insert has an exception:{}", e.toString());
}
ChatQueryDO lastChatQuery = getLastChatQuery(chatCtx.getChatId());
Long queryId = lastChatQuery.getQuestionId();
parseResult.setQueryId(queryId);
return queryId;
}
public Boolean batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
ParseResp parseResult,
List<SemanticParseInfo> candidateParses,
List<SemanticParseInfo> selectedParses) {
Long queryId = createChatParse(parseResult, chatCtx, queryReq);
List<ChatParseDO> chatParseDOList = new ArrayList<>();
log.info("candidateParses size:{},selectedParses size:{}", candidateParses.size(), selectedParses.size());
getChatParseDO(chatCtx, queryReq, queryId, 0, 1, candidateParses, chatParseDOList);
getChatParseDO(chatCtx, queryReq, queryId, candidateParses.size(), 0, selectedParses, chatParseDOList);
Boolean save = chatParseMapper.batchSaveParseInfo(chatParseDOList);
return save;
}
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId, int base, int isCandidate,
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
for (int i = 0; i < parses.size(); i++) {
ChatParseDO chatParseDO = new ChatParseDO();
parses.get(i).setId(base + i + 1);
chatParseDO.setChatId(Long.valueOf(chatCtx.getChatId()));
chatParseDO.setQuestionId(queryId);
chatParseDO.setQueryText(queryReq.getQueryText());
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
chatParseDO.setIsCandidate(isCandidate);
chatParseDO.setParseId(base + i + 1);
chatParseDO.setCreateTime(new java.util.Date());
chatParseDO.setUserName(queryReq.getUser().getName());
chatParseDOList.add(chatParseDO);
}
}
@Override
public ChatQueryDO getLastChatQuery(long chatId) {
ChatQueryDOExample example = new ChatQueryDOExample();
@@ -96,4 +174,13 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
public int updateChatQuery(ChatQueryDO chatQueryDO) {
return chatQueryDOMapper.updateByPrimaryKeyWithBLOBs(chatQueryDO);
}
public ChatParseDO getParseInfo(Long questionId, String userName, int parseId) {
return chatParseMapper.getParseInfo(questionId, userName, parseId);
}
@Override
public Boolean deleteChatQuery(Long questionId) {
return chatQueryDOMapper.deleteByPrimaryKey(questionId);
}
}

View File

@@ -5,13 +5,14 @@ import com.tencent.supersonic.chat.persistence.dataobject.PluginDOExample;
import com.tencent.supersonic.chat.persistence.mapper.PluginDOMapper;
import com.tencent.supersonic.chat.persistence.repository.PluginRepository;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;
import org.springframework.stereotype.Repository;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;
import org.springframework.stereotype.Repository;
@Repository
@Slf4j

View File

@@ -0,0 +1,29 @@
package com.tencent.supersonic.chat.persistence.repository.impl;
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.persistence.mapper.StatisticsMapper;
import com.tencent.supersonic.chat.persistence.repository.StatisticsRepository;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;
import java.util.List;
@Repository
@Primary
@Slf4j
public class StatisticsRepositoryImpl implements StatisticsRepository {
private final StatisticsMapper statisticsMapper;
public StatisticsRepositoryImpl(StatisticsMapper statisticsMapper) {
this.statisticsMapper = statisticsMapper;
}
public boolean batchSaveStatistics(List<StatisticsDO> list) {
return statisticsMapper.batchSaveStatistics(list);
}
;
}

View File

@@ -3,15 +3,17 @@ package com.tencent.supersonic.chat.plugin;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.agent.tool.DslTool;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.PluginTool;
import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingResp;
import com.tencent.supersonic.chat.parser.embedding.RecallRetrieval;
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
import com.tencent.supersonic.chat.plugin.event.PluginAddEvent;
import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.query.plugin.ParamOption;
@@ -20,10 +22,16 @@ import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.net.URI;
import java.util.*;
import java.util.List;
import java.util.Collection;
import java.util.Set;
import java.util.Optional;
import java.util.HashSet;
import java.util.HashMap;
import java.util.Objects;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@@ -31,7 +39,11 @@ import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.util.Strings;
import org.springframework.context.event.EventListener;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.*;
import org.springframework.http.ResponseEntity;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.HttpEntity;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
@@ -149,9 +161,6 @@ public class PluginManager {
}
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
plugins = plugins.stream()
.filter(plugin -> ParseMode.EMBEDDING_RECALL.equals(plugin.getParseMode()))
.collect(Collectors.toList());
requestEmbeddingPluginAdd(convert(plugins));
}
@@ -229,11 +238,11 @@ public class PluginManager {
}
List<ParamOption> paramOptions = getSemanticOption(plugin);
if (CollectionUtils.isEmpty(paramOptions)) {
return Pair.of(true, Sets.newHashSet());
return Pair.of(true, pluginMatchedModel);
}
Set<Long> matchedModel = Sets.newHashSet();
Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream().
collect(Collectors.groupingBy(ParamOption::getModelId));
Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream()
.collect(Collectors.groupingBy(ParamOption::getModelId));
for (Long modelId : paramOptionMap.keySet()) {
List<ParamOption> params = paramOptionMap.get(modelId);
if (CollectionUtils.isEmpty(params)) {
@@ -268,8 +277,8 @@ public class PluginManager {
return Sets.newHashSet();
}
return schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()) ||
SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.map(SchemaElementMatch::getElement)
.map(SchemaElement::getId)
.collect(Collectors.toSet());

View File

@@ -1,12 +1,13 @@
package com.tencent.supersonic.chat.plugin;
import com.tencent.supersonic.chat.parser.function.Parameters;
import java.io.Serializable;
import java.util.List;
import com.tencent.supersonic.chat.parser.plugin.function.Parameters;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import java.io.Serializable;
import java.util.List;
import lombok.NoArgsConstructor;
import lombok.ToString;
@@ -17,12 +18,12 @@ import lombok.ToString;
@NoArgsConstructor
public class PluginParseConfig implements Serializable {
private String name;
private String description;
public Parameters parameters;
public List<String> examples;
private String name;
private String description;
}

View File

@@ -5,11 +5,13 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import java.util.List;
import java.util.ArrayList;
import java.util.OptionalDouble;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import java.util.ArrayList;
import java.util.List;
import java.util.OptionalDouble;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@@ -50,8 +52,8 @@ public class HeuristicQuerySelector implements QuerySelector {
return true;
}
for (SemanticQuery candidateQuery : candidateQueries) {
if (candidateQuery.getQueryMode().equals(MetricEntityQuery.QUERY_MODE) &&
semanticQuery.getParseInfo().getScore() == candidateQuery.getParseInfo().getScore()) {
if (candidateQuery.getQueryMode().equals(MetricEntityQuery.QUERY_MODE)
&& semanticQuery.getParseInfo().getScore() == candidateQuery.getParseInfo().getScore()) {
return false;
}
}

View File

@@ -1,17 +0,0 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FieldCorrector extends BaseDSLOptimizer {
@Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
String replaceFields = CCJSqlParserUtils.replaceFields(correctionInfo.getSql(),
getFieldToBizName(correctionInfo.getParseInfo().getModelId()));
correctionInfo.setSql(replaceFields);
return correctionInfo;
}
}

View File

@@ -1,16 +0,0 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FunctionCorrector extends BaseDSLOptimizer {
@Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
String replaceFunction = CCJSqlParserUtils.replaceFunction(correctionInfo.getSql());
correctionInfo.setSql(replaceFunction);
return correctionInfo;
}
}

View File

@@ -1,15 +1,12 @@
package com.tencent.supersonic.chat.query.dsl;
package com.tencent.supersonic.chat.query.llm.dsl;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.component.DSLOptimizer;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
@@ -29,12 +26,12 @@ import org.springframework.stereotype.Component;
@Slf4j
@Component
public class DSLQuery extends PluginSemanticQuery {
public class DslQuery extends PluginSemanticQuery {
public static final String QUERY_MODE = "DSL";
protected SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
public DSLQuery() {
public DslQuery() {
QueryManager.register(this);
}
@@ -48,31 +45,12 @@ public class DSLQuery extends PluginSemanticQuery {
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class);
LLMResp llmResp = dslParseResult.getLlmResp();
QueryReq queryReq = dslParseResult.getRequest();
CorrectionInfo correctionInfo = CorrectionInfo.builder()
.queryFilters(queryReq.getQueryFilters())
.sql(llmResp.getSqlOutput())
.parseInfo(parseInfo)
.build();
List<DSLOptimizer> DSLCorrections = ComponentFactory.getSqlCorrections();
DSLCorrections.forEach(DSLCorrection -> {
try {
DSLCorrection.rewriter(correctionInfo);
log.info("sqlCorrection:{} sql:{}", DSLCorrection.getClass().getSimpleName(), correctionInfo.getSql());
} catch (Exception e) {
log.error("sqlCorrection:{} execute error,correctionInfo:{}", DSLCorrection, correctionInfo, e);
}
});
String querySql = correctionInfo.getSql();
long startTime = System.currentTimeMillis();
QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(querySql, parseInfo.getModelId());
QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(llmResp.getCorrectorSql(), parseInfo.getModelId());
QueryResultWithSchemaResp queryResp = semanticLayer.queryByDsl(queryDslReq, user);
log.info("queryByDsl cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
log.info("queryByDsl cost:{},querySql:{}", System.currentTimeMillis() - startTime, llmResp.getSqlOutput());
QueryResult queryResult = new QueryResult();
if (Objects.nonNull(queryResp)) {

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.dsl;
package com.tencent.supersonic.chat.query.llm.dsl;
import java.util.List;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.dsl;
package com.tencent.supersonic.chat.query.llm.dsl;
import java.util.List;
import lombok.Data;
@@ -17,4 +17,6 @@ public class LLMResp {
private String schemaLinkingOutput;
private String schemaLinkStr;
private String correctorSql;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.metricInterpret;
package com.tencent.supersonic.chat.query.metricinterpret;
import lombok.Data;

View File

@@ -1,11 +1,9 @@
package com.tencent.supersonic.chat.query.metricInterpret;
package com.tencent.supersonic.chat.query.metricinterpret;
import lombok.Data;
@Data
public class LLmAnswerResp {
private String assistant_message;
private String assistantMessage;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.metricInterpret;
package com.tencent.supersonic.chat.query.metricinterpret;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
@@ -26,7 +26,11 @@ import org.apache.commons.lang3.StringUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.*;
import java.util.Map;
import java.util.HashMap;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
@@ -34,17 +38,17 @@ import java.util.stream.Collectors;
public class MetricInterpretQuery extends PluginSemanticQuery {
public final static String QUERY_MODE = "METRIC_INTERPRET";
public static final String QUERY_MODE = "METRIC_INTERPRET";
public MetricInterpretQuery() {
QueryManager.register(this);
}
@Override
public String getQueryMode() {
return QUERY_MODE;
}
public MetricInterpretQuery() {
QueryManager.register(this);
}
@Override
public QueryResult execute(User user) throws SqlParseException {
QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo);
@@ -55,10 +59,11 @@ public class MetricInterpretQuery extends PluginSemanticQuery {
String text = generateTableText(queryResultWithSchemaResp);
Map<String, Object> properties = parseInfo.getProperties();
Map<String, String> replacedMap = new HashMap<>();
String textReplaced = replaceText((String) properties.get("queryText"), parseInfo.getElementMatches(), replacedMap);
String textReplaced = replaceText((String) properties.get("queryText"),
parseInfo.getElementMatches(), replacedMap);
String answer = replaceAnswer(fetchInterpret(textReplaced, text), replacedMap);
QueryResult queryResult = new QueryResult();
List<QueryColumn> queryColumns = Lists.newArrayList(new QueryColumn("结果","string","answer"));
List<QueryColumn> queryColumns = Lists.newArrayList(new QueryColumn("结果", "string", "answer"));
Map<String, Object> result = new HashMap<>();
result.put("answer", answer);
List<Map<String, Object>> resultList = Lists.newArrayList();
@@ -70,7 +75,8 @@ public class MetricInterpretQuery extends PluginSemanticQuery {
return queryResult;
}
private String replaceText(String text, List<SchemaElementMatch> schemaElementMatches, Map<String, String> replacedMap) {
private String replaceText(String text, List<SchemaElementMatch> schemaElementMatches,
Map<String, String> replacedMap) {
if (CollectionUtils.isEmpty(schemaElementMatches)) {
return text;
}
@@ -134,10 +140,10 @@ public class MetricInterpretQuery extends PluginSemanticQuery {
JSONObject.toJSONString(lLmAnswerReq));
LLmAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLmAnswerResp.class);
if (lLmAnswerResp != null) {
return lLmAnswerResp.getAssistant_message();
return lLmAnswerResp.getAssistantMessage();
}
return null;
}
}
}

Some files were not shown because too many files have changed in this diff Show More