diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java index fc6166769..439e695a8 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/BaseSemanticCorrector.java @@ -22,6 +22,10 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; +/** + * basic semantic correction functionality, offering common methods and an + * abstract method called doCorrect + */ @Slf4j public abstract class BaseSemanticCorrector implements SemanticCorrector { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java index c7147d720..5d6a8a7b4 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java @@ -16,6 +16,9 @@ import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.springframework.util.CollectionUtils; +/** + * Perform SQL corrections on the "group by" section in S2SQL. + */ @Slf4j public class GroupByCorrector extends BaseSemanticCorrector { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java index 22c7a3d4e..826e6b942 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.corrector; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper; @@ -16,6 +17,9 @@ import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; import org.springframework.util.CollectionUtils; +/** + * Perform SQL corrections on the "Having" section in S2SQL. + */ @Slf4j public class HavingCorrector extends BaseSemanticCorrector { @@ -29,9 +33,7 @@ public class HavingCorrector extends BaseSemanticCorrector { addHavingToSelect(semanticParseInfo); //remove number condition - String correctorSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); - correctorSql = SqlParserRemoveHelper.removeNumberCondition(correctorSql); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctorSql); + removeNumberCondition(semanticParseInfo); } private void addHaving(SemanticParseInfo semanticParseInfo) { @@ -62,4 +64,10 @@ public class HavingCorrector extends BaseSemanticCorrector { return; } + private void removeNumberCondition(SemanticParseInfo semanticParseInfo) { + SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); + String correctorSql = SqlParserRemoveHelper.removeNumberCondition(sqlInfo.getCorrectS2SQL()); + sqlInfo.setCorrectS2SQL(correctorSql); + } + } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SchemaCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SchemaCorrector.java index 1623a1749..7eb6ed8fa 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SchemaCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SchemaCorrector.java @@ -17,6 +17,9 @@ import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.springframework.util.CollectionUtils; +/** + * Perform schema corrections on the Schema information in S2QL. + */ @Slf4j public class SchemaCorrector extends BaseSemanticCorrector { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java index 3036fbb82..ee8dfa5e0 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java @@ -7,6 +7,9 @@ import java.util.List; import lombok.extern.slf4j.Slf4j; import org.springframework.util.CollectionUtils; +/** + * Perform SQL corrections on the "Select" section in S2SQL. + */ @Slf4j public class SelectCorrector extends BaseSemanticCorrector { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java index 1812b107a..8c78b3e81 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java @@ -28,6 +28,9 @@ import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.util.Strings; import org.springframework.util.CollectionUtils; +/** + * Perform SQL corrections on the "Where" section in S2SQL. + */ @Slf4j public class WhereCorrector extends BaseSemanticCorrector { @@ -110,7 +113,6 @@ public class WhereCorrector extends BaseSemanticCorrector { String correctS2SQL = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), aliasAndBizNameToTechName); semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); - return; } private Map> getAliasAndBizNameToTechName(List dimensions) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMRequestService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMRequestService.java index 44935fc40..eab046dc9 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMRequestService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMRequestService.java @@ -17,15 +17,14 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.chat.service.AgentService; +import com.tencent.supersonic.chat.service.LLMParserLayer; import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.DateUtils; -import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.knowledge.service.SchemaService; import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem; import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp; -import java.net.URL; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -40,14 +39,8 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.http.HttpEntity; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; -import org.springframework.http.MediaType; -import org.springframework.http.ResponseEntity; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; -import org.springframework.web.client.RestTemplate; @Slf4j @Service @@ -61,9 +54,9 @@ public class LLMRequestService { @Autowired private SchemaService schemaService; @Autowired - private RestTemplate restTemplate; - @Autowired private OptimizationConfig optimizationConfig; + @Autowired + private LLMParserLayer llmParserLayer; public boolean check(QueryContext queryCtx) { QueryReq request = queryCtx.getRequest(); @@ -144,23 +137,7 @@ public class LLMRequestService { } public LLMResp requestLLM(LLMReq llmReq, Long modelId) { - long startTime = System.currentTimeMillis(); - log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq); - try { - URL url = new URL(new URL(llmParserConfig.getUrl()), llmParserConfig.getQueryToSqlPath()); - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); - HttpEntity entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers); - ResponseEntity responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity, - LLMResp.class); - - log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}", - System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody()); - return responseEntity.getBody(); - } catch (Exception e) { - log.error("requestLLM error", e); - } - return null; + return llmParserLayer.query2sql(llmReq, modelId); } protected List getFieldNameList(QueryContext queryCtx, Long modelId, LLMParserConfig llmParserConfig) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMS2SQLParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMS2SQLParser.java index 5f312c1c6..948018c35 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMS2SQLParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMS2SQLParser.java @@ -45,7 +45,7 @@ public class LLMS2SQLParser implements SemanticParser { if (Objects.isNull(llmResp)) { return; } - //5. get and update parserInfo and corrector sql + //5. get and update parserInfo Map sqlWeight = llmResp.getSqlWeight(); ParseResult parseResult = ParseResult.builder() .request(request) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/LLMParserLayer.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/LLMParserLayer.java new file mode 100644 index 000000000..bdbdeeb24 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/LLMParserLayer.java @@ -0,0 +1,13 @@ +package com.tencent.supersonic.chat.service; + +import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; + +/** + * Unified wrapper for invoking the llmparser Python service layer. + */ +public interface LLMParserLayer { + + LLMResp query2sql(LLMReq llmReq, Long modelId); + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/LLMParserLayerImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/LLMParserLayerImpl.java new file mode 100644 index 000000000..51b6266d4 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/LLMParserLayerImpl.java @@ -0,0 +1,47 @@ +package com.tencent.supersonic.chat.service.impl; + +import com.tencent.supersonic.chat.config.LLMParserConfig; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; +import com.tencent.supersonic.chat.service.LLMParserLayer; +import com.tencent.supersonic.common.util.JsonUtil; +import java.net.URL; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.stereotype.Service; +import org.springframework.web.client.RestTemplate; + +@Service +@Slf4j +public class LLMParserLayerImpl implements LLMParserLayer { + + @Autowired + private RestTemplate restTemplate; + @Autowired + private LLMParserConfig llmParserConfig; + + public LLMResp query2sql(LLMReq llmReq, Long modelId) { + long startTime = System.currentTimeMillis(); + log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq); + try { + URL url = new URL(new URL(llmParserConfig.getUrl()), llmParserConfig.getQueryToSqlPath()); + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + HttpEntity entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers); + ResponseEntity responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity, + LLMResp.class); + + log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}", + System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody()); + return responseEntity.getBody(); + } catch (Exception e) { + log.error("requestLLM error", e); + } + return null; + } +} diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java index 53eb01031..80ccf957f 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java @@ -166,7 +166,7 @@ public class SqlParserSelectHelper { return new ArrayList<>(results); } - private static ArrayList getFieldsByPlainSelect(PlainSelect plainSelect) { + private static List getFieldsByPlainSelect(PlainSelect plainSelect) { if (Objects.isNull(plainSelect)) { return new ArrayList<>(); } @@ -396,9 +396,7 @@ public class SqlParserSelectHelper { } SelectBody selectBody = selectStatement.getSelectBody(); PlainSelect plainSelect = (PlainSelect) selectBody; - - Table table = (Table) plainSelect.getFromItem(); - return table; + return (Table) plainSelect.getFromItem(); } public static String getDbTableName(String sql) { @@ -406,5 +404,12 @@ public class SqlParserSelectHelper { return table.getFullyQualifiedName(); } + public static String getNormalizedSql(String sql) { + Select selectStatement = getSelect(sql); + if (selectStatement == null) { + return null; + } + return selectStatement.toString(); + } }