mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-17 16:02:14 +00:00
(improvement)(chat) Add 'Few-shot Examples' display to the Chat chart. (#589)
This commit is contained in:
@@ -12,7 +12,6 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
|||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
@@ -43,12 +42,9 @@ public class JavaLLMProxy implements LLMProxy {
|
|||||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
||||||
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
|
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
|
||||||
String modelName = llmReq.getSchema().getModelName();
|
String modelName = llmReq.getSchema().getModelName();
|
||||||
Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey);
|
LLMResp result = sqlGeneration.generation(llmReq, modelClusterKey);
|
||||||
|
|
||||||
LLMResp result = new LLMResp();
|
|
||||||
result.setQuery(llmReq.getQueryText());
|
result.setQuery(llmReq.getQueryText());
|
||||||
result.setModelName(modelName);
|
result.setModelName(modelName);
|
||||||
result.setSqlWeight(sqlWeight);
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,13 +6,16 @@ import com.tencent.supersonic.chat.config.LLMParserConfig;
|
|||||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig;
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||||
|
import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.net.URL;
|
import java.net.URL;
|
||||||
|
import java.util.ArrayList;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.MapUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
@@ -59,10 +62,15 @@ public class PythonLLMProxy implements LLMProxy {
|
|||||||
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
|
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
|
||||||
LLMResp.class);
|
LLMResp.class);
|
||||||
|
|
||||||
|
LLMResp llmResp = responseEntity.getBody();
|
||||||
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
||||||
System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody());
|
System.currentTimeMillis() - startTime, url, entity, llmResp);
|
||||||
keyPipelineLog.info("LLMResp:{}", responseEntity.getBody());
|
keyPipelineLog.info("LLMResp:{}", llmResp);
|
||||||
return responseEntity.getBody();
|
|
||||||
|
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
|
||||||
|
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(new ArrayList<>(), llmResp.getSqlWeight()));
|
||||||
|
}
|
||||||
|
return llmResp;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("requestLLM error", e);
|
log.error("requestLLM error", e);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.query.QueryManager;
|
|||||||
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -46,12 +47,12 @@ public class LLMResponseService {
|
|||||||
return parseInfo;
|
return parseInfo;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<String, Double> getDeduplicationSqlWeight(LLMResp llmResp) {
|
public Map<String, LLMSqlResp> getDeduplicationSqlResp(LLMResp llmResp) {
|
||||||
if (MapUtils.isEmpty(llmResp.getSqlWeight())) {
|
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
|
||||||
return llmResp.getSqlWeight();
|
return llmResp.getSqlRespMap();
|
||||||
}
|
}
|
||||||
Map<String, Double> result = new HashMap<>();
|
Map<String, LLMSqlResp> result = new HashMap<>();
|
||||||
for (Map.Entry<String, Double> entry : llmResp.getSqlWeight().entrySet()) {
|
for (Map.Entry<String, LLMSqlResp> entry : llmResp.getSqlRespMap().entrySet()) {
|
||||||
String key = entry.getKey();
|
String key = entry.getKey();
|
||||||
if (result.keySet().stream().anyMatch(existKey -> SqlParserEqualHelper.equals(existKey, key))) {
|
if (result.keySet().stream().anyMatch(existKey -> SqlParserEqualHelper.equals(existKey, key))) {
|
||||||
continue;
|
continue;
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
|||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
|
||||||
import com.tencent.supersonic.chat.service.SemanticService;
|
import com.tencent.supersonic.chat.service.SemanticService;
|
||||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
@@ -56,7 +57,7 @@ public class LLMSqlParser implements SemanticParser {
|
|||||||
//5. deduplicate the SQL result list and build parserInfo
|
//5. deduplicate the SQL result list and build parserInfo
|
||||||
modelCluster.buildName(semanticSchema.getModelIdToName());
|
modelCluster.buildName(semanticSchema.getModelIdToName());
|
||||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||||
Map<String, Double> deduplicationSqlWeight = responseService.getDeduplicationSqlWeight(llmResp);
|
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
|
||||||
ParseResult parseResult = ParseResult.builder()
|
ParseResult parseResult = ParseResult.builder()
|
||||||
.request(request)
|
.request(request)
|
||||||
.modelCluster(modelCluster)
|
.modelCluster(modelCluster)
|
||||||
@@ -66,11 +67,11 @@ public class LLMSqlParser implements SemanticParser {
|
|||||||
.linkingValues(linkingValues)
|
.linkingValues(linkingValues)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
if (MapUtils.isEmpty(deduplicationSqlWeight)) {
|
if (MapUtils.isEmpty(deduplicationSqlResp)) {
|
||||||
responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D);
|
responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D);
|
||||||
} else {
|
} else {
|
||||||
deduplicationSqlWeight.forEach((sql, weight) -> {
|
deduplicationSqlResp.forEach((sql, sqlResp) -> {
|
||||||
responseService.addParseInfo(queryCtx, parseResult, sql, weight);
|
responseService.addParseInfo(queryCtx, parseResult, sql, sqlResp.getSqlWeight());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
|
|||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
@@ -41,9 +42,10 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
|||||||
private SqlPromptGenerator sqlPromptGenerator;
|
private SqlPromptGenerator sqlPromptGenerator;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
|
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
||||||
//1.retriever sqlExamples and generate exampleListPool
|
//1.retriever sqlExamples and generate exampleListPool
|
||||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||||
|
|
||||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||||
|
|
||||||
@@ -70,9 +72,14 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
|||||||
Pair<String, Map<String, Double>> linkingMap = OutputFormat.selfConsistencyVote(candidateSortedList);
|
Pair<String, Map<String, Double>> linkingMap = OutputFormat.selfConsistencyVote(candidateSortedList);
|
||||||
List<String> sqlList = llmResults.stream()
|
List<String> sqlList = llmResults.stream()
|
||||||
.map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList());
|
.map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList());
|
||||||
Pair<String, Map<String, Double>> sqlMap = OutputFormat.selfConsistencyVote(sqlList);
|
|
||||||
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMap);
|
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlList);
|
||||||
return sqlMap.getRight();
|
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
|
||||||
|
|
||||||
|
LLMResp result = new LLMResp();
|
||||||
|
result.setQuery(llmReq.getQueryText());
|
||||||
|
result.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMapPair.getRight()));
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ package com.tencent.supersonic.chat.parser.sql.llm;
|
|||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
@@ -38,7 +40,7 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
|||||||
private SqlPromptGenerator sqlPromptGenerator;
|
private SqlPromptGenerator sqlPromptGenerator;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
|
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
||||||
//1.retriever sqlExamples
|
//1.retriever sqlExamples
|
||||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||||
@@ -55,10 +57,14 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
|||||||
//3.format response.
|
//3.format response.
|
||||||
String schemaLinkStr = OutputFormat.getSchemaLinks(response.content().text());
|
String schemaLinkStr = OutputFormat.getSchemaLinks(response.content().text());
|
||||||
String sql = OutputFormat.getSql(response.content().text());
|
String sql = OutputFormat.getSql(response.content().text());
|
||||||
Map<String, Double> sqlMap = new HashMap<>();
|
Map<String, LLMSqlResp> sqlRespMap = new HashMap<>();
|
||||||
sqlMap.put(sql, 1D);
|
sqlRespMap.put(sql, LLMSqlResp.builder().sqlWeight(1D).fewShots(sqlExamples).build());
|
||||||
keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap);
|
keyPipelineLog.info("schemaLinkStr:{},sqlRespMap:{}", schemaLinkStr, sqlRespMap);
|
||||||
return sqlMap;
|
|
||||||
|
LLMResp llmResp = new LLMResp();
|
||||||
|
llmResp.setQuery(llmReq.getQueryText());
|
||||||
|
llmResp.setSqlRespMap(sqlRespMap);
|
||||||
|
return llmResp;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
|
|||||||
import com.fasterxml.jackson.databind.JsonNode;
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -10,6 +11,7 @@ import java.util.List;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.regex.Matcher;
|
import java.util.regex.Matcher;
|
||||||
import java.util.regex.Pattern;
|
import java.util.regex.Pattern;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
|
||||||
@@ -19,8 +21,6 @@ import org.apache.commons.lang3.tuple.Pair;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class OutputFormat {
|
public class OutputFormat {
|
||||||
|
|
||||||
public static final String PATTERN = "\\{[^{}]+\\}";
|
|
||||||
|
|
||||||
public static String getSchemaLink(String schemaLink) {
|
public static String getSchemaLink(String schemaLink) {
|
||||||
String reult = "";
|
String reult = "";
|
||||||
try {
|
try {
|
||||||
@@ -126,4 +126,13 @@ public class OutputFormat {
|
|||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static Map<String, LLMSqlResp> buildSqlRespMap(List<Map<String, String>> sqlExamples,
|
||||||
|
Map<String, Double> sqlMap) {
|
||||||
|
return sqlMap.entrySet().stream()
|
||||||
|
.collect(Collectors.toMap(
|
||||||
|
Map.Entry::getKey,
|
||||||
|
entry -> LLMSqlResp.builder().sqlWeight(entry.getValue()).fewShots(sqlExamples).build())
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
|
|||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
import java.util.Map;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sql Generation interface, generating SQL using a large model.
|
* Sql Generation interface, generating SQL using a large model.
|
||||||
@@ -10,11 +10,11 @@ import java.util.Map;
|
|||||||
public interface SqlGeneration {
|
public interface SqlGeneration {
|
||||||
|
|
||||||
/***
|
/***
|
||||||
* generate SQL through LLMReq.
|
* generate llmResp (sql, schemaLink, prompt, etc.) through LLMReq.
|
||||||
* @param llmReq
|
* @param llmReq
|
||||||
* @param modelClusterKey
|
* @param modelClusterKey
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
Map<String, Double> generation(LLMReq llmReq, String modelClusterKey);
|
LLMResp generation(LLMReq llmReq, String modelClusterKey);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
|
|||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
@@ -38,7 +39,7 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
|||||||
private SqlPromptGenerator sqlPromptGenerator;
|
private SqlPromptGenerator sqlPromptGenerator;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
|
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
||||||
//1.retriever sqlExamples and generate exampleListPool
|
//1.retriever sqlExamples and generate exampleListPool
|
||||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||||
@@ -74,9 +75,13 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
|||||||
sqlTaskPool.add(result);
|
sqlTaskPool.add(result);
|
||||||
});
|
});
|
||||||
//4.format response.
|
//4.format response.
|
||||||
Pair<String, Map<String, Double>> sqlMap = OutputFormat.selfConsistencyVote(sqlTaskPool);
|
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool);
|
||||||
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMap);
|
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
|
||||||
return sqlMap.getRight();
|
|
||||||
|
LLMResp llmResp = new LLMResp();
|
||||||
|
llmResp.setQuery(llmReq.getQueryText());
|
||||||
|
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMapPair.getRight()));
|
||||||
|
return llmResp;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
|
|||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
@@ -38,7 +39,7 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
|||||||
private SqlPromptGenerator sqlPromptGenerator;
|
private SqlPromptGenerator sqlPromptGenerator;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
|
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
||||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||||
@@ -60,7 +61,11 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
|||||||
Map<String, Double> sqlMap = new HashMap<>();
|
Map<String, Double> sqlMap = new HashMap<>();
|
||||||
sqlMap.put(result, 1D);
|
sqlMap.put(result, 1D);
|
||||||
keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap);
|
keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap);
|
||||||
return sqlMap;
|
|
||||||
|
LLMResp llmResp = new LLMResp();
|
||||||
|
llmResp.setQuery(llmReq.getQueryText());
|
||||||
|
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMap));
|
||||||
|
return llmResp;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -15,9 +15,11 @@ public class LLMResp {
|
|||||||
|
|
||||||
private List<String> fields;
|
private List<String> fields;
|
||||||
|
|
||||||
private String schemaLinkingOutput;
|
private Map<String, LLMSqlResp> sqlRespMap;
|
||||||
|
|
||||||
private String schemaLinkStr;
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Only for compatibility with python code, later deleted
|
||||||
|
*/
|
||||||
private Map<String, Double> sqlWeight;
|
private Map<String, Double> sqlWeight;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package com.tencent.supersonic.chat.query.llm.s2sql;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
public class LLMSqlResp {
|
||||||
|
|
||||||
|
private double sqlWeight;
|
||||||
|
|
||||||
|
private List<Map<String, String>> fewShots;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.parser.llm.s2sql;
|
|||||||
|
|
||||||
import com.tencent.supersonic.chat.parser.sql.llm.LLMResponseService;
|
import com.tencent.supersonic.chat.parser.sql.llm.LLMResponseService;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
@@ -15,38 +16,40 @@ class LLMResponseServiceTest {
|
|||||||
String sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
String sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||||
|
|
||||||
LLMResp llmResp = new LLMResp();
|
LLMResp llmResp = new LLMResp();
|
||||||
Map<String, Double> sqlWeight = new HashMap<>();
|
Map<String, LLMSqlResp> sqlWeight = new HashMap<>();
|
||||||
sqlWeight.put(sql1, 0.2D);
|
sqlWeight.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||||
sqlWeight.put(sql2, 0.8D);
|
sqlWeight.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||||
llmResp.setSqlWeight(sqlWeight);
|
|
||||||
LLMResponseService llmResponseService = new LLMResponseService();
|
|
||||||
Map<String, Double> deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp);
|
|
||||||
|
|
||||||
Assert.assertEquals(deduplicationSqlWeight.size(), 1);
|
llmResp.setSqlRespMap(sqlWeight);
|
||||||
|
LLMResponseService llmResponseService = new LLMResponseService();
|
||||||
|
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp);
|
||||||
|
|
||||||
|
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
||||||
|
|
||||||
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||||
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||||
|
|
||||||
LLMResp llmResp2 = new LLMResp();
|
LLMResp llmResp2 = new LLMResp();
|
||||||
Map<String, Double> sqlWeight2 = new HashMap<>();
|
Map<String, LLMSqlResp> sqlWeight2 = new HashMap<>();
|
||||||
sqlWeight2.put(sql1, 0.2D);
|
sqlWeight2.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||||
sqlWeight2.put(sql2, 0.8D);
|
sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||||
llmResp2.setSqlWeight(sqlWeight2);
|
|
||||||
deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp2);
|
|
||||||
|
|
||||||
Assert.assertEquals(deduplicationSqlWeight.size(), 1);
|
llmResp2.setSqlRespMap(sqlWeight2);
|
||||||
|
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp2);
|
||||||
|
|
||||||
|
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
||||||
|
|
||||||
sql1 = "SELECT a,b,c,d,e FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
sql1 = "SELECT a,b,c,d,e FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||||
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||||
|
|
||||||
LLMResp llmResp3 = new LLMResp();
|
LLMResp llmResp3 = new LLMResp();
|
||||||
Map<String, Double> sqlWeight3 = new HashMap<>();
|
Map<String, LLMSqlResp> sqlWeight3 = new HashMap<>();
|
||||||
sqlWeight3.put(sql1, 0.2D);
|
sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||||
sqlWeight3.put(sql2, 0.8D);
|
sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||||
llmResp3.setSqlWeight(sqlWeight3);
|
llmResp3.setSqlRespMap(sqlWeight3);
|
||||||
deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp3);
|
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp3);
|
||||||
|
|
||||||
Assert.assertEquals(deduplicationSqlWeight.size(), 2);
|
Assert.assertEquals(deduplicationSqlResp.size(), 2);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user