(improvement)(chat) Add 'Few-shot Examples' display to the Chat chart. (#589)

This commit is contained in:
lexluo09
2024-01-02 18:07:15 +08:00
committed by GitHub
parent e7f13572d7
commit d72166944c
13 changed files with 118 additions and 59 deletions

View File

@@ -12,7 +12,6 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger; import org.slf4j.Logger;
@@ -43,12 +42,9 @@ public class JavaLLMProxy implements LLMProxy {
SqlGeneration sqlGeneration = SqlGenerationFactory.get( SqlGeneration sqlGeneration = SqlGenerationFactory.get(
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode())); SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
String modelName = llmReq.getSchema().getModelName(); String modelName = llmReq.getSchema().getModelName();
Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey); LLMResp result = sqlGeneration.generation(llmReq, modelClusterKey);
LLMResp result = new LLMResp();
result.setQuery(llmReq.getQueryText()); result.setQuery(llmReq.getQueryText());
result.setModelName(modelName); result.setModelName(modelName);
result.setSqlWeight(sqlWeight);
return result; return result;
} }

View File

@@ -6,13 +6,16 @@ import com.tencent.supersonic.chat.config.LLMParserConfig;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig; import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq; import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp; import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import java.net.URI; import java.net.URI;
import java.net.URL; import java.net.URL;
import java.util.ArrayList;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@@ -59,10 +62,15 @@ public class PythonLLMProxy implements LLMProxy {
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity, ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
LLMResp.class); LLMResp.class);
LLMResp llmResp = responseEntity.getBody();
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}", log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody()); System.currentTimeMillis() - startTime, url, entity, llmResp);
keyPipelineLog.info("LLMResp:{}", responseEntity.getBody()); keyPipelineLog.info("LLMResp:{}", llmResp);
return responseEntity.getBody();
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(new ArrayList<>(), llmResp.getSqlWeight()));
}
return llmResp;
} catch (Exception e) { } catch (Exception e) {
log.error("requestLLM error", e); log.error("requestLLM error", e);
} }

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery; import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -46,12 +47,12 @@ public class LLMResponseService {
return parseInfo; return parseInfo;
} }
public Map<String, Double> getDeduplicationSqlWeight(LLMResp llmResp) { public Map<String, LLMSqlResp> getDeduplicationSqlResp(LLMResp llmResp) {
if (MapUtils.isEmpty(llmResp.getSqlWeight())) { if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
return llmResp.getSqlWeight(); return llmResp.getSqlRespMap();
} }
Map<String, Double> result = new HashMap<>(); Map<String, LLMSqlResp> result = new HashMap<>();
for (Map.Entry<String, Double> entry : llmResp.getSqlWeight().entrySet()) { for (Map.Entry<String, LLMSqlResp> entry : llmResp.getSqlRespMap().entrySet()) {
String key = entry.getKey(); String key = entry.getKey();
if (result.keySet().stream().anyMatch(existKey -> SqlParserEqualHelper.equals(existKey, key))) { if (result.keySet().stream().anyMatch(existKey -> SqlParserEqualHelper.equals(existKey, key))) {
continue; continue;

View File

@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.ModelCluster; import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
@@ -56,7 +57,7 @@ public class LLMSqlParser implements SemanticParser {
//5. deduplicate the SQL result list and build parserInfo //5. deduplicate the SQL result list and build parserInfo
modelCluster.buildName(semanticSchema.getModelIdToName()); modelCluster.buildName(semanticSchema.getModelIdToName());
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class); LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
Map<String, Double> deduplicationSqlWeight = responseService.getDeduplicationSqlWeight(llmResp); Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
ParseResult parseResult = ParseResult.builder() ParseResult parseResult = ParseResult.builder()
.request(request) .request(request)
.modelCluster(modelCluster) .modelCluster(modelCluster)
@@ -66,11 +67,11 @@ public class LLMSqlParser implements SemanticParser {
.linkingValues(linkingValues) .linkingValues(linkingValues)
.build(); .build();
if (MapUtils.isEmpty(deduplicationSqlWeight)) { if (MapUtils.isEmpty(deduplicationSqlResp)) {
responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D); responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D);
} else { } else {
deduplicationSqlWeight.forEach((sql, weight) -> { deduplicationSqlResp.forEach((sql, sqlResp) -> {
responseService.addParseInfo(queryCtx, parseResult, sql, weight); responseService.addParseInfo(queryCtx, parseResult, sql, sqlResp.getSqlWeight());
}); });
} }

View File

@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -41,9 +42,10 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
private SqlPromptGenerator sqlPromptGenerator; private SqlPromptGenerator sqlPromptGenerator;
@Override @Override
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) { public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
//1.retriever sqlExamples and generate exampleListPool //1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq); keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(), List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum()); optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
@@ -70,9 +72,14 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
Pair<String, Map<String, Double>> linkingMap = OutputFormat.selfConsistencyVote(candidateSortedList); Pair<String, Map<String, Double>> linkingMap = OutputFormat.selfConsistencyVote(candidateSortedList);
List<String> sqlList = llmResults.stream() List<String> sqlList = llmResults.stream()
.map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList()); .map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList());
Pair<String, Map<String, Double>> sqlMap = OutputFormat.selfConsistencyVote(sqlList);
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMap); Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlList);
return sqlMap.getRight(); keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
LLMResp result = new LLMResp();
result.setQuery(llmReq.getQueryText());
result.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMapPair.getRight()));
return result;
} }
@Override @Override

View File

@@ -4,6 +4,8 @@ package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -38,7 +40,7 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
private SqlPromptGenerator sqlPromptGenerator; private SqlPromptGenerator sqlPromptGenerator;
@Override @Override
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) { public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
//1.retriever sqlExamples //1.retriever sqlExamples
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq); keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(), List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
@@ -55,10 +57,14 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
//3.format response. //3.format response.
String schemaLinkStr = OutputFormat.getSchemaLinks(response.content().text()); String schemaLinkStr = OutputFormat.getSchemaLinks(response.content().text());
String sql = OutputFormat.getSql(response.content().text()); String sql = OutputFormat.getSql(response.content().text());
Map<String, Double> sqlMap = new HashMap<>(); Map<String, LLMSqlResp> sqlRespMap = new HashMap<>();
sqlMap.put(sql, 1D); sqlRespMap.put(sql, LLMSqlResp.builder().sqlWeight(1D).fewShots(sqlExamples).build());
keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap); keyPipelineLog.info("schemaLinkStr:{},sqlRespMap:{}", schemaLinkStr, sqlRespMap);
return sqlMap;
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
llmResp.setSqlRespMap(sqlRespMap);
return llmResp;
} }
@Override @Override

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp; import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
@@ -10,6 +11,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
@@ -19,8 +21,6 @@ import org.apache.commons.lang3.tuple.Pair;
@Slf4j @Slf4j
public class OutputFormat { public class OutputFormat {
public static final String PATTERN = "\\{[^{}]+\\}";
public static String getSchemaLink(String schemaLink) { public static String getSchemaLink(String schemaLink) {
String reult = ""; String reult = "";
try { try {
@@ -126,4 +126,13 @@ public class OutputFormat {
} }
return null; return null;
} }
public static Map<String, LLMSqlResp> buildSqlRespMap(List<Map<String, String>> sqlExamples,
Map<String, Double> sqlMap) {
return sqlMap.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
entry -> LLMSqlResp.builder().sqlWeight(entry.getValue()).fewShots(sqlExamples).build())
);
}
} }

View File

@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import java.util.Map; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
/** /**
* Sql Generation interface, generating SQL using a large model. * Sql Generation interface, generating SQL using a large model.
@@ -10,11 +10,11 @@ import java.util.Map;
public interface SqlGeneration { public interface SqlGeneration {
/*** /***
* generate SQL through LLMReq. * generate llmResp (sql, schemaLink, prompt, etc.) through LLMReq.
* @param llmReq * @param llmReq
* @param modelClusterKey * @param modelClusterKey
* @return * @return
*/ */
Map<String, Double> generation(LLMReq llmReq, String modelClusterKey); LLMResp generation(LLMReq llmReq, String modelClusterKey);
} }

View File

@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -38,7 +39,7 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
private SqlPromptGenerator sqlPromptGenerator; private SqlPromptGenerator sqlPromptGenerator;
@Override @Override
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) { public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
//1.retriever sqlExamples and generate exampleListPool //1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq); keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(), List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
@@ -74,9 +75,13 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
sqlTaskPool.add(result); sqlTaskPool.add(result);
}); });
//4.format response. //4.format response.
Pair<String, Map<String, Double>> sqlMap = OutputFormat.selfConsistencyVote(sqlTaskPool); Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool);
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMap); keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
return sqlMap.getRight();
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMapPair.getRight()));
return llmResp;
} }
@Override @Override

View File

@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -38,7 +39,7 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
private SqlPromptGenerator sqlPromptGenerator; private SqlPromptGenerator sqlPromptGenerator;
@Override @Override
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) { public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq); keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(), List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum()); optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
@@ -60,7 +61,11 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
Map<String, Double> sqlMap = new HashMap<>(); Map<String, Double> sqlMap = new HashMap<>();
sqlMap.put(result, 1D); sqlMap.put(result, 1D);
keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap); keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap);
return sqlMap;
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMap));
return llmResp;
} }
@Override @Override

View File

@@ -15,9 +15,11 @@ public class LLMResp {
private List<String> fields; private List<String> fields;
private String schemaLinkingOutput; private Map<String, LLMSqlResp> sqlRespMap;
private String schemaLinkStr;
/**
* Only for compatibility with python code, later deleted
*/
private Map<String, Double> sqlWeight; private Map<String, Double> sqlWeight;
} }

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.chat.query.llm.s2sql;
import java.util.List;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
@Data
@Builder
public class LLMSqlResp {
private double sqlWeight;
private List<Map<String, String>> fewShots;
}

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.parser.sql.llm.LLMResponseService; import com.tencent.supersonic.chat.parser.sql.llm.LLMResponseService;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.junit.Assert; import org.junit.Assert;
@@ -15,38 +16,40 @@ class LLMResponseServiceTest {
String sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; String sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
LLMResp llmResp = new LLMResp(); LLMResp llmResp = new LLMResp();
Map<String, Double> sqlWeight = new HashMap<>(); Map<String, LLMSqlResp> sqlWeight = new HashMap<>();
sqlWeight.put(sql1, 0.2D); sqlWeight.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
sqlWeight.put(sql2, 0.8D); sqlWeight.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
llmResp.setSqlWeight(sqlWeight);
LLMResponseService llmResponseService = new LLMResponseService();
Map<String, Double> deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp);
Assert.assertEquals(deduplicationSqlWeight.size(), 1); llmResp.setSqlRespMap(sqlWeight);
LLMResponseService llmResponseService = new LLMResponseService();
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp);
Assert.assertEquals(deduplicationSqlResp.size(), 1);
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a"; sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
LLMResp llmResp2 = new LLMResp(); LLMResp llmResp2 = new LLMResp();
Map<String, Double> sqlWeight2 = new HashMap<>(); Map<String, LLMSqlResp> sqlWeight2 = new HashMap<>();
sqlWeight2.put(sql1, 0.2D); sqlWeight2.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
sqlWeight2.put(sql2, 0.8D); sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
llmResp2.setSqlWeight(sqlWeight2);
deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp2);
Assert.assertEquals(deduplicationSqlWeight.size(), 1); llmResp2.setSqlRespMap(sqlWeight2);
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp2);
Assert.assertEquals(deduplicationSqlResp.size(), 1);
sql1 = "SELECT a,b,c,d,e FROM table1 WHERE column1 = 1 AND column2 = 2 order by a"; sql1 = "SELECT a,b,c,d,e FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
LLMResp llmResp3 = new LLMResp(); LLMResp llmResp3 = new LLMResp();
Map<String, Double> sqlWeight3 = new HashMap<>(); Map<String, LLMSqlResp> sqlWeight3 = new HashMap<>();
sqlWeight3.put(sql1, 0.2D); sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
sqlWeight3.put(sql2, 0.8D); sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
llmResp3.setSqlWeight(sqlWeight3); llmResp3.setSqlRespMap(sqlWeight3);
deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp3); deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp3);
Assert.assertEquals(deduplicationSqlWeight.size(), 2); Assert.assertEquals(deduplicationSqlResp.size(), 2);
} }
} }