(improvement)(chat) The large model parsing supports SQL result verification and adds three retries (#1194)

This commit is contained in:
lexluo09
2024-06-22 22:21:51 +08:00
committed by GitHub
parent 32e2c1e39d
commit 29694be64e
10 changed files with 131 additions and 87 deletions

View File

@@ -2,13 +2,15 @@ package com.tencent.supersonic.common.jsqlparser;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.collections.CollectionUtils;
/**
* Sql Parser equal Helper
* Sql Parser valid Helper
*/
@Slf4j
public class SqlEqualHelper {
public class SqlValidHelper {
/**
* determine if two SQL statements are equal.
@@ -63,5 +65,15 @@ public class SqlEqualHelper {
return true;
}
public static boolean isValidSQL(String sql) {
try {
CCJSqlParserUtil.parse(sql);
return true;
} catch (JSQLParserException e) {
log.error("isValidSQL parse:{}", e);
return false;
}
}
}

View File

@@ -11,7 +11,7 @@ class ChatModelProperties {
String apiKey;
Double temperature;
Double topP;
String model;
String modelName;
Integer maxRetries;
Integer maxToken;
Boolean logRequests;

View File

@@ -21,7 +21,7 @@ public class ZhipuAutoConfig {
return ZhipuAiChatModel.builder()
.baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey())
.model(chatModelProperties.getModel())
.model(chatModelProperties.getModelName())
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP())
.maxRetries(chatModelProperties.getMaxRetries())
@@ -38,7 +38,7 @@ public class ZhipuAutoConfig {
return ZhipuAiStreamingChatModel.builder()
.baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey())
.model(chatModelProperties.getModel())
.model(chatModelProperties.getModelName())
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP())
.maxToken(chatModelProperties.getMaxToken())

View File

@@ -4,39 +4,33 @@ package com.tencent.supersonic.common.jsqlparser;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
/**
* @author lex luo
* @date 2023/11/15 15:04
*/
class SqlEqualHelperTest {
class SqlValidHelperTest {
@Test
void testEquals() {
String sql1 = "SELECT * FROM table1 WHERE column1 = 1 AND column2 = 2";
String sql2 = "SELECT * FROM table1 WHERE column2 = 2 AND column1 = 1";
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true);
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true);
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a";
sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a";
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true);
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a";
sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a";
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true);
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
sql2 = "SELECT d,c,b,f FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), false);
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), false);
sql1 = "SELECT\n"
+ "页面,\n"
@@ -65,6 +59,27 @@ class SqlEqualHelperTest {
+ "页面\n"
+ "LIMIT\n"
+ "365";
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true);
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
}
@Test
void testIsValidSQL() {
String sql1 = "SELECT * FROM table1 WHERE column1 = 1 AND column2 = 2";
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), true);
sql1 = "SELECT sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2";
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), true);
sql1 = "SELECT a,b,c, FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false);
sql1 = "SELECTa,b,c,d FROM table1";
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false);
sql1 = "SELECT sum(b),sum(c),sum(d) FROM table1 WHERE";
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false);
}
}

View File

@@ -15,6 +15,9 @@ public class LLMParserConfig {
@Value("${s2.query2sql.path:/query2sql}")
private String queryToSqlPath;
@Value("${s2.recall.max.retries:3}")
private int recallMaxRetries;
@Value("${s2.dimension.topn:10}")
private Integer dimensionTopN;

View File

@@ -1,25 +1,23 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE;
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
import com.tencent.supersonic.headless.chat.QueryContext;
import com.tencent.supersonic.headless.chat.utils.S2SqlDateHelper;
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
import com.tencent.supersonic.headless.chat.utils.S2SqlDateHelper;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
@@ -29,9 +27,12 @@ import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Slf4j
@Service
@@ -62,8 +63,10 @@ public class LLMRequestService {
return dataSetResolver.resolve(queryCtx, queryCtx.getDataSetIds());
}
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId,
SemanticSchema semanticSchema, List<LLMReq.ElementValue> linkingValues) {
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId) {
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
List<LLMReq.ElementValue> linkingValues = requestService.getValues(queryCtx, dataSetId);
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
String queryText = queryCtx.getQueryText();

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.jsqlparser.SqlEqualHelper;
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
@@ -9,6 +9,7 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.headless.chat.QueryContext;
import java.util.ArrayList;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils;
import org.springframework.stereotype.Service;
@@ -41,14 +42,20 @@ public class LLMResponseService {
return parseInfo;
}
public Map<String, LLMSqlResp> getDeduplicationSqlResp(LLMResp llmResp) {
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
return llmResp.getSqlRespMap();
public Map<String, LLMSqlResp> getDeduplicationSqlResp(int currentRetry, LLMResp llmResp) {
Map<String, LLMSqlResp> sqlRespMap = llmResp.getSqlRespMap();
if (MapUtils.isEmpty(sqlRespMap)) {
LLMSqlResp llmSqlResp = new LLMSqlResp(1D, new ArrayList<>());
sqlRespMap.put(llmResp.getSqlOutput(), llmSqlResp);
}
Map<String, LLMSqlResp> result = new HashMap<>();
for (Map.Entry<String, LLMSqlResp> entry : llmResp.getSqlRespMap().entrySet()) {
for (Map.Entry<String, LLMSqlResp> entry : sqlRespMap.entrySet()) {
String key = entry.getKey();
if (result.keySet().stream().anyMatch(existKey -> SqlEqualHelper.equals(existKey, key))) {
if (result.keySet().stream().anyMatch(existKey -> SqlValidHelper.equals(existKey, key))) {
continue;
}
if (!SqlValidHelper.isValidSQL(key)) {
log.error("currentRetry:{},sql is not valid:{}", currentRetry, key);
continue;
}
result.put(key, entry.getValue());

View File

@@ -2,20 +2,18 @@ package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.QueryContext;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.chat.ChatContext;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
import java.util.Map;
import java.util.Objects;
/**
* LLMSqlParser uses large language model to understand query semantics and
@@ -26,12 +24,12 @@ public class LLMSqlParser implements SemanticParser {
@Override
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
try {
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
//1.determine whether to skip this parser.
if (requestService.isSkip(queryCtx)) {
return;
}
try {
//2.get dataSetId from queryCtx and chatCtx.
Long dataSetId = requestService.getDataSetId(queryCtx);
if (dataSetId == null) {
@@ -40,39 +38,46 @@ public class LLMSqlParser implements SemanticParser {
log.info("Generate query statement for dataSetId:{}", dataSetId);
//3.invoke LLM service to do parsing.
List<LLMReq.ElementValue> linkingValues = requestService.getValues(queryCtx, dataSetId);
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues);
LLMResp llmResp = requestService.runText2SQL(llmReq);
if (Objects.isNull(llmResp)) {
return;
}
//4. deduplicate the S2SQL result list and build parserInfo
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
ParseResult parseResult = ParseResult.builder()
.dataSetId(dataSetId)
.llmReq(llmReq)
.llmResp(llmResp)
.linkingValues(linkingValues)
.build();
if (MapUtils.isEmpty(deduplicationSqlResp)) {
if (StringUtils.isNotBlank(llmResp.getSqlOutput())) {
responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D);
}
} else {
deduplicationSqlResp.forEach((sql, sqlResp) -> {
if (StringUtils.isNotBlank(sql)) {
responseService.addParseInfo(queryCtx, parseResult, sql, sqlResp.getSqlWeight());
}
});
}
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId);
recall(queryCtx, dataSetId, llmReq);
} catch (Exception e) {
log.error("Failed to parse query:", e);
}
}
private void recall(QueryContext queryCtx, Long dataSetId, LLMReq llmReq) {
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
int maxRetries = ContextUtils.getBean(LLMParserConfig.class).getRecallMaxRetries();
int currentRetry = 0;
Map<String, LLMSqlResp> sqlRespMap = new HashMap<>();
ParseResult parseResult = null;
while (currentRetry < maxRetries) {
log.info("currentRetry:{},start runText2SQL", currentRetry);
try {
LLMResp llmResp = requestService.runText2SQL(llmReq);
if (Objects.nonNull(llmResp)) {
//deduplicate the S2SQL result list and build parserInfo
sqlRespMap = responseService.getDeduplicationSqlResp(currentRetry, llmResp);
if (MapUtils.isNotEmpty(sqlRespMap)) {
parseResult = ParseResult.builder()
.dataSetId(dataSetId).llmReq(llmReq).llmResp(llmResp)
.linkingValues(llmReq.getLinking()).build();
break;
}
}
} catch (Exception e) {
log.error("currentRetry:{},runText2SQL error", currentRetry, e);
}
currentRetry++;
}
if (MapUtils.isNotEmpty(sqlRespMap)) {
for (Entry<String, LLMSqlResp> entry : sqlRespMap.entrySet()) {
String sql = entry.getKey();
double sqlWeight = entry.getValue().getSqlWeight();
responseService.addParseInfo(queryCtx, parseResult, sql, sqlWeight);
}
}
}
}

View File

@@ -26,7 +26,6 @@ public class LLMReq {
private SqlGenType sqlGenType;
private LLMConfig llmConfig;
@Data
public static class ElementValue {

View File

@@ -23,7 +23,7 @@ class LLMResponseServiceTest {
llmResp.setSqlRespMap(sqlWeight);
LLMResponseService llmResponseService = new LLMResponseService();
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp);
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(0, llmResp);
Assert.assertEquals(deduplicationSqlResp.size(), 1);
@@ -36,7 +36,7 @@ class LLMResponseServiceTest {
sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
llmResp2.setSqlRespMap(sqlWeight2);
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp2);
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(0, llmResp2);
Assert.assertEquals(deduplicationSqlResp.size(), 1);
@@ -48,7 +48,7 @@ class LLMResponseServiceTest {
sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
llmResp3.setSqlRespMap(sqlWeight3);
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp3);
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(0, llmResp3);
Assert.assertEquals(deduplicationSqlResp.size(), 2);