From 1e93282c9f194012547e5ecbc1e1cbc08a1a08c3 Mon Sep 17 00:00:00 2001 From: mainmain <57514971+mainmainer@users.noreply.github.com> Date: Wed, 10 Jan 2024 22:01:21 +0800 Subject: [PATCH] [improvement] llm supports all models (#615) --- .../chat/core/config/LLMParserConfig.java | 2 ++ .../parser/sql/llm/LLMRequestService.java | 23 +++++++++++++++++++ .../chat/core/query/llm/s2sql/LLMSqlResp.java | 8 ++++++- 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/config/LLMParserConfig.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/config/LLMParserConfig.java index b05bc0502..c38cfbe5e 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/config/LLMParserConfig.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/config/LLMParserConfig.java @@ -22,4 +22,6 @@ public class LLMParserConfig { @Value("${metric.topn:5}") private Integer metricTopN; + @Value("${all.model:false}") + private Boolean allModel; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMRequestService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMRequestService.java index 990bd5863..901e1de36 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMRequestService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMRequestService.java @@ -13,9 +13,12 @@ import com.tencent.supersonic.chat.core.knowledge.semantic.SemanticInterpreter; import com.tencent.supersonic.chat.core.parser.SatisfactionChecker; import com.tencent.supersonic.chat.core.pojo.ChatContext; import com.tencent.supersonic.chat.core.pojo.QueryContext; +import com.tencent.supersonic.chat.core.query.QueryManager; +import com.tencent.supersonic.chat.core.query.SemanticQuery; import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue; import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp; +import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.chat.core.utils.ComponentFactory; import com.tencent.supersonic.common.pojo.ModelCluster; import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum; @@ -66,6 +69,26 @@ public class LLMRequestService { if (Objects.nonNull(agent)) { distinctModelIds = agent.getModelIds(AgentToolType.NL2SQL_LLM); } + if (llmParserConfig.getAllModel()) { + ModelCluster modelCluster = ModelCluster.build(distinctModelIds); + if (!CollectionUtils.isEmpty(queryCtx.getCandidateQueries())) { + queryCtx.getCandidateQueries().stream().forEach(o -> { + if (LLMSqlQuery.QUERY_MODE.equals(o.getParseInfo().getQueryMode())) { + o.getParseInfo().setModel(modelCluster); + } + }); + } + SemanticQuery semanticQuery = QueryManager.createQuery(LLMSqlQuery.QUERY_MODE); + semanticQuery.getParseInfo().setModel(modelCluster); + List schemaElementMatches = new ArrayList<>(); + distinctModelIds.stream().forEach(o -> { + if (!CollectionUtils.isEmpty(queryCtx.getMapInfo().getMatchedElements(o))) { + schemaElementMatches.addAll(queryCtx.getMapInfo().getMatchedElements(o)); + } + }); + queryCtx.getModelClusterMapInfo().setMatchedElements(modelCluster.getKey(), schemaElementMatches); + return modelCluster; + } if (Agent.containsAllModel(distinctModelIds)) { distinctModelIds = new HashSet<>(); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/query/llm/s2sql/LLMSqlResp.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/query/llm/s2sql/LLMSqlResp.java index 530cce791..6883bad73 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/query/llm/s2sql/LLMSqlResp.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/query/llm/s2sql/LLMSqlResp.java @@ -2,11 +2,17 @@ package com.tencent.supersonic.chat.core.query.llm.s2sql; import java.util.List; import java.util.Map; -import lombok.Builder; + import lombok.Data; +import lombok.Builder; +import lombok.AllArgsConstructor; +import lombok.NoArgsConstructor; + @Data @Builder +@AllArgsConstructor +@NoArgsConstructor public class LLMSqlResp { private double sqlWeight;