(improvement)(chat) The large model parsing supports SQL result verification and adds three retries (#1194)

This commit is contained in:
lexluo09
2024-06-22 22:21:51 +08:00
committed by GitHub
parent 32e2c1e39d
commit 29694be64e
10 changed files with 131 additions and 87 deletions

View File

@@ -2,13 +2,15 @@ package com.tencent.supersonic.common.jsqlparser;
import java.util.List; import java.util.List;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
/** /**
* Sql Parser equal Helper * Sql Parser valid Helper
*/ */
@Slf4j @Slf4j
public class SqlEqualHelper { public class SqlValidHelper {
/** /**
* determine if two SQL statements are equal. * determine if two SQL statements are equal.
@@ -63,5 +65,15 @@ public class SqlEqualHelper {
return true; return true;
} }
public static boolean isValidSQL(String sql) {
try {
CCJSqlParserUtil.parse(sql);
return true;
} catch (JSQLParserException e) {
log.error("isValidSQL parse:{}", e);
return false;
}
}
} }

View File

@@ -11,7 +11,7 @@ class ChatModelProperties {
String apiKey; String apiKey;
Double temperature; Double temperature;
Double topP; Double topP;
String model; String modelName;
Integer maxRetries; Integer maxRetries;
Integer maxToken; Integer maxToken;
Boolean logRequests; Boolean logRequests;

View File

@@ -21,7 +21,7 @@ public class ZhipuAutoConfig {
return ZhipuAiChatModel.builder() return ZhipuAiChatModel.builder()
.baseUrl(chatModelProperties.getBaseUrl()) .baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey()) .apiKey(chatModelProperties.getApiKey())
.model(chatModelProperties.getModel()) .model(chatModelProperties.getModelName())
.temperature(chatModelProperties.getTemperature()) .temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP()) .topP(chatModelProperties.getTopP())
.maxRetries(chatModelProperties.getMaxRetries()) .maxRetries(chatModelProperties.getMaxRetries())
@@ -38,7 +38,7 @@ public class ZhipuAutoConfig {
return ZhipuAiStreamingChatModel.builder() return ZhipuAiStreamingChatModel.builder()
.baseUrl(chatModelProperties.getBaseUrl()) .baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey()) .apiKey(chatModelProperties.getApiKey())
.model(chatModelProperties.getModel()) .model(chatModelProperties.getModelName())
.temperature(chatModelProperties.getTemperature()) .temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP()) .topP(chatModelProperties.getTopP())
.maxToken(chatModelProperties.getMaxToken()) .maxToken(chatModelProperties.getMaxToken())

View File

@@ -4,39 +4,33 @@ package com.tencent.supersonic.common.jsqlparser;
import org.junit.Assert; import org.junit.Assert;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
/** class SqlValidHelperTest {
* @author lex luo
* @date 2023/11/15 15:04
*/
class SqlEqualHelperTest {
@Test @Test
void testEquals() { void testEquals() {
String sql1 = "SELECT * FROM table1 WHERE column1 = 1 AND column2 = 2"; String sql1 = "SELECT * FROM table1 WHERE column1 = 1 AND column2 = 2";
String sql2 = "SELECT * FROM table1 WHERE column2 = 2 AND column1 = 1"; String sql2 = "SELECT * FROM table1 WHERE column2 = 2 AND column1 = 1";
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true); Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a"; sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true); Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a"; sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a";
sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a"; sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a";
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true); Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a"; sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a";
sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a"; sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a";
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true); Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a"; sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
sql2 = "SELECT d,c,b,f FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; sql2 = "SELECT d,c,b,f FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), false); Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), false);
sql1 = "SELECT\n" sql1 = "SELECT\n"
+ "页面,\n" + "页面,\n"
@@ -65,6 +59,27 @@ class SqlEqualHelperTest {
+ "页面\n" + "页面\n"
+ "LIMIT\n" + "LIMIT\n"
+ "365"; + "365";
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true); Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
}
@Test
void testIsValidSQL() {
String sql1 = "SELECT * FROM table1 WHERE column1 = 1 AND column2 = 2";
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), true);
sql1 = "SELECT sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2";
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), true);
sql1 = "SELECT a,b,c, FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false);
sql1 = "SELECTa,b,c,d FROM table1";
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false);
sql1 = "SELECT sum(b),sum(c),sum(d) FROM table1 WHERE";
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false);
} }
} }

View File

@@ -15,6 +15,9 @@ public class LLMParserConfig {
@Value("${s2.query2sql.path:/query2sql}") @Value("${s2.query2sql.path:/query2sql}")
private String queryToSqlPath; private String queryToSqlPath;
@Value("${s2.recall.max.retries:3}")
private int recallMaxRetries;
@Value("${s2.dimension.topn:10}") @Value("${s2.dimension.topn:10}")
private Integer dimensionTopN; private Integer dimensionTopN;

View File

@@ -1,25 +1,23 @@
package com.tencent.supersonic.headless.chat.parser.llm; package com.tencent.supersonic.headless.chat.parser.llm;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE;
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum; import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
import com.tencent.supersonic.headless.chat.QueryContext; import com.tencent.supersonic.headless.chat.QueryContext;
import com.tencent.supersonic.headless.chat.utils.S2SqlDateHelper; import com.tencent.supersonic.headless.chat.parser.ParserConfig;
import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker; import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j; import com.tencent.supersonic.headless.chat.utils.S2SqlDateHelper;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
@@ -29,9 +27,12 @@ import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE; import org.apache.commons.lang3.StringUtils;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE; import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Slf4j @Slf4j
@Service @Service
@@ -62,8 +63,10 @@ public class LLMRequestService {
return dataSetResolver.resolve(queryCtx, queryCtx.getDataSetIds()); return dataSetResolver.resolve(queryCtx, queryCtx.getDataSetIds());
} }
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId, public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId) {
SemanticSchema semanticSchema, List<LLMReq.ElementValue> linkingValues) { LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
List<LLMReq.ElementValue> linkingValues = requestService.getValues(queryCtx, dataSetId);
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName(); Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
String queryText = queryCtx.getQueryText(); String queryText = queryCtx.getQueryText();
@@ -114,7 +117,7 @@ public class LLMRequestService {
} }
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId, protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
LLMParserConfig llmParserConfig) { LLMParserConfig llmParserConfig) {
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig); Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.headless.chat.parser.llm; package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.jsqlparser.SqlEqualHelper; import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.query.QueryManager; import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery; import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
@@ -9,6 +9,7 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.headless.chat.QueryContext; import com.tencent.supersonic.headless.chat.QueryContext;
import java.util.ArrayList;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils; import org.apache.commons.collections.MapUtils;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -41,14 +42,20 @@ public class LLMResponseService {
return parseInfo; return parseInfo;
} }
public Map<String, LLMSqlResp> getDeduplicationSqlResp(LLMResp llmResp) { public Map<String, LLMSqlResp> getDeduplicationSqlResp(int currentRetry, LLMResp llmResp) {
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) { Map<String, LLMSqlResp> sqlRespMap = llmResp.getSqlRespMap();
return llmResp.getSqlRespMap(); if (MapUtils.isEmpty(sqlRespMap)) {
LLMSqlResp llmSqlResp = new LLMSqlResp(1D, new ArrayList<>());
sqlRespMap.put(llmResp.getSqlOutput(), llmSqlResp);
} }
Map<String, LLMSqlResp> result = new HashMap<>(); Map<String, LLMSqlResp> result = new HashMap<>();
for (Map.Entry<String, LLMSqlResp> entry : llmResp.getSqlRespMap().entrySet()) { for (Map.Entry<String, LLMSqlResp> entry : sqlRespMap.entrySet()) {
String key = entry.getKey(); String key = entry.getKey();
if (result.keySet().stream().anyMatch(existKey -> SqlEqualHelper.equals(existKey, key))) { if (result.keySet().stream().anyMatch(existKey -> SqlValidHelper.equals(existKey, key))) {
continue;
}
if (!SqlValidHelper.isValidSQL(key)) {
log.error("currentRetry:{},sql is not valid:{}", currentRetry, key);
continue; continue;
} }
result.put(key, entry.getValue()); result.put(key, entry.getValue());

View File

@@ -2,20 +2,18 @@ package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.QueryContext; import com.tencent.supersonic.headless.chat.QueryContext;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.headless.chat.parser.SemanticParser; import java.util.HashMap;
import com.tencent.supersonic.headless.chat.ChatContext; import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils; import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
import java.util.Map;
import java.util.Objects;
/** /**
* LLMSqlParser uses large language model to understand query semantics and * LLMSqlParser uses large language model to understand query semantics and
@@ -26,12 +24,12 @@ public class LLMSqlParser implements SemanticParser {
@Override @Override
public void parse(QueryContext queryCtx, ChatContext chatCtx) { public void parse(QueryContext queryCtx, ChatContext chatCtx) {
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
//1.determine whether to skip this parser.
if (requestService.isSkip(queryCtx)) {
return;
}
try { try {
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
//1.determine whether to skip this parser.
if (requestService.isSkip(queryCtx)) {
return;
}
//2.get dataSetId from queryCtx and chatCtx. //2.get dataSetId from queryCtx and chatCtx.
Long dataSetId = requestService.getDataSetId(queryCtx); Long dataSetId = requestService.getDataSetId(queryCtx);
if (dataSetId == null) { if (dataSetId == null) {
@@ -40,39 +38,46 @@ public class LLMSqlParser implements SemanticParser {
log.info("Generate query statement for dataSetId:{}", dataSetId); log.info("Generate query statement for dataSetId:{}", dataSetId);
//3.invoke LLM service to do parsing. //3.invoke LLM service to do parsing.
List<LLMReq.ElementValue> linkingValues = requestService.getValues(queryCtx, dataSetId); LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId);
SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); recall(queryCtx, dataSetId, llmReq);
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues);
LLMResp llmResp = requestService.runText2SQL(llmReq);
if (Objects.isNull(llmResp)) {
return;
}
//4. deduplicate the S2SQL result list and build parserInfo
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
ParseResult parseResult = ParseResult.builder()
.dataSetId(dataSetId)
.llmReq(llmReq)
.llmResp(llmResp)
.linkingValues(linkingValues)
.build();
if (MapUtils.isEmpty(deduplicationSqlResp)) {
if (StringUtils.isNotBlank(llmResp.getSqlOutput())) {
responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D);
}
} else {
deduplicationSqlResp.forEach((sql, sqlResp) -> {
if (StringUtils.isNotBlank(sql)) {
responseService.addParseInfo(queryCtx, parseResult, sql, sqlResp.getSqlWeight());
}
});
}
} catch (Exception e) { } catch (Exception e) {
log.error("Failed to parse query:", e); log.error("Failed to parse query:", e);
} }
} }
private void recall(QueryContext queryCtx, Long dataSetId, LLMReq llmReq) {
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
int maxRetries = ContextUtils.getBean(LLMParserConfig.class).getRecallMaxRetries();
int currentRetry = 0;
Map<String, LLMSqlResp> sqlRespMap = new HashMap<>();
ParseResult parseResult = null;
while (currentRetry < maxRetries) {
log.info("currentRetry:{},start runText2SQL", currentRetry);
try {
LLMResp llmResp = requestService.runText2SQL(llmReq);
if (Objects.nonNull(llmResp)) {
//deduplicate the S2SQL result list and build parserInfo
sqlRespMap = responseService.getDeduplicationSqlResp(currentRetry, llmResp);
if (MapUtils.isNotEmpty(sqlRespMap)) {
parseResult = ParseResult.builder()
.dataSetId(dataSetId).llmReq(llmReq).llmResp(llmResp)
.linkingValues(llmReq.getLinking()).build();
break;
}
}
} catch (Exception e) {
log.error("currentRetry:{},runText2SQL error", currentRetry, e);
}
currentRetry++;
}
if (MapUtils.isNotEmpty(sqlRespMap)) {
for (Entry<String, LLMSqlResp> entry : sqlRespMap.entrySet()) {
String sql = entry.getKey();
double sqlWeight = entry.getValue().getSqlWeight();
responseService.addParseInfo(queryCtx, parseResult, sql, sqlWeight);
}
}
}
} }

View File

@@ -26,7 +26,6 @@ public class LLMReq {
private SqlGenType sqlGenType; private SqlGenType sqlGenType;
private LLMConfig llmConfig; private LLMConfig llmConfig;
@Data @Data
public static class ElementValue { public static class ElementValue {

View File

@@ -23,7 +23,7 @@ class LLMResponseServiceTest {
llmResp.setSqlRespMap(sqlWeight); llmResp.setSqlRespMap(sqlWeight);
LLMResponseService llmResponseService = new LLMResponseService(); LLMResponseService llmResponseService = new LLMResponseService();
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp); Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(0, llmResp);
Assert.assertEquals(deduplicationSqlResp.size(), 1); Assert.assertEquals(deduplicationSqlResp.size(), 1);
@@ -36,7 +36,7 @@ class LLMResponseServiceTest {
sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build()); sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
llmResp2.setSqlRespMap(sqlWeight2); llmResp2.setSqlRespMap(sqlWeight2);
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp2); deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(0, llmResp2);
Assert.assertEquals(deduplicationSqlResp.size(), 1); Assert.assertEquals(deduplicationSqlResp.size(), 1);
@@ -48,7 +48,7 @@ class LLMResponseServiceTest {
sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build()); sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build()); sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
llmResp3.setSqlRespMap(sqlWeight3); llmResp3.setSqlRespMap(sqlWeight3);
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp3); deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(0, llmResp3);
Assert.assertEquals(deduplicationSqlResp.size(), 2); Assert.assertEquals(deduplicationSqlResp.size(), 2);