diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlEqualHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelper.java similarity index 84% rename from common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlEqualHelper.java rename to common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelper.java index fc76a3239..9ac38858a 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlEqualHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelper.java @@ -2,13 +2,15 @@ package com.tencent.supersonic.common.jsqlparser; import java.util.List; import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; import org.apache.commons.collections.CollectionUtils; /** - * Sql Parser equal Helper + * Sql Parser valid Helper */ @Slf4j -public class SqlEqualHelper { +public class SqlValidHelper { /** * determine if two SQL statements are equal. @@ -63,5 +65,15 @@ public class SqlEqualHelper { return true; } + public static boolean isValidSQL(String sql) { + try { + CCJSqlParserUtil.parse(sql); + return true; + } catch (JSQLParserException e) { + log.error("isValidSQL parse:{}", e); + return false; + } + } + } diff --git a/common/src/main/java/dev/langchain4j/zhipu/spring/ChatModelProperties.java b/common/src/main/java/dev/langchain4j/zhipu/spring/ChatModelProperties.java index 000d29997..ceaca061d 100644 --- a/common/src/main/java/dev/langchain4j/zhipu/spring/ChatModelProperties.java +++ b/common/src/main/java/dev/langchain4j/zhipu/spring/ChatModelProperties.java @@ -11,7 +11,7 @@ class ChatModelProperties { String apiKey; Double temperature; Double topP; - String model; + String modelName; Integer maxRetries; Integer maxToken; Boolean logRequests; diff --git a/common/src/main/java/dev/langchain4j/zhipu/spring/ZhipuAutoConfig.java b/common/src/main/java/dev/langchain4j/zhipu/spring/ZhipuAutoConfig.java index a2d0d3f25..d40f5b465 100644 --- a/common/src/main/java/dev/langchain4j/zhipu/spring/ZhipuAutoConfig.java +++ b/common/src/main/java/dev/langchain4j/zhipu/spring/ZhipuAutoConfig.java @@ -21,7 +21,7 @@ public class ZhipuAutoConfig { return ZhipuAiChatModel.builder() .baseUrl(chatModelProperties.getBaseUrl()) .apiKey(chatModelProperties.getApiKey()) - .model(chatModelProperties.getModel()) + .model(chatModelProperties.getModelName()) .temperature(chatModelProperties.getTemperature()) .topP(chatModelProperties.getTopP()) .maxRetries(chatModelProperties.getMaxRetries()) @@ -38,7 +38,7 @@ public class ZhipuAutoConfig { return ZhipuAiStreamingChatModel.builder() .baseUrl(chatModelProperties.getBaseUrl()) .apiKey(chatModelProperties.getApiKey()) - .model(chatModelProperties.getModel()) + .model(chatModelProperties.getModelName()) .temperature(chatModelProperties.getTemperature()) .topP(chatModelProperties.getTopP()) .maxToken(chatModelProperties.getMaxToken()) diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlEqualHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelperTest.java similarity index 63% rename from common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlEqualHelperTest.java rename to common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelperTest.java index ab115e9f3..1a7dbd782 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlEqualHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelperTest.java @@ -4,39 +4,33 @@ package com.tencent.supersonic.common.jsqlparser; import org.junit.Assert; import org.junit.jupiter.api.Test; -/** - * @author lex luo - * @date 2023/11/15 15:04 - */ -class SqlEqualHelperTest { +class SqlValidHelperTest { @Test void testEquals() { String sql1 = "SELECT * FROM table1 WHERE column1 = 1 AND column2 = 2"; String sql2 = "SELECT * FROM table1 WHERE column2 = 2 AND column1 = 1"; - Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true); + Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true); sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a"; sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; - Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true); + Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true); sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a"; sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a"; - Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true); + Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true); sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a"; sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a"; - Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true); - + Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true); sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a"; sql2 = "SELECT d,c,b,f FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; - Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), false); - + Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), false); sql1 = "SELECT\n" + "页面,\n" @@ -65,6 +59,27 @@ class SqlEqualHelperTest { + "页面\n" + "LIMIT\n" + "365"; - Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true); + Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true); + } + + @Test + void testIsValidSQL() { + String sql1 = "SELECT * FROM table1 WHERE column1 = 1 AND column2 = 2"; + Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), true); + + sql1 = "SELECT sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2"; + + Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), true); + + sql1 = "SELECT a,b,c, FROM table1 WHERE column1 = 1 AND column2 = 2 order by a"; + Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false); + + sql1 = "SELECTa,b,c,d FROM table1"; + + Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false); + + sql1 = "SELECT sum(b),sum(c),sum(d) FROM table1 WHERE"; + + Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false); } } \ No newline at end of file diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMParserConfig.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMParserConfig.java index ad32c9fa5..b8b7808b2 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMParserConfig.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMParserConfig.java @@ -15,6 +15,9 @@ public class LLMParserConfig { @Value("${s2.query2sql.path:/query2sql}") private String queryToSqlPath; + @Value("${s2.recall.max.retries:3}") + private int recallMaxRetries; + @Value("${s2.dimension.topn:10}") private Integer dimensionTopN; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java index 101ebfa1a..e8d128839 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java @@ -1,25 +1,23 @@ package com.tencent.supersonic.headless.chat.parser.llm; +import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE; +import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE; + import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; +import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; -import com.tencent.supersonic.headless.chat.utils.ComponentFactory; -import com.tencent.supersonic.headless.chat.parser.ParserConfig; import com.tencent.supersonic.headless.chat.QueryContext; -import com.tencent.supersonic.headless.chat.utils.S2SqlDateHelper; +import com.tencent.supersonic.headless.chat.parser.ParserConfig; +import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; -import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; -import org.apache.commons.lang3.tuple.Pair; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; -import org.springframework.util.CollectionUtils; +import com.tencent.supersonic.headless.chat.utils.ComponentFactory; +import com.tencent.supersonic.headless.chat.utils.S2SqlDateHelper; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -29,9 +27,12 @@ import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; - -import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE; -import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; @Slf4j @Service @@ -62,8 +63,10 @@ public class LLMRequestService { return dataSetResolver.resolve(queryCtx, queryCtx.getDataSetIds()); } - public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId, - SemanticSchema semanticSchema, List linkingValues) { + public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId) { + LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class); + List linkingValues = requestService.getValues(queryCtx, dataSetId); + SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); Map dataSetIdToName = semanticSchema.getDataSetIdToName(); String queryText = queryCtx.getQueryText(); @@ -114,7 +117,7 @@ public class LLMRequestService { } protected List getFieldNameList(QueryContext queryCtx, Long dataSetId, - LLMParserConfig llmParserConfig) { + LLMParserConfig llmParserConfig) { Set results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java index 4ed7ec1b6..1146daa3d 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java @@ -1,7 +1,7 @@ package com.tencent.supersonic.headless.chat.parser.llm; import com.tencent.supersonic.common.pojo.Constants; -import com.tencent.supersonic.common.jsqlparser.SqlEqualHelper; +import com.tencent.supersonic.common.jsqlparser.SqlValidHelper; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.chat.query.QueryManager; import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery; @@ -9,6 +9,7 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp; import com.tencent.supersonic.headless.chat.QueryContext; +import java.util.ArrayList; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.MapUtils; import org.springframework.stereotype.Service; @@ -41,14 +42,20 @@ public class LLMResponseService { return parseInfo; } - public Map getDeduplicationSqlResp(LLMResp llmResp) { - if (MapUtils.isEmpty(llmResp.getSqlRespMap())) { - return llmResp.getSqlRespMap(); + public Map getDeduplicationSqlResp(int currentRetry, LLMResp llmResp) { + Map sqlRespMap = llmResp.getSqlRespMap(); + if (MapUtils.isEmpty(sqlRespMap)) { + LLMSqlResp llmSqlResp = new LLMSqlResp(1D, new ArrayList<>()); + sqlRespMap.put(llmResp.getSqlOutput(), llmSqlResp); } Map result = new HashMap<>(); - for (Map.Entry entry : llmResp.getSqlRespMap().entrySet()) { + for (Map.Entry entry : sqlRespMap.entrySet()) { String key = entry.getKey(); - if (result.keySet().stream().anyMatch(existKey -> SqlEqualHelper.equals(existKey, key))) { + if (result.keySet().stream().anyMatch(existKey -> SqlValidHelper.equals(existKey, key))) { + continue; + } + if (!SqlValidHelper.isValidSQL(key)) { + log.error("currentRetry:{},sql is not valid:{}", currentRetry, key); continue; } result.put(key, entry.getValue()); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java index 302cc7623..3f8d6a7af 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java @@ -2,20 +2,18 @@ package com.tencent.supersonic.headless.chat.parser.llm; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.headless.api.pojo.SemanticSchema; +import com.tencent.supersonic.headless.chat.ChatContext; import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.parser.SemanticParser; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp; -import com.tencent.supersonic.headless.chat.parser.SemanticParser; -import com.tencent.supersonic.headless.chat.ChatContext; +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.MapUtils; -import org.apache.commons.lang3.StringUtils; - -import java.util.List; -import java.util.Map; -import java.util.Objects; /** * LLMSqlParser uses large language model to understand query semantics and @@ -26,12 +24,12 @@ public class LLMSqlParser implements SemanticParser { @Override public void parse(QueryContext queryCtx, ChatContext chatCtx) { - LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class); - //1.determine whether to skip this parser. - if (requestService.isSkip(queryCtx)) { - return; - } try { + LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class); + //1.determine whether to skip this parser. + if (requestService.isSkip(queryCtx)) { + return; + } //2.get dataSetId from queryCtx and chatCtx. Long dataSetId = requestService.getDataSetId(queryCtx); if (dataSetId == null) { @@ -40,39 +38,46 @@ public class LLMSqlParser implements SemanticParser { log.info("Generate query statement for dataSetId:{}", dataSetId); //3.invoke LLM service to do parsing. - List linkingValues = requestService.getValues(queryCtx, dataSetId); - SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); - LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues); - LLMResp llmResp = requestService.runText2SQL(llmReq); - if (Objects.isNull(llmResp)) { - return; - } - - //4. deduplicate the S2SQL result list and build parserInfo - LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class); - Map deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp); - ParseResult parseResult = ParseResult.builder() - .dataSetId(dataSetId) - .llmReq(llmReq) - .llmResp(llmResp) - .linkingValues(linkingValues) - .build(); - - if (MapUtils.isEmpty(deduplicationSqlResp)) { - if (StringUtils.isNotBlank(llmResp.getSqlOutput())) { - responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D); - } - } else { - deduplicationSqlResp.forEach((sql, sqlResp) -> { - if (StringUtils.isNotBlank(sql)) { - responseService.addParseInfo(queryCtx, parseResult, sql, sqlResp.getSqlWeight()); - } - }); - } - + LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId); + recall(queryCtx, dataSetId, llmReq); } catch (Exception e) { log.error("Failed to parse query:", e); } } + private void recall(QueryContext queryCtx, Long dataSetId, LLMReq llmReq) { + LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class); + LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class); + int maxRetries = ContextUtils.getBean(LLMParserConfig.class).getRecallMaxRetries(); + int currentRetry = 0; + Map sqlRespMap = new HashMap<>(); + ParseResult parseResult = null; + while (currentRetry < maxRetries) { + log.info("currentRetry:{},start runText2SQL", currentRetry); + try { + LLMResp llmResp = requestService.runText2SQL(llmReq); + if (Objects.nonNull(llmResp)) { + //deduplicate the S2SQL result list and build parserInfo + sqlRespMap = responseService.getDeduplicationSqlResp(currentRetry, llmResp); + if (MapUtils.isNotEmpty(sqlRespMap)) { + parseResult = ParseResult.builder() + .dataSetId(dataSetId).llmReq(llmReq).llmResp(llmResp) + .linkingValues(llmReq.getLinking()).build(); + break; + } + } + } catch (Exception e) { + log.error("currentRetry:{},runText2SQL error", currentRetry, e); + } + currentRetry++; + } + if (MapUtils.isNotEmpty(sqlRespMap)) { + for (Entry entry : sqlRespMap.entrySet()) { + String sql = entry.getKey(); + double sqlWeight = entry.getValue().getSqlWeight(); + responseService.addParseInfo(queryCtx, parseResult, sql, sqlWeight); + } + } + } + } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java index 9e9b78b25..014923b7c 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java @@ -26,7 +26,6 @@ public class LLMReq { private SqlGenType sqlGenType; private LLMConfig llmConfig; - @Data public static class ElementValue { diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/s2sql/LLMResponseServiceTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/s2sql/LLMResponseServiceTest.java index 76b787c01..cb633663b 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/s2sql/LLMResponseServiceTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/s2sql/LLMResponseServiceTest.java @@ -23,7 +23,7 @@ class LLMResponseServiceTest { llmResp.setSqlRespMap(sqlWeight); LLMResponseService llmResponseService = new LLMResponseService(); - Map deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp); + Map deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(0, llmResp); Assert.assertEquals(deduplicationSqlResp.size(), 1); @@ -36,7 +36,7 @@ class LLMResponseServiceTest { sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build()); llmResp2.setSqlRespMap(sqlWeight2); - deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp2); + deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(0, llmResp2); Assert.assertEquals(deduplicationSqlResp.size(), 1); @@ -48,7 +48,7 @@ class LLMResponseServiceTest { sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build()); sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build()); llmResp3.setSqlRespMap(sqlWeight3); - deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp3); + deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(0, llmResp3); Assert.assertEquals(deduplicationSqlResp.size(), 2);