(improvement)(chat) Add 'Few-shot Examples' display to the Chat chart. (#589)

This commit is contained in:
lexluo09
2024-01-02 18:07:15 +08:00
committed by GitHub
parent e7f13572d7
commit d72166944c
13 changed files with 118 additions and 59 deletions

View File

@@ -12,7 +12,6 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import java.util.Map;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
@@ -43,12 +42,9 @@ public class JavaLLMProxy implements LLMProxy {
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
String modelName = llmReq.getSchema().getModelName();
Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey);
LLMResp result = new LLMResp();
LLMResp result = sqlGeneration.generation(llmReq, modelClusterKey);
result.setQuery(llmReq.getQueryText());
result.setModelName(modelName);
result.setSqlWeight(sqlWeight);
return result;
}

View File

@@ -6,13 +6,16 @@ import com.tencent.supersonic.chat.config.LLMParserConfig;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import java.net.URI;
import java.net.URL;
import java.util.ArrayList;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -59,10 +62,15 @@ public class PythonLLMProxy implements LLMProxy {
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
LLMResp.class);
LLMResp llmResp = responseEntity.getBody();
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody());
keyPipelineLog.info("LLMResp:{}", responseEntity.getBody());
return responseEntity.getBody();
System.currentTimeMillis() - startTime, url, entity, llmResp);
keyPipelineLog.info("LLMResp:{}", llmResp);
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(new ArrayList<>(), llmResp.getSqlWeight()));
}
return llmResp;
} catch (Exception e) {
log.error("requestLLM error", e);
}

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper;
import lombok.extern.slf4j.Slf4j;
@@ -46,12 +47,12 @@ public class LLMResponseService {
return parseInfo;
}
public Map<String, Double> getDeduplicationSqlWeight(LLMResp llmResp) {
if (MapUtils.isEmpty(llmResp.getSqlWeight())) {
return llmResp.getSqlWeight();
public Map<String, LLMSqlResp> getDeduplicationSqlResp(LLMResp llmResp) {
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
return llmResp.getSqlRespMap();
}
Map<String, Double> result = new HashMap<>();
for (Map.Entry<String, Double> entry : llmResp.getSqlWeight().entrySet()) {
Map<String, LLMSqlResp> result = new HashMap<>();
for (Map.Entry<String, LLMSqlResp> entry : llmResp.getSqlRespMap().entrySet()) {
String key = entry.getKey();
if (result.keySet().stream().anyMatch(existKey -> SqlParserEqualHelper.equals(existKey, key))) {
continue;

View File

@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -56,7 +57,7 @@ public class LLMSqlParser implements SemanticParser {
//5. deduplicate the SQL result list and build parserInfo
modelCluster.buildName(semanticSchema.getModelIdToName());
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
Map<String, Double> deduplicationSqlWeight = responseService.getDeduplicationSqlWeight(llmResp);
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
ParseResult parseResult = ParseResult.builder()
.request(request)
.modelCluster(modelCluster)
@@ -66,11 +67,11 @@ public class LLMSqlParser implements SemanticParser {
.linkingValues(linkingValues)
.build();
if (MapUtils.isEmpty(deduplicationSqlWeight)) {
if (MapUtils.isEmpty(deduplicationSqlResp)) {
responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D);
} else {
deduplicationSqlWeight.forEach((sql, weight) -> {
responseService.addParseInfo(queryCtx, parseResult, sql, weight);
deduplicationSqlResp.forEach((sql, sqlResp) -> {
responseService.addParseInfo(queryCtx, parseResult, sql, sqlResp.getSqlWeight());
});
}

View File

@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -41,9 +42,10 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
private SqlPromptGenerator sqlPromptGenerator;
@Override
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
//1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
@@ -70,9 +72,14 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
Pair<String, Map<String, Double>> linkingMap = OutputFormat.selfConsistencyVote(candidateSortedList);
List<String> sqlList = llmResults.stream()
.map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList());
Pair<String, Map<String, Double>> sqlMap = OutputFormat.selfConsistencyVote(sqlList);
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMap);
return sqlMap.getRight();
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlList);
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
LLMResp result = new LLMResp();
result.setQuery(llmReq.getQueryText());
result.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMapPair.getRight()));
return result;
}
@Override

View File

@@ -4,6 +4,8 @@ package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -38,7 +40,7 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
private SqlPromptGenerator sqlPromptGenerator;
@Override
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
//1.retriever sqlExamples
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
@@ -55,10 +57,14 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
//3.format response.
String schemaLinkStr = OutputFormat.getSchemaLinks(response.content().text());
String sql = OutputFormat.getSql(response.content().text());
Map<String, Double> sqlMap = new HashMap<>();
sqlMap.put(sql, 1D);
keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap);
return sqlMap;
Map<String, LLMSqlResp> sqlRespMap = new HashMap<>();
sqlRespMap.put(sql, LLMSqlResp.builder().sqlWeight(1D).fewShots(sqlExamples).build());
keyPipelineLog.info("schemaLinkStr:{},sqlRespMap:{}", schemaLinkStr, sqlRespMap);
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
llmResp.setSqlRespMap(sqlRespMap);
return llmResp;
}
@Override

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
@@ -10,6 +11,7 @@ import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
@@ -19,8 +21,6 @@ import org.apache.commons.lang3.tuple.Pair;
@Slf4j
public class OutputFormat {
public static final String PATTERN = "\\{[^{}]+\\}";
public static String getSchemaLink(String schemaLink) {
String reult = "";
try {
@@ -126,4 +126,13 @@ public class OutputFormat {
}
return null;
}
public static Map<String, LLMSqlResp> buildSqlRespMap(List<Map<String, String>> sqlExamples,
Map<String, Double> sqlMap) {
return sqlMap.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
entry -> LLMSqlResp.builder().sqlWeight(entry.getValue()).fewShots(sqlExamples).build())
);
}
}

View File

@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import java.util.Map;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
/**
* Sql Generation interface, generating SQL using a large model.
@@ -10,11 +10,11 @@ import java.util.Map;
public interface SqlGeneration {
/***
* generate SQL through LLMReq.
* generate llmResp (sql, schemaLink, prompt, etc.) through LLMReq.
* @param llmReq
* @param modelClusterKey
* @return
*/
Map<String, Double> generation(LLMReq llmReq, String modelClusterKey);
LLMResp generation(LLMReq llmReq, String modelClusterKey);
}

View File

@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -38,7 +39,7 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
private SqlPromptGenerator sqlPromptGenerator;
@Override
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
//1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
@@ -74,9 +75,13 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
sqlTaskPool.add(result);
});
//4.format response.
Pair<String, Map<String, Double>> sqlMap = OutputFormat.selfConsistencyVote(sqlTaskPool);
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMap);
return sqlMap.getRight();
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool);
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMapPair.getRight()));
return llmResp;
}
@Override

View File

@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -38,7 +39,7 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
private SqlPromptGenerator sqlPromptGenerator;
@Override
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
@@ -60,7 +61,11 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
Map<String, Double> sqlMap = new HashMap<>();
sqlMap.put(result, 1D);
keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap);
return sqlMap;
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMap));
return llmResp;
}
@Override

View File

@@ -15,9 +15,11 @@ public class LLMResp {
private List<String> fields;
private String schemaLinkingOutput;
private String schemaLinkStr;
private Map<String, LLMSqlResp> sqlRespMap;
/**
* Only for compatibility with python code, later deleted
*/
private Map<String, Double> sqlWeight;
}

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.chat.query.llm.s2sql;
import java.util.List;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
@Data
@Builder
public class LLMSqlResp {
private double sqlWeight;
private List<Map<String, String>> fewShots;
}

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.parser.sql.llm.LLMResponseService;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
import java.util.HashMap;
import java.util.Map;
import org.junit.Assert;
@@ -15,38 +16,40 @@ class LLMResponseServiceTest {
String sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
LLMResp llmResp = new LLMResp();
Map<String, Double> sqlWeight = new HashMap<>();
sqlWeight.put(sql1, 0.2D);
sqlWeight.put(sql2, 0.8D);
llmResp.setSqlWeight(sqlWeight);
LLMResponseService llmResponseService = new LLMResponseService();
Map<String, Double> deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp);
Map<String, LLMSqlResp> sqlWeight = new HashMap<>();
sqlWeight.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
sqlWeight.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
Assert.assertEquals(deduplicationSqlWeight.size(), 1);
llmResp.setSqlRespMap(sqlWeight);
LLMResponseService llmResponseService = new LLMResponseService();
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp);
Assert.assertEquals(deduplicationSqlResp.size(), 1);
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
LLMResp llmResp2 = new LLMResp();
Map<String, Double> sqlWeight2 = new HashMap<>();
sqlWeight2.put(sql1, 0.2D);
sqlWeight2.put(sql2, 0.8D);
llmResp2.setSqlWeight(sqlWeight2);
deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp2);
Map<String, LLMSqlResp> sqlWeight2 = new HashMap<>();
sqlWeight2.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
Assert.assertEquals(deduplicationSqlWeight.size(), 1);
llmResp2.setSqlRespMap(sqlWeight2);
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp2);
Assert.assertEquals(deduplicationSqlResp.size(), 1);
sql1 = "SELECT a,b,c,d,e FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
LLMResp llmResp3 = new LLMResp();
Map<String, Double> sqlWeight3 = new HashMap<>();
sqlWeight3.put(sql1, 0.2D);
sqlWeight3.put(sql2, 0.8D);
llmResp3.setSqlWeight(sqlWeight3);
deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp3);
Map<String, LLMSqlResp> sqlWeight3 = new HashMap<>();
sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
llmResp3.setSqlRespMap(sqlWeight3);
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp3);
Assert.assertEquals(deduplicationSqlWeight.size(), 2);
Assert.assertEquals(deduplicationSqlResp.size(), 2);
}
}