(improvement)(Chat) Determine whether to enable LLM based on agent information (#810)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2024-03-13 10:22:16 +08:00
committed by GitHub
parent 46910fbfcf
commit 8a8370164f
4 changed files with 9 additions and 11 deletions

View File

@@ -66,6 +66,10 @@ public class Agent extends RecordInfo {
.collect(Collectors.toList());
}
public boolean containsLLMParserTool() {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
}
public Set<Long> getDataSetIds() {
Set<Long> dataSetIds = getDataSetIds(null);
if (containsAllModel(dataSetIds)) {

View File

@@ -106,6 +106,9 @@ public class ChatServiceImpl implements ChatService {
if (agent == null) {
return queryReq;
}
if (agent.containsLLMParserTool()) {
queryReq.setEnableLLM(true);
}
queryReq.setDataSetIds(agent.getDataSetIds());
return queryReq;
}

View File

@@ -14,4 +14,5 @@ public class QueryReq {
private User user;
private QueryFilters queryFilters;
private boolean saveAnswer = true;
private boolean enableLLM;
}

View File

@@ -75,7 +75,6 @@ import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
@@ -102,9 +101,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
@Autowired
private DataSetService dataSetService;
@Value("${time.threshold: 100}")
private Integer timeThreshold;
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
@@ -118,22 +114,15 @@ public class ChatQueryServiceImpl implements ChatQueryService {
// in order to support multi-turn conversation, chat context is needed
ChatContext chatCtx = chatContextService.getOrCreateContext(queryReq.getChatId());
List<StatisticsDO> timeCostDOList = new ArrayList<>();
// 1. mapper
schemaMappers.forEach(mapper -> {
long startTime = System.currentTimeMillis();
mapper.map(queryCtx);
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(mapper.getClass().getSimpleName()).type(CostType.MAPPER.getType()).build());
});
// 2. parser
semanticParsers.forEach(parser -> {
long startTime = System.currentTimeMillis();
parser.parse(queryCtx, chatCtx);
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(parser.getClass().getSimpleName()).type(CostType.PARSER.getType()).build());
log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
});
@@ -168,6 +157,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
.candidateQueries(new ArrayList<>())
.mapInfo(new SchemaMapInfo())
.modelIdToDataSetIds(modelIdToDataSetIds)
.enableLLM(queryReq.isEnableLLM())
.build();
BeanUtils.copyProperties(queryReq, queryCtx);
return queryCtx;