[improvement] llm supports all models (#615)

This commit is contained in:
mainmain
2024-01-10 22:01:21 +08:00
committed by GitHub
parent 9c8039c499
commit 1e93282c9f
3 changed files with 32 additions and 1 deletions

View File

@@ -22,4 +22,6 @@ public class LLMParserConfig {
@Value("${metric.topn:5}")
private Integer metricTopN;
@Value("${all.model:false}")
private Boolean allModel;
}

View File

@@ -13,9 +13,12 @@ import com.tencent.supersonic.chat.core.knowledge.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
@@ -66,6 +69,26 @@ public class LLMRequestService {
if (Objects.nonNull(agent)) {
distinctModelIds = agent.getModelIds(AgentToolType.NL2SQL_LLM);
}
if (llmParserConfig.getAllModel()) {
ModelCluster modelCluster = ModelCluster.build(distinctModelIds);
if (!CollectionUtils.isEmpty(queryCtx.getCandidateQueries())) {
queryCtx.getCandidateQueries().stream().forEach(o -> {
if (LLMSqlQuery.QUERY_MODE.equals(o.getParseInfo().getQueryMode())) {
o.getParseInfo().setModel(modelCluster);
}
});
}
SemanticQuery semanticQuery = QueryManager.createQuery(LLMSqlQuery.QUERY_MODE);
semanticQuery.getParseInfo().setModel(modelCluster);
List<SchemaElementMatch> schemaElementMatches = new ArrayList<>();
distinctModelIds.stream().forEach(o -> {
if (!CollectionUtils.isEmpty(queryCtx.getMapInfo().getMatchedElements(o))) {
schemaElementMatches.addAll(queryCtx.getMapInfo().getMatchedElements(o));
}
});
queryCtx.getModelClusterMapInfo().setMatchedElements(modelCluster.getKey(), schemaElementMatches);
return modelCluster;
}
if (Agent.containsAllModel(distinctModelIds)) {
distinctModelIds = new HashSet<>();
}

View File

@@ -2,11 +2,17 @@ package com.tencent.supersonic.chat.core.query.llm.s2sql;
import java.util.List;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.Builder;
import lombok.AllArgsConstructor;
import lombok.NoArgsConstructor;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class LLMSqlResp {
private double sqlWeight;