mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(chat) The large model parsing supports SQL result verification and adds three retries (#1194)
This commit is contained in:
@@ -2,13 +2,15 @@ package com.tencent.supersonic.common.jsqlparser;
|
|||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
|
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sql Parser equal Helper
|
* Sql Parser valid Helper
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class SqlEqualHelper {
|
public class SqlValidHelper {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* determine if two SQL statements are equal.
|
* determine if two SQL statements are equal.
|
||||||
@@ -63,5 +65,15 @@ public class SqlEqualHelper {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static boolean isValidSQL(String sql) {
|
||||||
|
try {
|
||||||
|
CCJSqlParserUtil.parse(sql);
|
||||||
|
return true;
|
||||||
|
} catch (JSQLParserException e) {
|
||||||
|
log.error("isValidSQL parse:{}", e);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -11,7 +11,7 @@ class ChatModelProperties {
|
|||||||
String apiKey;
|
String apiKey;
|
||||||
Double temperature;
|
Double temperature;
|
||||||
Double topP;
|
Double topP;
|
||||||
String model;
|
String modelName;
|
||||||
Integer maxRetries;
|
Integer maxRetries;
|
||||||
Integer maxToken;
|
Integer maxToken;
|
||||||
Boolean logRequests;
|
Boolean logRequests;
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ public class ZhipuAutoConfig {
|
|||||||
return ZhipuAiChatModel.builder()
|
return ZhipuAiChatModel.builder()
|
||||||
.baseUrl(chatModelProperties.getBaseUrl())
|
.baseUrl(chatModelProperties.getBaseUrl())
|
||||||
.apiKey(chatModelProperties.getApiKey())
|
.apiKey(chatModelProperties.getApiKey())
|
||||||
.model(chatModelProperties.getModel())
|
.model(chatModelProperties.getModelName())
|
||||||
.temperature(chatModelProperties.getTemperature())
|
.temperature(chatModelProperties.getTemperature())
|
||||||
.topP(chatModelProperties.getTopP())
|
.topP(chatModelProperties.getTopP())
|
||||||
.maxRetries(chatModelProperties.getMaxRetries())
|
.maxRetries(chatModelProperties.getMaxRetries())
|
||||||
@@ -38,7 +38,7 @@ public class ZhipuAutoConfig {
|
|||||||
return ZhipuAiStreamingChatModel.builder()
|
return ZhipuAiStreamingChatModel.builder()
|
||||||
.baseUrl(chatModelProperties.getBaseUrl())
|
.baseUrl(chatModelProperties.getBaseUrl())
|
||||||
.apiKey(chatModelProperties.getApiKey())
|
.apiKey(chatModelProperties.getApiKey())
|
||||||
.model(chatModelProperties.getModel())
|
.model(chatModelProperties.getModelName())
|
||||||
.temperature(chatModelProperties.getTemperature())
|
.temperature(chatModelProperties.getTemperature())
|
||||||
.topP(chatModelProperties.getTopP())
|
.topP(chatModelProperties.getTopP())
|
||||||
.maxToken(chatModelProperties.getMaxToken())
|
.maxToken(chatModelProperties.getMaxToken())
|
||||||
|
|||||||
@@ -4,39 +4,33 @@ package com.tencent.supersonic.common.jsqlparser;
|
|||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
/**
|
class SqlValidHelperTest {
|
||||||
* @author lex luo
|
|
||||||
* @date 2023/11/15 15:04
|
|
||||||
*/
|
|
||||||
class SqlEqualHelperTest {
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testEquals() {
|
void testEquals() {
|
||||||
String sql1 = "SELECT * FROM table1 WHERE column1 = 1 AND column2 = 2";
|
String sql1 = "SELECT * FROM table1 WHERE column1 = 1 AND column2 = 2";
|
||||||
String sql2 = "SELECT * FROM table1 WHERE column2 = 2 AND column1 = 1";
|
String sql2 = "SELECT * FROM table1 WHERE column2 = 2 AND column1 = 1";
|
||||||
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true);
|
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
|
||||||
|
|
||||||
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||||
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||||
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true);
|
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
|
||||||
|
|
||||||
sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a";
|
sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a";
|
||||||
|
|
||||||
sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a";
|
sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a";
|
||||||
|
|
||||||
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true);
|
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
|
||||||
|
|
||||||
sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a";
|
sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a";
|
||||||
|
|
||||||
sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a";
|
sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a";
|
||||||
|
|
||||||
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true);
|
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
|
||||||
|
|
||||||
|
|
||||||
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||||
sql2 = "SELECT d,c,b,f FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
sql2 = "SELECT d,c,b,f FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||||
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), false);
|
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), false);
|
||||||
|
|
||||||
|
|
||||||
sql1 = "SELECT\n"
|
sql1 = "SELECT\n"
|
||||||
+ "页面,\n"
|
+ "页面,\n"
|
||||||
@@ -65,6 +59,27 @@ class SqlEqualHelperTest {
|
|||||||
+ "页面\n"
|
+ "页面\n"
|
||||||
+ "LIMIT\n"
|
+ "LIMIT\n"
|
||||||
+ "365";
|
+ "365";
|
||||||
Assert.assertEquals(SqlEqualHelper.equals(sql1, sql2), true);
|
Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testIsValidSQL() {
|
||||||
|
String sql1 = "SELECT * FROM table1 WHERE column1 = 1 AND column2 = 2";
|
||||||
|
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), true);
|
||||||
|
|
||||||
|
sql1 = "SELECT sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2";
|
||||||
|
|
||||||
|
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), true);
|
||||||
|
|
||||||
|
sql1 = "SELECT a,b,c, FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||||
|
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false);
|
||||||
|
|
||||||
|
sql1 = "SELECTa,b,c,d FROM table1";
|
||||||
|
|
||||||
|
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false);
|
||||||
|
|
||||||
|
sql1 = "SELECT sum(b),sum(c),sum(d) FROM table1 WHERE";
|
||||||
|
|
||||||
|
Assert.assertEquals(SqlValidHelper.isValidSQL(sql1), false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -15,6 +15,9 @@ public class LLMParserConfig {
|
|||||||
@Value("${s2.query2sql.path:/query2sql}")
|
@Value("${s2.query2sql.path:/query2sql}")
|
||||||
private String queryToSqlPath;
|
private String queryToSqlPath;
|
||||||
|
|
||||||
|
@Value("${s2.recall.max.retries:3}")
|
||||||
|
private int recallMaxRetries;
|
||||||
|
|
||||||
@Value("${s2.dimension.topn:10}")
|
@Value("${s2.dimension.topn:10}")
|
||||||
private Integer dimensionTopN;
|
private Integer dimensionTopN;
|
||||||
|
|
||||||
|
|||||||
@@ -1,25 +1,23 @@
|
|||||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||||
|
|
||||||
|
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
|
||||||
|
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.DateUtils;
|
import com.tencent.supersonic.common.util.DateUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
|
||||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.utils.S2SqlDateHelper;
|
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||||
|
import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker;
|
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import com.tencent.supersonic.headless.chat.utils.S2SqlDateHelper;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
@@ -29,9 +27,12 @@ import java.util.Map;
|
|||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Service
|
@Service
|
||||||
@@ -62,8 +63,10 @@ public class LLMRequestService {
|
|||||||
return dataSetResolver.resolve(queryCtx, queryCtx.getDataSetIds());
|
return dataSetResolver.resolve(queryCtx, queryCtx.getDataSetIds());
|
||||||
}
|
}
|
||||||
|
|
||||||
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId,
|
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId) {
|
||||||
SemanticSchema semanticSchema, List<LLMReq.ElementValue> linkingValues) {
|
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||||
|
List<LLMReq.ElementValue> linkingValues = requestService.getValues(queryCtx, dataSetId);
|
||||||
|
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||||
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
|
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
|
||||||
String queryText = queryCtx.getQueryText();
|
String queryText = queryCtx.getQueryText();
|
||||||
|
|
||||||
@@ -114,7 +117,7 @@ public class LLMRequestService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
|
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
|
||||||
LLMParserConfig llmParserConfig) {
|
LLMParserConfig llmParserConfig) {
|
||||||
|
|
||||||
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);
|
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlEqualHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
||||||
@@ -9,6 +9,7 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
|||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||||
|
import java.util.ArrayList;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.MapUtils;
|
import org.apache.commons.collections.MapUtils;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
@@ -41,14 +42,20 @@ public class LLMResponseService {
|
|||||||
return parseInfo;
|
return parseInfo;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<String, LLMSqlResp> getDeduplicationSqlResp(LLMResp llmResp) {
|
public Map<String, LLMSqlResp> getDeduplicationSqlResp(int currentRetry, LLMResp llmResp) {
|
||||||
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
|
Map<String, LLMSqlResp> sqlRespMap = llmResp.getSqlRespMap();
|
||||||
return llmResp.getSqlRespMap();
|
if (MapUtils.isEmpty(sqlRespMap)) {
|
||||||
|
LLMSqlResp llmSqlResp = new LLMSqlResp(1D, new ArrayList<>());
|
||||||
|
sqlRespMap.put(llmResp.getSqlOutput(), llmSqlResp);
|
||||||
}
|
}
|
||||||
Map<String, LLMSqlResp> result = new HashMap<>();
|
Map<String, LLMSqlResp> result = new HashMap<>();
|
||||||
for (Map.Entry<String, LLMSqlResp> entry : llmResp.getSqlRespMap().entrySet()) {
|
for (Map.Entry<String, LLMSqlResp> entry : sqlRespMap.entrySet()) {
|
||||||
String key = entry.getKey();
|
String key = entry.getKey();
|
||||||
if (result.keySet().stream().anyMatch(existKey -> SqlEqualHelper.equals(existKey, key))) {
|
if (result.keySet().stream().anyMatch(existKey -> SqlValidHelper.equals(existKey, key))) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!SqlValidHelper.isValidSQL(key)) {
|
||||||
|
log.error("currentRetry:{},sql is not valid:{}", currentRetry, key);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
result.put(key, entry.getValue());
|
result.put(key, entry.getValue());
|
||||||
|
|||||||
@@ -2,20 +2,18 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
|||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||||
|
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
||||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
import java.util.HashMap;
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
import java.util.Map;
|
||||||
|
import java.util.Map.Entry;
|
||||||
|
import java.util.Objects;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.MapUtils;
|
import org.apache.commons.collections.MapUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* LLMSqlParser uses large language model to understand query semantics and
|
* LLMSqlParser uses large language model to understand query semantics and
|
||||||
@@ -26,12 +24,12 @@ public class LLMSqlParser implements SemanticParser {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
|
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
|
||||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
|
||||||
//1.determine whether to skip this parser.
|
|
||||||
if (requestService.isSkip(queryCtx)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
try {
|
try {
|
||||||
|
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||||
|
//1.determine whether to skip this parser.
|
||||||
|
if (requestService.isSkip(queryCtx)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
//2.get dataSetId from queryCtx and chatCtx.
|
//2.get dataSetId from queryCtx and chatCtx.
|
||||||
Long dataSetId = requestService.getDataSetId(queryCtx);
|
Long dataSetId = requestService.getDataSetId(queryCtx);
|
||||||
if (dataSetId == null) {
|
if (dataSetId == null) {
|
||||||
@@ -40,39 +38,46 @@ public class LLMSqlParser implements SemanticParser {
|
|||||||
log.info("Generate query statement for dataSetId:{}", dataSetId);
|
log.info("Generate query statement for dataSetId:{}", dataSetId);
|
||||||
|
|
||||||
//3.invoke LLM service to do parsing.
|
//3.invoke LLM service to do parsing.
|
||||||
List<LLMReq.ElementValue> linkingValues = requestService.getValues(queryCtx, dataSetId);
|
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId);
|
||||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
recall(queryCtx, dataSetId, llmReq);
|
||||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues);
|
|
||||||
LLMResp llmResp = requestService.runText2SQL(llmReq);
|
|
||||||
if (Objects.isNull(llmResp)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
//4. deduplicate the S2SQL result list and build parserInfo
|
|
||||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
|
||||||
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
|
|
||||||
ParseResult parseResult = ParseResult.builder()
|
|
||||||
.dataSetId(dataSetId)
|
|
||||||
.llmReq(llmReq)
|
|
||||||
.llmResp(llmResp)
|
|
||||||
.linkingValues(linkingValues)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
if (MapUtils.isEmpty(deduplicationSqlResp)) {
|
|
||||||
if (StringUtils.isNotBlank(llmResp.getSqlOutput())) {
|
|
||||||
responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
deduplicationSqlResp.forEach((sql, sqlResp) -> {
|
|
||||||
if (StringUtils.isNotBlank(sql)) {
|
|
||||||
responseService.addParseInfo(queryCtx, parseResult, sql, sqlResp.getSqlWeight());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("Failed to parse query:", e);
|
log.error("Failed to parse query:", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void recall(QueryContext queryCtx, Long dataSetId, LLMReq llmReq) {
|
||||||
|
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||||
|
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||||
|
int maxRetries = ContextUtils.getBean(LLMParserConfig.class).getRecallMaxRetries();
|
||||||
|
int currentRetry = 0;
|
||||||
|
Map<String, LLMSqlResp> sqlRespMap = new HashMap<>();
|
||||||
|
ParseResult parseResult = null;
|
||||||
|
while (currentRetry < maxRetries) {
|
||||||
|
log.info("currentRetry:{},start runText2SQL", currentRetry);
|
||||||
|
try {
|
||||||
|
LLMResp llmResp = requestService.runText2SQL(llmReq);
|
||||||
|
if (Objects.nonNull(llmResp)) {
|
||||||
|
//deduplicate the S2SQL result list and build parserInfo
|
||||||
|
sqlRespMap = responseService.getDeduplicationSqlResp(currentRetry, llmResp);
|
||||||
|
if (MapUtils.isNotEmpty(sqlRespMap)) {
|
||||||
|
parseResult = ParseResult.builder()
|
||||||
|
.dataSetId(dataSetId).llmReq(llmReq).llmResp(llmResp)
|
||||||
|
.linkingValues(llmReq.getLinking()).build();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("currentRetry:{},runText2SQL error", currentRetry, e);
|
||||||
|
}
|
||||||
|
currentRetry++;
|
||||||
|
}
|
||||||
|
if (MapUtils.isNotEmpty(sqlRespMap)) {
|
||||||
|
for (Entry<String, LLMSqlResp> entry : sqlRespMap.entrySet()) {
|
||||||
|
String sql = entry.getKey();
|
||||||
|
double sqlWeight = entry.getValue().getSqlWeight();
|
||||||
|
responseService.addParseInfo(queryCtx, parseResult, sql, sqlWeight);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ public class LLMReq {
|
|||||||
private SqlGenType sqlGenType;
|
private SqlGenType sqlGenType;
|
||||||
|
|
||||||
private LLMConfig llmConfig;
|
private LLMConfig llmConfig;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public static class ElementValue {
|
public static class ElementValue {
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class LLMResponseServiceTest {
|
|||||||
|
|
||||||
llmResp.setSqlRespMap(sqlWeight);
|
llmResp.setSqlRespMap(sqlWeight);
|
||||||
LLMResponseService llmResponseService = new LLMResponseService();
|
LLMResponseService llmResponseService = new LLMResponseService();
|
||||||
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp);
|
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(0, llmResp);
|
||||||
|
|
||||||
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ class LLMResponseServiceTest {
|
|||||||
sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||||
|
|
||||||
llmResp2.setSqlRespMap(sqlWeight2);
|
llmResp2.setSqlRespMap(sqlWeight2);
|
||||||
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp2);
|
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(0, llmResp2);
|
||||||
|
|
||||||
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
||||||
|
|
||||||
@@ -48,7 +48,7 @@ class LLMResponseServiceTest {
|
|||||||
sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||||
sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||||
llmResp3.setSqlRespMap(sqlWeight3);
|
llmResp3.setSqlRespMap(sqlWeight3);
|
||||||
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp3);
|
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(0, llmResp3);
|
||||||
|
|
||||||
Assert.assertEquals(deduplicationSqlResp.size(), 2);
|
Assert.assertEquals(deduplicationSqlResp.size(), 2);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user