mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
[improvement] llm supports all models (#615)
This commit is contained in:
@@ -22,4 +22,6 @@ public class LLMParserConfig {
|
||||
@Value("${metric.topn:5}")
|
||||
private Integer metricTopN;
|
||||
|
||||
@Value("${all.model:false}")
|
||||
private Boolean allModel;
|
||||
}
|
||||
|
||||
@@ -13,9 +13,12 @@ import com.tencent.supersonic.chat.core.knowledge.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||
@@ -66,6 +69,26 @@ public class LLMRequestService {
|
||||
if (Objects.nonNull(agent)) {
|
||||
distinctModelIds = agent.getModelIds(AgentToolType.NL2SQL_LLM);
|
||||
}
|
||||
if (llmParserConfig.getAllModel()) {
|
||||
ModelCluster modelCluster = ModelCluster.build(distinctModelIds);
|
||||
if (!CollectionUtils.isEmpty(queryCtx.getCandidateQueries())) {
|
||||
queryCtx.getCandidateQueries().stream().forEach(o -> {
|
||||
if (LLMSqlQuery.QUERY_MODE.equals(o.getParseInfo().getQueryMode())) {
|
||||
o.getParseInfo().setModel(modelCluster);
|
||||
}
|
||||
});
|
||||
}
|
||||
SemanticQuery semanticQuery = QueryManager.createQuery(LLMSqlQuery.QUERY_MODE);
|
||||
semanticQuery.getParseInfo().setModel(modelCluster);
|
||||
List<SchemaElementMatch> schemaElementMatches = new ArrayList<>();
|
||||
distinctModelIds.stream().forEach(o -> {
|
||||
if (!CollectionUtils.isEmpty(queryCtx.getMapInfo().getMatchedElements(o))) {
|
||||
schemaElementMatches.addAll(queryCtx.getMapInfo().getMatchedElements(o));
|
||||
}
|
||||
});
|
||||
queryCtx.getModelClusterMapInfo().setMatchedElements(modelCluster.getKey(), schemaElementMatches);
|
||||
return modelCluster;
|
||||
}
|
||||
if (Agent.containsAllModel(distinctModelIds)) {
|
||||
distinctModelIds = new HashSet<>();
|
||||
}
|
||||
|
||||
@@ -2,11 +2,17 @@ package com.tencent.supersonic.chat.core.query.llm.s2sql;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.Builder;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.Builder;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class LLMSqlResp {
|
||||
|
||||
private double sqlWeight;
|
||||
|
||||
Reference in New Issue
Block a user