mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement](chat) Add a LLMParserLayer to interact with a Python service and add comments to certain classes (#388)
This commit is contained in:
@@ -22,6 +22,10 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* basic semantic correction functionality, offering common methods and an
|
||||
* abstract method called doCorrect
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
|
||||
@@ -16,6 +16,9 @@ import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "group by" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.corrector;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
|
||||
@@ -16,6 +17,9 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Having" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class HavingCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@@ -29,9 +33,7 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
addHavingToSelect(semanticParseInfo);
|
||||
|
||||
//remove number condition
|
||||
String correctorSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
correctorSql = SqlParserRemoveHelper.removeNumberCondition(correctorSql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
|
||||
removeNumberCondition(semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addHaving(SemanticParseInfo semanticParseInfo) {
|
||||
@@ -62,4 +64,10 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
return;
|
||||
}
|
||||
|
||||
private void removeNumberCondition(SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctorSql = SqlParserRemoveHelper.removeNumberCondition(sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.setCorrectS2SQL(correctorSql);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -17,6 +17,9 @@ import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform schema corrections on the Schema information in S2QL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
|
||||
|
||||
@@ -7,6 +7,9 @@ import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Select" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class SelectCorrector extends BaseSemanticCorrector {
|
||||
|
||||
|
||||
@@ -28,6 +28,9 @@ import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Where" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class WhereCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@@ -110,7 +113,6 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
String correctS2SQL = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||
aliasAndBizNameToTechName);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
return;
|
||||
}
|
||||
|
||||
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
|
||||
|
||||
@@ -17,15 +17,14 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.service.LLMParserLayer;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||
import java.net.URL;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
@@ -40,14 +39,8 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@@ -61,9 +54,9 @@ public class LLMRequestService {
|
||||
@Autowired
|
||||
private SchemaService schemaService;
|
||||
@Autowired
|
||||
private RestTemplate restTemplate;
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
@Autowired
|
||||
private LLMParserLayer llmParserLayer;
|
||||
|
||||
public boolean check(QueryContext queryCtx) {
|
||||
QueryReq request = queryCtx.getRequest();
|
||||
@@ -144,23 +137,7 @@ public class LLMRequestService {
|
||||
}
|
||||
|
||||
public LLMResp requestLLM(LLMReq llmReq, Long modelId) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
|
||||
try {
|
||||
URL url = new URL(new URL(llmParserConfig.getUrl()), llmParserConfig.getQueryToSqlPath());
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
|
||||
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
|
||||
LLMResp.class);
|
||||
|
||||
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
||||
System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody());
|
||||
return responseEntity.getBody();
|
||||
} catch (Exception e) {
|
||||
log.error("requestLLM error", e);
|
||||
}
|
||||
return null;
|
||||
return llmParserLayer.query2sql(llmReq, modelId);
|
||||
}
|
||||
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, LLMParserConfig llmParserConfig) {
|
||||
|
||||
@@ -45,7 +45,7 @@ public class LLMS2SQLParser implements SemanticParser {
|
||||
if (Objects.isNull(llmResp)) {
|
||||
return;
|
||||
}
|
||||
//5. get and update parserInfo and corrector sql
|
||||
//5. get and update parserInfo
|
||||
Map<String, Double> sqlWeight = llmResp.getSqlWeight();
|
||||
ParseResult parseResult = ParseResult.builder()
|
||||
.request(request)
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.tencent.supersonic.chat.service;
|
||||
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
|
||||
/**
|
||||
* Unified wrapper for invoking the llmparser Python service layer.
|
||||
*/
|
||||
public interface LLMParserLayer {
|
||||
|
||||
LLMResp query2sql(LLMReq llmReq, Long modelId);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package com.tencent.supersonic.chat.service.impl;
|
||||
|
||||
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.service.LLMParserLayer;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.net.URL;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class LLMParserLayerImpl implements LLMParserLayer {
|
||||
|
||||
@Autowired
|
||||
private RestTemplate restTemplate;
|
||||
@Autowired
|
||||
private LLMParserConfig llmParserConfig;
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, Long modelId) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
|
||||
try {
|
||||
URL url = new URL(new URL(llmParserConfig.getUrl()), llmParserConfig.getQueryToSqlPath());
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
|
||||
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
|
||||
LLMResp.class);
|
||||
|
||||
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
||||
System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody());
|
||||
return responseEntity.getBody();
|
||||
} catch (Exception e) {
|
||||
log.error("requestLLM error", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -166,7 +166,7 @@ public class SqlParserSelectHelper {
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
private static ArrayList<String> getFieldsByPlainSelect(PlainSelect plainSelect) {
|
||||
private static List<String> getFieldsByPlainSelect(PlainSelect plainSelect) {
|
||||
if (Objects.isNull(plainSelect)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
@@ -396,9 +396,7 @@ public class SqlParserSelectHelper {
|
||||
}
|
||||
SelectBody selectBody = selectStatement.getSelectBody();
|
||||
PlainSelect plainSelect = (PlainSelect) selectBody;
|
||||
|
||||
Table table = (Table) plainSelect.getFromItem();
|
||||
return table;
|
||||
return (Table) plainSelect.getFromItem();
|
||||
}
|
||||
|
||||
public static String getDbTableName(String sql) {
|
||||
@@ -406,5 +404,12 @@ public class SqlParserSelectHelper {
|
||||
return table.getFullyQualifiedName();
|
||||
}
|
||||
|
||||
public static String getNormalizedSql(String sql) {
|
||||
Select selectStatement = getSelect(sql);
|
||||
if (selectStatement == null) {
|
||||
return null;
|
||||
}
|
||||
return selectStatement.toString();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user