mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 13:47:09 +00:00
(improvement)(Chat) Determine whether to enable LLM based on agent information (#810)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -66,6 +66,10 @@ public class Agent extends RecordInfo {
|
|||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean containsLLMParserTool() {
|
||||||
|
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
|
||||||
|
}
|
||||||
|
|
||||||
public Set<Long> getDataSetIds() {
|
public Set<Long> getDataSetIds() {
|
||||||
Set<Long> dataSetIds = getDataSetIds(null);
|
Set<Long> dataSetIds = getDataSetIds(null);
|
||||||
if (containsAllModel(dataSetIds)) {
|
if (containsAllModel(dataSetIds)) {
|
||||||
|
|||||||
@@ -106,6 +106,9 @@ public class ChatServiceImpl implements ChatService {
|
|||||||
if (agent == null) {
|
if (agent == null) {
|
||||||
return queryReq;
|
return queryReq;
|
||||||
}
|
}
|
||||||
|
if (agent.containsLLMParserTool()) {
|
||||||
|
queryReq.setEnableLLM(true);
|
||||||
|
}
|
||||||
queryReq.setDataSetIds(agent.getDataSetIds());
|
queryReq.setDataSetIds(agent.getDataSetIds());
|
||||||
return queryReq;
|
return queryReq;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,4 +14,5 @@ public class QueryReq {
|
|||||||
private User user;
|
private User user;
|
||||||
private QueryFilters queryFilters;
|
private QueryFilters queryFilters;
|
||||||
private boolean saveAnswer = true;
|
private boolean saveAnswer = true;
|
||||||
|
private boolean enableLLM;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -75,7 +75,6 @@ import org.apache.commons.lang3.StringUtils;
|
|||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -102,9 +101,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private DataSetService dataSetService;
|
private DataSetService dataSetService;
|
||||||
|
|
||||||
@Value("${time.threshold: 100}")
|
|
||||||
private Integer timeThreshold;
|
|
||||||
|
|
||||||
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
|
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
|
||||||
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
|
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
|
||||||
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
|
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
|
||||||
@@ -118,22 +114,15 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
|
|
||||||
// in order to support multi-turn conversation, chat context is needed
|
// in order to support multi-turn conversation, chat context is needed
|
||||||
ChatContext chatCtx = chatContextService.getOrCreateContext(queryReq.getChatId());
|
ChatContext chatCtx = chatContextService.getOrCreateContext(queryReq.getChatId());
|
||||||
List<StatisticsDO> timeCostDOList = new ArrayList<>();
|
|
||||||
|
|
||||||
// 1. mapper
|
// 1. mapper
|
||||||
schemaMappers.forEach(mapper -> {
|
schemaMappers.forEach(mapper -> {
|
||||||
long startTime = System.currentTimeMillis();
|
|
||||||
mapper.map(queryCtx);
|
mapper.map(queryCtx);
|
||||||
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
|
|
||||||
.interfaceName(mapper.getClass().getSimpleName()).type(CostType.MAPPER.getType()).build());
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// 2. parser
|
// 2. parser
|
||||||
semanticParsers.forEach(parser -> {
|
semanticParsers.forEach(parser -> {
|
||||||
long startTime = System.currentTimeMillis();
|
|
||||||
parser.parse(queryCtx, chatCtx);
|
parser.parse(queryCtx, chatCtx);
|
||||||
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
|
|
||||||
.interfaceName(parser.getClass().getSimpleName()).type(CostType.PARSER.getType()).build());
|
|
||||||
log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
|
log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -168,6 +157,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
.candidateQueries(new ArrayList<>())
|
.candidateQueries(new ArrayList<>())
|
||||||
.mapInfo(new SchemaMapInfo())
|
.mapInfo(new SchemaMapInfo())
|
||||||
.modelIdToDataSetIds(modelIdToDataSetIds)
|
.modelIdToDataSetIds(modelIdToDataSetIds)
|
||||||
|
.enableLLM(queryReq.isEnableLLM())
|
||||||
.build();
|
.build();
|
||||||
BeanUtils.copyProperties(queryReq, queryCtx);
|
BeanUtils.copyProperties(queryReq, queryCtx);
|
||||||
return queryCtx;
|
return queryCtx;
|
||||||
|
|||||||
Reference in New Issue
Block a user