mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(chat) The large model parsing supports SQL result verification and adds three retries (#1194)
This commit is contained in:
@@ -2,13 +2,15 @@ package com.tencent.supersonic.common.jsqlparser;
|
||||
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Sql Parser equal Helper
|
||||
* Sql Parser valid Helper
|
||||
*/
|
||||
@Slf4j
|
||||
public class SqlEqualHelper {
|
||||
public class SqlValidHelper {
|
||||
|
||||
/**
|
||||
* determine if two SQL statements are equal.
|
||||
@@ -63,5 +65,15 @@ public class SqlEqualHelper {
|
||||
return true;
|
||||
}
|
||||
|
||||
public static boolean isValidSQL(String sql) {
|
||||
try {
|
||||
CCJSqlParserUtil.parse(sql);
|
||||
return true;
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("isValidSQL parse:{}", e);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ class ChatModelProperties {
|
||||
String apiKey;
|
||||
Double temperature;
|
||||
Double topP;
|
||||
String model;
|
||||
String modelName;
|
||||
Integer maxRetries;
|
||||
Integer maxToken;
|
||||
Boolean logRequests;
|
||||
|
||||
@@ -21,7 +21,7 @@ public class ZhipuAutoConfig {
|
||||
return ZhipuAiChatModel.builder()
|
||||
.baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey())
|
||||
.model(chatModelProperties.getModel())
|
||||
.model(chatModelProperties.getModelName())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.topP(chatModelProperties.getTopP())
|
||||
.maxRetries(chatModelProperties.getMaxRetries())
|
||||
@@ -38,7 +38,7 @@ public class ZhipuAutoConfig {
|
||||
return ZhipuAiStreamingChatModel.builder()
|
||||
.baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey())
|
||||
.model(chatModelProperties.getModel())
|
||||
.model(chatModelProperties.getModelName())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.topP(chatModelProperties.getTopP())
|
||||
.maxToken(chatModelProperties.getMaxToken())
|
||||
|
||||
@@ -4,39 +4,33 @@ package com.tencent.supersonic.common.jsqlparser;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
/**
|
||||
* @author lex luo
|
||||
* @date 2023/11/15 15:04
|
||||
*/
|
||||
class SqlEqualHelperTest {
|
||||
class SqlValidHelperTest {
|
||||
|
||||
@Test
|
||||
void testEquals() {
|
||||
String sql1 = "SELECT * FROM table1 WHERE column1 = 1 AND column2 = 2";
|
||||
String sql2 = "SELECT * FROM table1 WHERE column2 = 2 AND column1 = 1";
|
||||
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true);
|
||||
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
|
||||
|
||||
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true);
|
||||
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
|
||||
|
||||
sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a";
|
||||
|
||||
sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a";
|
||||
|
||||
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true);
|
||||
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
|
||||
|
||||
sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a";
|
||||
|
||||
sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a";
|
||||
|
||||
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true);
|
||||
|
||||
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
|
||||
|
||||
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
sql2 = "SELECT d,c,b,f FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), false);
|
||||
|
||||
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), false);
|
||||
|
||||
sql1 = "SELECT\n"
|
||||
+ "页面,\n"
|
||||
@@ -65,6 +59,27 @@ class SqlEqualHelperTest {
|
||||
+ "页面\n"
|
||||
+ "LIMIT\n"
|
||||
+ "365";
|
||||
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true);
|
||||
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testIsValidSQL() {
|
||||
String sql1 = "SELECT * FROM table1 WHERE column1 = 1 AND column2 = 2";
|
||||
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), true);
|
||||
|
||||
sql1 = "SELECT sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2";
|
||||
|
||||
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), true);
|
||||
|
||||
sql1 = "SELECT a,b,c, FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false);
|
||||
|
||||
sql1 = "SELECTa,b,c,d FROM table1";
|
||||
|
||||
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false);
|
||||
|
||||
sql1 = "SELECT sum(b),sum(c),sum(d) FROM table1 WHERE";
|
||||
|
||||
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false);
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,9 @@ public class LLMParserConfig {
|
||||
@Value("${s2.query2sql.path:/query2sql}")
|
||||
private String queryToSqlPath;
|
||||
|
||||
@Value("${s2.recall.max.retries:3}")
|
||||
private int recallMaxRetries;
|
||||
|
||||
@Value("${s2.dimension.topn:10}")
|
||||
private Integer dimensionTopN;
|
||||
|
||||
|
||||
@@ -1,25 +1,23 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.utils.S2SqlDateHelper;
|
||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.headless.chat.utils.S2SqlDateHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
@@ -29,9 +27,12 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@@ -62,8 +63,10 @@ public class LLMRequestService {
|
||||
return dataSetResolver.resolve(queryCtx, queryCtx.getDataSetIds());
|
||||
}
|
||||
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId,
|
||||
SemanticSchema semanticSchema, List<LLMReq.ElementValue> linkingValues) {
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId) {
|
||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||
List<LLMReq.ElementValue> linkingValues = requestService.getValues(queryCtx, dataSetId);
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
|
||||
String queryText = queryCtx.getQueryText();
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlEqualHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
||||
@@ -9,6 +9,7 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import java.util.ArrayList;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -41,14 +42,20 @@ public class LLMResponseService {
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
public Map<String, LLMSqlResp> getDeduplicationSqlResp(LLMResp llmResp) {
|
||||
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
|
||||
return llmResp.getSqlRespMap();
|
||||
public Map<String, LLMSqlResp> getDeduplicationSqlResp(int currentRetry, LLMResp llmResp) {
|
||||
Map<String, LLMSqlResp> sqlRespMap = llmResp.getSqlRespMap();
|
||||
if (MapUtils.isEmpty(sqlRespMap)) {
|
||||
LLMSqlResp llmSqlResp = new LLMSqlResp(1D, new ArrayList<>());
|
||||
sqlRespMap.put(llmResp.getSqlOutput(), llmSqlResp);
|
||||
}
|
||||
Map<String, LLMSqlResp> result = new HashMap<>();
|
||||
for (Map.Entry<String, LLMSqlResp> entry : llmResp.getSqlRespMap().entrySet()) {
|
||||
for (Map.Entry<String, LLMSqlResp> entry : sqlRespMap.entrySet()) {
|
||||
String key = entry.getKey();
|
||||
if (result.keySet().stream().anyMatch(existKey -> SqlEqualHelper.equals(existKey, key))) {
|
||||
if (result.keySet().stream().anyMatch(existKey -> SqlValidHelper.equals(existKey, key))) {
|
||||
continue;
|
||||
}
|
||||
if (!SqlValidHelper.isValidSQL(key)) {
|
||||
log.error("currentRetry:{},sql is not valid:{}", currentRetry, key);
|
||||
continue;
|
||||
}
|
||||
result.put(key, entry.getValue());
|
||||
|
||||
@@ -2,20 +2,18 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* LLMSqlParser uses large language model to understand query semantics and
|
||||
@@ -26,12 +24,12 @@ public class LLMSqlParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
|
||||
try {
|
||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||
//1.determine whether to skip this parser.
|
||||
if (requestService.isSkip(queryCtx)) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
//2.get dataSetId from queryCtx and chatCtx.
|
||||
Long dataSetId = requestService.getDataSetId(queryCtx);
|
||||
if (dataSetId == null) {
|
||||
@@ -40,39 +38,46 @@ public class LLMSqlParser implements SemanticParser {
|
||||
log.info("Generate query statement for dataSetId:{}", dataSetId);
|
||||
|
||||
//3.invoke LLM service to do parsing.
|
||||
List<LLMReq.ElementValue> linkingValues = requestService.getValues(queryCtx, dataSetId);
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues);
|
||||
LLMResp llmResp = requestService.runText2SQL(llmReq);
|
||||
if (Objects.isNull(llmResp)) {
|
||||
return;
|
||||
}
|
||||
|
||||
//4. deduplicate the S2SQL result list and build parserInfo
|
||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
|
||||
ParseResult parseResult = ParseResult.builder()
|
||||
.dataSetId(dataSetId)
|
||||
.llmReq(llmReq)
|
||||
.llmResp(llmResp)
|
||||
.linkingValues(linkingValues)
|
||||
.build();
|
||||
|
||||
if (MapUtils.isEmpty(deduplicationSqlResp)) {
|
||||
if (StringUtils.isNotBlank(llmResp.getSqlOutput())) {
|
||||
responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D);
|
||||
}
|
||||
} else {
|
||||
deduplicationSqlResp.forEach((sql, sqlResp) -> {
|
||||
if (StringUtils.isNotBlank(sql)) {
|
||||
responseService.addParseInfo(queryCtx, parseResult, sql, sqlResp.getSqlWeight());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId);
|
||||
recall(queryCtx, dataSetId, llmReq);
|
||||
} catch (Exception e) {
|
||||
log.error("Failed to parse query:", e);
|
||||
}
|
||||
}
|
||||
|
||||
private void recall(QueryContext queryCtx, Long dataSetId, LLMReq llmReq) {
|
||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||
int maxRetries = ContextUtils.getBean(LLMParserConfig.class).getRecallMaxRetries();
|
||||
int currentRetry = 0;
|
||||
Map<String, LLMSqlResp> sqlRespMap = new HashMap<>();
|
||||
ParseResult parseResult = null;
|
||||
while (currentRetry < maxRetries) {
|
||||
log.info("currentRetry:{},start runText2SQL", currentRetry);
|
||||
try {
|
||||
LLMResp llmResp = requestService.runText2SQL(llmReq);
|
||||
if (Objects.nonNull(llmResp)) {
|
||||
//deduplicate the S2SQL result list and build parserInfo
|
||||
sqlRespMap = responseService.getDeduplicationSqlResp(currentRetry, llmResp);
|
||||
if (MapUtils.isNotEmpty(sqlRespMap)) {
|
||||
parseResult = ParseResult.builder()
|
||||
.dataSetId(dataSetId).llmReq(llmReq).llmResp(llmResp)
|
||||
.linkingValues(llmReq.getLinking()).build();
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("currentRetry:{},runText2SQL error", currentRetry, e);
|
||||
}
|
||||
currentRetry++;
|
||||
}
|
||||
if (MapUtils.isNotEmpty(sqlRespMap)) {
|
||||
for (Entry<String, LLMSqlResp> entry : sqlRespMap.entrySet()) {
|
||||
String sql = entry.getKey();
|
||||
double sqlWeight = entry.getValue().getSqlWeight();
|
||||
responseService.addParseInfo(queryCtx, parseResult, sql, sqlWeight);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -26,7 +26,6 @@ public class LLMReq {
|
||||
private SqlGenType sqlGenType;
|
||||
|
||||
private LLMConfig llmConfig;
|
||||
|
||||
@Data
|
||||
public static class ElementValue {
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class LLMResponseServiceTest {
|
||||
|
||||
llmResp.setSqlRespMap(sqlWeight);
|
||||
LLMResponseService llmResponseService = new LLMResponseService();
|
||||
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp);
|
||||
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(0, llmResp);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
||||
|
||||
@@ -36,7 +36,7 @@ class LLMResponseServiceTest {
|
||||
sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||
|
||||
llmResp2.setSqlRespMap(sqlWeight2);
|
||||
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp2);
|
||||
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(0, llmResp2);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
||||
|
||||
@@ -48,7 +48,7 @@ class LLMResponseServiceTest {
|
||||
sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||
sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||
llmResp3.setSqlRespMap(sqlWeight3);
|
||||
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp3);
|
||||
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(0, llmResp3);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 2);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user