mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(Chat) Determine whether to enable LLM based on agent information (#810)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -66,6 +66,10 @@ public class Agent extends RecordInfo {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public boolean containsLLMParserTool() {
|
||||
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
|
||||
}
|
||||
|
||||
public Set<Long> getDataSetIds() {
|
||||
Set<Long> dataSetIds = getDataSetIds(null);
|
||||
if (containsAllModel(dataSetIds)) {
|
||||
|
||||
@@ -106,6 +106,9 @@ public class ChatServiceImpl implements ChatService {
|
||||
if (agent == null) {
|
||||
return queryReq;
|
||||
}
|
||||
if (agent.containsLLMParserTool()) {
|
||||
queryReq.setEnableLLM(true);
|
||||
}
|
||||
queryReq.setDataSetIds(agent.getDataSetIds());
|
||||
return queryReq;
|
||||
}
|
||||
|
||||
@@ -14,4 +14,5 @@ public class QueryReq {
|
||||
private User user;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
private boolean enableLLM;
|
||||
}
|
||||
|
||||
@@ -75,7 +75,6 @@ import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -102,9 +101,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
@Autowired
|
||||
private DataSetService dataSetService;
|
||||
|
||||
@Value("${time.threshold: 100}")
|
||||
private Integer timeThreshold;
|
||||
|
||||
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
|
||||
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
|
||||
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
|
||||
@@ -118,22 +114,15 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
|
||||
// in order to support multi-turn conversation, chat context is needed
|
||||
ChatContext chatCtx = chatContextService.getOrCreateContext(queryReq.getChatId());
|
||||
List<StatisticsDO> timeCostDOList = new ArrayList<>();
|
||||
|
||||
// 1. mapper
|
||||
schemaMappers.forEach(mapper -> {
|
||||
long startTime = System.currentTimeMillis();
|
||||
mapper.map(queryCtx);
|
||||
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
|
||||
.interfaceName(mapper.getClass().getSimpleName()).type(CostType.MAPPER.getType()).build());
|
||||
});
|
||||
|
||||
// 2. parser
|
||||
semanticParsers.forEach(parser -> {
|
||||
long startTime = System.currentTimeMillis();
|
||||
parser.parse(queryCtx, chatCtx);
|
||||
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
|
||||
.interfaceName(parser.getClass().getSimpleName()).type(CostType.PARSER.getType()).build());
|
||||
log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
|
||||
});
|
||||
|
||||
@@ -168,6 +157,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
.candidateQueries(new ArrayList<>())
|
||||
.mapInfo(new SchemaMapInfo())
|
||||
.modelIdToDataSetIds(modelIdToDataSetIds)
|
||||
.enableLLM(queryReq.isEnableLLM())
|
||||
.build();
|
||||
BeanUtils.copyProperties(queryReq, queryCtx);
|
||||
return queryCtx;
|
||||
|
||||
Reference in New Issue
Block a user