mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 22:25:19 +00:00
(improvement)(chat) Add 'Few-shot Examples' display to the Chat chart. (#589)
This commit is contained in:
@@ -12,7 +12,6 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
@@ -43,12 +42,9 @@ public class JavaLLMProxy implements LLMProxy {
|
||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
||||
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
|
||||
String modelName = llmReq.getSchema().getModelName();
|
||||
Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey);
|
||||
|
||||
LLMResp result = new LLMResp();
|
||||
LLMResp result = sqlGeneration.generation(llmReq, modelClusterKey);
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
result.setModelName(modelName);
|
||||
result.setSqlWeight(sqlWeight);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
@@ -6,13 +6,16 @@ import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
import java.util.ArrayList;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -59,10 +62,15 @@ public class PythonLLMProxy implements LLMProxy {
|
||||
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
|
||||
LLMResp.class);
|
||||
|
||||
LLMResp llmResp = responseEntity.getBody();
|
||||
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
||||
System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody());
|
||||
keyPipelineLog.info("LLMResp:{}", responseEntity.getBody());
|
||||
return responseEntity.getBody();
|
||||
System.currentTimeMillis() - startTime, url, entity, llmResp);
|
||||
keyPipelineLog.info("LLMResp:{}", llmResp);
|
||||
|
||||
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
|
||||
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(new ArrayList<>(), llmResp.getSqlWeight()));
|
||||
}
|
||||
return llmResp;
|
||||
} catch (Exception e) {
|
||||
log.error("requestLLM error", e);
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -46,12 +47,12 @@ public class LLMResponseService {
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
public Map<String, Double> getDeduplicationSqlWeight(LLMResp llmResp) {
|
||||
if (MapUtils.isEmpty(llmResp.getSqlWeight())) {
|
||||
return llmResp.getSqlWeight();
|
||||
public Map<String, LLMSqlResp> getDeduplicationSqlResp(LLMResp llmResp) {
|
||||
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
|
||||
return llmResp.getSqlRespMap();
|
||||
}
|
||||
Map<String, Double> result = new HashMap<>();
|
||||
for (Map.Entry<String, Double> entry : llmResp.getSqlWeight().entrySet()) {
|
||||
Map<String, LLMSqlResp> result = new HashMap<>();
|
||||
for (Map.Entry<String, LLMSqlResp> entry : llmResp.getSqlRespMap().entrySet()) {
|
||||
String key = entry.getKey();
|
||||
if (result.keySet().stream().anyMatch(existKey -> SqlParserEqualHelper.equals(existKey, key))) {
|
||||
continue;
|
||||
|
||||
@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -56,7 +57,7 @@ public class LLMSqlParser implements SemanticParser {
|
||||
//5. deduplicate the SQL result list and build parserInfo
|
||||
modelCluster.buildName(semanticSchema.getModelIdToName());
|
||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||
Map<String, Double> deduplicationSqlWeight = responseService.getDeduplicationSqlWeight(llmResp);
|
||||
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
|
||||
ParseResult parseResult = ParseResult.builder()
|
||||
.request(request)
|
||||
.modelCluster(modelCluster)
|
||||
@@ -66,11 +67,11 @@ public class LLMSqlParser implements SemanticParser {
|
||||
.linkingValues(linkingValues)
|
||||
.build();
|
||||
|
||||
if (MapUtils.isEmpty(deduplicationSqlWeight)) {
|
||||
if (MapUtils.isEmpty(deduplicationSqlResp)) {
|
||||
responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D);
|
||||
} else {
|
||||
deduplicationSqlWeight.forEach((sql, weight) -> {
|
||||
responseService.addParseInfo(queryCtx, parseResult, sql, weight);
|
||||
deduplicationSqlResp.forEach((sql, sqlResp) -> {
|
||||
responseService.addParseInfo(queryCtx, parseResult, sql, sqlResp.getSqlWeight());
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
@@ -41,9 +42,10 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
|
||||
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
@@ -70,9 +72,14 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
Pair<String, Map<String, Double>> linkingMap = OutputFormat.selfConsistencyVote(candidateSortedList);
|
||||
List<String> sqlList = llmResults.stream()
|
||||
.map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList());
|
||||
Pair<String, Map<String, Double>> sqlMap = OutputFormat.selfConsistencyVote(sqlList);
|
||||
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMap);
|
||||
return sqlMap.getRight();
|
||||
|
||||
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlList);
|
||||
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
|
||||
|
||||
LLMResp result = new LLMResp();
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
result.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMapPair.getRight()));
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -4,6 +4,8 @@ package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
@@ -38,7 +40,7 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
|
||||
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
||||
//1.retriever sqlExamples
|
||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
@@ -55,10 +57,14 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
//3.format response.
|
||||
String schemaLinkStr = OutputFormat.getSchemaLinks(response.content().text());
|
||||
String sql = OutputFormat.getSql(response.content().text());
|
||||
Map<String, Double> sqlMap = new HashMap<>();
|
||||
sqlMap.put(sql, 1D);
|
||||
keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap);
|
||||
return sqlMap;
|
||||
Map<String, LLMSqlResp> sqlRespMap = new HashMap<>();
|
||||
sqlRespMap.put(sql, LLMSqlResp.builder().sqlWeight(1D).fewShots(sqlExamples).build());
|
||||
keyPipelineLog.info("schemaLinkStr:{},sqlRespMap:{}", schemaLinkStr, sqlRespMap);
|
||||
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
llmResp.setSqlRespMap(sqlRespMap);
|
||||
return llmResp;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
@@ -10,6 +11,7 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
@@ -19,8 +21,6 @@ import org.apache.commons.lang3.tuple.Pair;
|
||||
@Slf4j
|
||||
public class OutputFormat {
|
||||
|
||||
public static final String PATTERN = "\\{[^{}]+\\}";
|
||||
|
||||
public static String getSchemaLink(String schemaLink) {
|
||||
String reult = "";
|
||||
try {
|
||||
@@ -126,4 +126,13 @@ public class OutputFormat {
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static Map<String, LLMSqlResp> buildSqlRespMap(List<Map<String, String>> sqlExamples,
|
||||
Map<String, Double> sqlMap) {
|
||||
return sqlMap.entrySet().stream()
|
||||
.collect(Collectors.toMap(
|
||||
Map.Entry::getKey,
|
||||
entry -> LLMSqlResp.builder().sqlWeight(entry.getValue()).fewShots(sqlExamples).build())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import java.util.Map;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
|
||||
/**
|
||||
* Sql Generation interface, generating SQL using a large model.
|
||||
@@ -10,11 +10,11 @@ import java.util.Map;
|
||||
public interface SqlGeneration {
|
||||
|
||||
/***
|
||||
* generate SQL through LLMReq.
|
||||
* generate llmResp (sql, schemaLink, prompt, etc.) through LLMReq.
|
||||
* @param llmReq
|
||||
* @param modelClusterKey
|
||||
* @return
|
||||
*/
|
||||
Map<String, Double> generation(LLMReq llmReq, String modelClusterKey);
|
||||
LLMResp generation(LLMReq llmReq, String modelClusterKey);
|
||||
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
@@ -38,7 +39,7 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
|
||||
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
@@ -74,9 +75,13 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
sqlTaskPool.add(result);
|
||||
});
|
||||
//4.format response.
|
||||
Pair<String, Map<String, Double>> sqlMap = OutputFormat.selfConsistencyVote(sqlTaskPool);
|
||||
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMap);
|
||||
return sqlMap.getRight();
|
||||
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool);
|
||||
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
|
||||
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMapPair.getRight()));
|
||||
return llmResp;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
@@ -38,7 +39,7 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
|
||||
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
@@ -60,7 +61,11 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
Map<String, Double> sqlMap = new HashMap<>();
|
||||
sqlMap.put(result, 1D);
|
||||
keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap);
|
||||
return sqlMap;
|
||||
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMap));
|
||||
return llmResp;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -15,9 +15,11 @@ public class LLMResp {
|
||||
|
||||
private List<String> fields;
|
||||
|
||||
private String schemaLinkingOutput;
|
||||
|
||||
private String schemaLinkStr;
|
||||
private Map<String, LLMSqlResp> sqlRespMap;
|
||||
|
||||
/**
|
||||
* Only for compatibility with python code, later deleted
|
||||
*/
|
||||
private Map<String, Double> sqlWeight;
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.tencent.supersonic.chat.query.llm.s2sql;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
public class LLMSqlResp {
|
||||
|
||||
private double sqlWeight;
|
||||
|
||||
private List<Map<String, String>> fewShots;
|
||||
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.LLMResponseService;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.junit.Assert;
|
||||
@@ -15,38 +16,40 @@ class LLMResponseServiceTest {
|
||||
String sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp = new LLMResp();
|
||||
Map<String, Double> sqlWeight = new HashMap<>();
|
||||
sqlWeight.put(sql1, 0.2D);
|
||||
sqlWeight.put(sql2, 0.8D);
|
||||
llmResp.setSqlWeight(sqlWeight);
|
||||
LLMResponseService llmResponseService = new LLMResponseService();
|
||||
Map<String, Double> deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp);
|
||||
Map<String, LLMSqlResp> sqlWeight = new HashMap<>();
|
||||
sqlWeight.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||
sqlWeight.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||
|
||||
Assert.assertEquals(deduplicationSqlWeight.size(), 1);
|
||||
llmResp.setSqlRespMap(sqlWeight);
|
||||
LLMResponseService llmResponseService = new LLMResponseService();
|
||||
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
||||
|
||||
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp2 = new LLMResp();
|
||||
Map<String, Double> sqlWeight2 = new HashMap<>();
|
||||
sqlWeight2.put(sql1, 0.2D);
|
||||
sqlWeight2.put(sql2, 0.8D);
|
||||
llmResp2.setSqlWeight(sqlWeight2);
|
||||
deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp2);
|
||||
Map<String, LLMSqlResp> sqlWeight2 = new HashMap<>();
|
||||
sqlWeight2.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||
sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||
|
||||
Assert.assertEquals(deduplicationSqlWeight.size(), 1);
|
||||
llmResp2.setSqlRespMap(sqlWeight2);
|
||||
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp2);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
||||
|
||||
sql1 = "SELECT a,b,c,d,e FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp3 = new LLMResp();
|
||||
Map<String, Double> sqlWeight3 = new HashMap<>();
|
||||
sqlWeight3.put(sql1, 0.2D);
|
||||
sqlWeight3.put(sql2, 0.8D);
|
||||
llmResp3.setSqlWeight(sqlWeight3);
|
||||
deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp3);
|
||||
Map<String, LLMSqlResp> sqlWeight3 = new HashMap<>();
|
||||
sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||
sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||
llmResp3.setSqlRespMap(sqlWeight3);
|
||||
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp3);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlWeight.size(), 2);
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 2);
|
||||
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user