diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/JavaLLMProxy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/JavaLLMProxy.java index 5af8e1d4c..10d093f10 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/JavaLLMProxy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/JavaLLMProxy.java @@ -12,7 +12,6 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.common.util.ContextUtils; import dev.langchain4j.model.chat.ChatLanguageModel; -import java.util.Map; import java.util.Objects; import lombok.extern.slf4j.Slf4j; import org.slf4j.Logger; @@ -43,12 +42,9 @@ public class JavaLLMProxy implements LLMProxy { SqlGeneration sqlGeneration = SqlGenerationFactory.get( SqlGenerationMode.getMode(llmReq.getSqlGenerationMode())); String modelName = llmReq.getSchema().getModelName(); - Map sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey); - - LLMResp result = new LLMResp(); + LLMResp result = sqlGeneration.generation(llmReq, modelClusterKey); result.setQuery(llmReq.getQueryText()); result.setModelName(modelName); - result.setSqlWeight(sqlWeight); return result; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/PythonLLMProxy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/PythonLLMProxy.java index 51f052919..cc75215c6 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/PythonLLMProxy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/PythonLLMProxy.java @@ -6,13 +6,16 @@ import com.tencent.supersonic.chat.config.LLMParserConfig; import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig; import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq; import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp; +import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.JsonUtil; import java.net.URI; import java.net.URL; +import java.util.ArrayList; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections4.MapUtils; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -59,10 +62,15 @@ public class PythonLLMProxy implements LLMProxy { ResponseEntity responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity, LLMResp.class); + LLMResp llmResp = responseEntity.getBody(); log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}", - System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody()); - keyPipelineLog.info("LLMResp:{}", responseEntity.getBody()); - return responseEntity.getBody(); + System.currentTimeMillis() - startTime, url, entity, llmResp); + keyPipelineLog.info("LLMResp:{}", llmResp); + + if (MapUtils.isEmpty(llmResp.getSqlRespMap())) { + llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(new ArrayList<>(), llmResp.getSqlWeight())); + } + return llmResp; } catch (Exception e) { log.error("requestLLM error", e); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMResponseService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMResponseService.java index f706523f5..5385cf461 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMResponseService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMResponseService.java @@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper; import lombok.extern.slf4j.Slf4j; @@ -46,12 +47,12 @@ public class LLMResponseService { return parseInfo; } - public Map getDeduplicationSqlWeight(LLMResp llmResp) { - if (MapUtils.isEmpty(llmResp.getSqlWeight())) { - return llmResp.getSqlWeight(); + public Map getDeduplicationSqlResp(LLMResp llmResp) { + if (MapUtils.isEmpty(llmResp.getSqlRespMap())) { + return llmResp.getSqlRespMap(); } - Map result = new HashMap<>(); - for (Map.Entry entry : llmResp.getSqlWeight().entrySet()) { + Map result = new HashMap<>(); + for (Map.Entry entry : llmResp.getSqlRespMap().entrySet()) { String key = entry.getKey(); if (result.keySet().stream().anyMatch(existKey -> SqlParserEqualHelper.equals(existKey, key))) { continue; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMSqlParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMSqlParser.java index a7f8eac28..e6c711219 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMSqlParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMSqlParser.java @@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp; import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.common.pojo.ModelCluster; import com.tencent.supersonic.common.util.ContextUtils; @@ -56,7 +57,7 @@ public class LLMSqlParser implements SemanticParser { //5. deduplicate the SQL result list and build parserInfo modelCluster.buildName(semanticSchema.getModelIdToName()); LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class); - Map deduplicationSqlWeight = responseService.getDeduplicationSqlWeight(llmResp); + Map deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp); ParseResult parseResult = ParseResult.builder() .request(request) .modelCluster(modelCluster) @@ -66,11 +67,11 @@ public class LLMSqlParser implements SemanticParser { .linkingValues(linkingValues) .build(); - if (MapUtils.isEmpty(deduplicationSqlWeight)) { + if (MapUtils.isEmpty(deduplicationSqlResp)) { responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D); } else { - deduplicationSqlWeight.forEach((sql, weight) -> { - responseService.addParseInfo(queryCtx, parseResult, sql, weight); + deduplicationSqlResp.forEach((sql, sqlResp) -> { + responseService.addParseInfo(queryCtx, parseResult, sql, sqlResp.getSqlWeight()); }); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/OnePassSCSqlGeneration.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/OnePassSCSqlGeneration.java index f9450b345..5a96d4c5b 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/OnePassSCSqlGeneration.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/OnePassSCSqlGeneration.java @@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.parser.sql.llm; import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.common.util.JsonUtil; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -41,9 +42,10 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean { private SqlPromptGenerator sqlPromptGenerator; @Override - public Map generation(LLMReq llmReq, String modelClusterKey) { + public LLMResp generation(LLMReq llmReq, String modelClusterKey) { //1.retriever sqlExamples and generate exampleListPool keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq); + List> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(), optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum()); @@ -70,9 +72,14 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean { Pair> linkingMap = OutputFormat.selfConsistencyVote(candidateSortedList); List sqlList = llmResults.stream() .map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList()); - Pair> sqlMap = OutputFormat.selfConsistencyVote(sqlList); - keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMap); - return sqlMap.getRight(); + + Pair> sqlMapPair = OutputFormat.selfConsistencyVote(sqlList); + keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight()); + + LLMResp result = new LLMResp(); + result.setQuery(llmReq.getQueryText()); + result.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMapPair.getRight())); + return result; } @Override diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/OnePassSqlGeneration.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/OnePassSqlGeneration.java index 1cb7d5959..6127ef1a4 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/OnePassSqlGeneration.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/OnePassSqlGeneration.java @@ -4,6 +4,8 @@ package com.tencent.supersonic.chat.parser.sql.llm; import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp; import com.tencent.supersonic.common.util.JsonUtil; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -38,7 +40,7 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean { private SqlPromptGenerator sqlPromptGenerator; @Override - public Map generation(LLMReq llmReq, String modelClusterKey) { + public LLMResp generation(LLMReq llmReq, String modelClusterKey) { //1.retriever sqlExamples keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq); List> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(), @@ -55,10 +57,14 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean { //3.format response. String schemaLinkStr = OutputFormat.getSchemaLinks(response.content().text()); String sql = OutputFormat.getSql(response.content().text()); - Map sqlMap = new HashMap<>(); - sqlMap.put(sql, 1D); - keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap); - return sqlMap; + Map sqlRespMap = new HashMap<>(); + sqlRespMap.put(sql, LLMSqlResp.builder().sqlWeight(1D).fewShots(sqlExamples).build()); + keyPipelineLog.info("schemaLinkStr:{},sqlRespMap:{}", schemaLinkStr, sqlRespMap); + + LLMResp llmResp = new LLMResp(); + llmResp.setQuery(llmReq.getQueryText()); + llmResp.setSqlRespMap(sqlRespMap); + return llmResp; } @Override diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/OutputFormat.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/OutputFormat.java index 191f10d7c..a60dedd9a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/OutputFormat.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/OutputFormat.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.parser.sql.llm; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -10,6 +11,7 @@ import java.util.List; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.tuple.Pair; @@ -19,8 +21,6 @@ import org.apache.commons.lang3.tuple.Pair; @Slf4j public class OutputFormat { - public static final String PATTERN = "\\{[^{}]+\\}"; - public static String getSchemaLink(String schemaLink) { String reult = ""; try { @@ -126,4 +126,13 @@ public class OutputFormat { } return null; } + + public static Map buildSqlRespMap(List> sqlExamples, + Map sqlMap) { + return sqlMap.entrySet().stream() + .collect(Collectors.toMap( + Map.Entry::getKey, + entry -> LLMSqlResp.builder().sqlWeight(entry.getValue()).fewShots(sqlExamples).build()) + ); + } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlGeneration.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlGeneration.java index 5b4794e1f..7372101b6 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlGeneration.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlGeneration.java @@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.parser.sql.llm; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; -import java.util.Map; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; /** * Sql Generation interface, generating SQL using a large model. @@ -10,11 +10,11 @@ import java.util.Map; public interface SqlGeneration { /*** - * generate SQL through LLMReq. + * generate llmResp (sql, schemaLink, prompt, etc.) through LLMReq. * @param llmReq * @param modelClusterKey * @return */ - Map generation(LLMReq llmReq, String modelClusterKey); + LLMResp generation(LLMReq llmReq, String modelClusterKey); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/TwoPassSCSqlGeneration.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/TwoPassSCSqlGeneration.java index 1aeb9d4e9..5c42d3f8f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/TwoPassSCSqlGeneration.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/TwoPassSCSqlGeneration.java @@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.parser.sql.llm; import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.common.util.JsonUtil; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -38,7 +39,7 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean { private SqlPromptGenerator sqlPromptGenerator; @Override - public Map generation(LLMReq llmReq, String modelClusterKey) { + public LLMResp generation(LLMReq llmReq, String modelClusterKey) { //1.retriever sqlExamples and generate exampleListPool keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq); List> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(), @@ -74,9 +75,13 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean { sqlTaskPool.add(result); }); //4.format response. - Pair> sqlMap = OutputFormat.selfConsistencyVote(sqlTaskPool); - keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMap); - return sqlMap.getRight(); + Pair> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool); + keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight()); + + LLMResp llmResp = new LLMResp(); + llmResp.setQuery(llmReq.getQueryText()); + llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMapPair.getRight())); + return llmResp; } @Override diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/TwoPassSqlGeneration.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/TwoPassSqlGeneration.java index fba015ca0..59eb5400f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/TwoPassSqlGeneration.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/TwoPassSqlGeneration.java @@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.parser.sql.llm; import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.common.util.JsonUtil; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -38,7 +39,7 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean { private SqlPromptGenerator sqlPromptGenerator; @Override - public Map generation(LLMReq llmReq, String modelClusterKey) { + public LLMResp generation(LLMReq llmReq, String modelClusterKey) { keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq); List> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(), optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum()); @@ -60,7 +61,11 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean { Map sqlMap = new HashMap<>(); sqlMap.put(result, 1D); keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap); - return sqlMap; + + LLMResp llmResp = new LLMResp(); + llmResp.setQuery(llmReq.getQueryText()); + llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMap)); + return llmResp; } @Override diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMResp.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMResp.java index 2c1d39ac4..a29ae537d 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMResp.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMResp.java @@ -15,9 +15,11 @@ public class LLMResp { private List fields; - private String schemaLinkingOutput; - - private String schemaLinkStr; + private Map sqlRespMap; + /** + * Only for compatibility with python code, later deleted + */ private Map sqlWeight; + } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMSqlResp.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMSqlResp.java new file mode 100644 index 000000000..38c36c66e --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMSqlResp.java @@ -0,0 +1,16 @@ +package com.tencent.supersonic.chat.query.llm.s2sql; + +import java.util.List; +import java.util.Map; +import lombok.Builder; +import lombok.Data; + +@Data +@Builder +public class LLMSqlResp { + + private double sqlWeight; + + private List> fewShots; + +} diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMResponseServiceTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMResponseServiceTest.java index c25e6fd7f..6a044ae8d 100644 --- a/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMResponseServiceTest.java +++ b/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMResponseServiceTest.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.parser.llm.s2sql; import com.tencent.supersonic.chat.parser.sql.llm.LLMResponseService; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp; import java.util.HashMap; import java.util.Map; import org.junit.Assert; @@ -15,38 +16,40 @@ class LLMResponseServiceTest { String sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; LLMResp llmResp = new LLMResp(); - Map sqlWeight = new HashMap<>(); - sqlWeight.put(sql1, 0.2D); - sqlWeight.put(sql2, 0.8D); - llmResp.setSqlWeight(sqlWeight); - LLMResponseService llmResponseService = new LLMResponseService(); - Map deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp); + Map sqlWeight = new HashMap<>(); + sqlWeight.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build()); + sqlWeight.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build()); - Assert.assertEquals(deduplicationSqlWeight.size(), 1); + llmResp.setSqlRespMap(sqlWeight); + LLMResponseService llmResponseService = new LLMResponseService(); + Map deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp); + + Assert.assertEquals(deduplicationSqlResp.size(), 1); sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a"; sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; LLMResp llmResp2 = new LLMResp(); - Map sqlWeight2 = new HashMap<>(); - sqlWeight2.put(sql1, 0.2D); - sqlWeight2.put(sql2, 0.8D); - llmResp2.setSqlWeight(sqlWeight2); - deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp2); + Map sqlWeight2 = new HashMap<>(); + sqlWeight2.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build()); + sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build()); - Assert.assertEquals(deduplicationSqlWeight.size(), 1); + llmResp2.setSqlRespMap(sqlWeight2); + deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp2); + + Assert.assertEquals(deduplicationSqlResp.size(), 1); sql1 = "SELECT a,b,c,d,e FROM table1 WHERE column1 = 1 AND column2 = 2 order by a"; sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; LLMResp llmResp3 = new LLMResp(); - Map sqlWeight3 = new HashMap<>(); - sqlWeight3.put(sql1, 0.2D); - sqlWeight3.put(sql2, 0.8D); - llmResp3.setSqlWeight(sqlWeight3); - deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp3); + Map sqlWeight3 = new HashMap<>(); + sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build()); + sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build()); + llmResp3.setSqlRespMap(sqlWeight3); + deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp3); - Assert.assertEquals(deduplicationSqlWeight.size(), 2); + Assert.assertEquals(deduplicationSqlResp.size(), 2); } } \ No newline at end of file