[improvement](chat) Add a LLMParserLayer to interact with a Python service and add comments to certain classes (#388)

This commit is contained in:
lexluo09
2023-11-15 14:49:22 +08:00
committed by GitHub
parent 7ef3d92f2c
commit aa448b1ba3
11 changed files with 101 additions and 36 deletions

View File

@@ -22,6 +22,10 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
/**
* basic semantic correction functionality, offering common methods and an
* abstract method called doCorrect
*/
@Slf4j
public abstract class BaseSemanticCorrector implements SemanticCorrector {

View File

@@ -16,6 +16,9 @@ import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
/**
* Perform SQL corrections on the "group by" section in S2SQL.
*/
@Slf4j
public class GroupByCorrector extends BaseSemanticCorrector {

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
@@ -16,6 +17,9 @@ import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import org.springframework.util.CollectionUtils;
/**
* Perform SQL corrections on the "Having" section in S2SQL.
*/
@Slf4j
public class HavingCorrector extends BaseSemanticCorrector {
@@ -29,9 +33,7 @@ public class HavingCorrector extends BaseSemanticCorrector {
addHavingToSelect(semanticParseInfo);
//remove number condition
String correctorSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
correctorSql = SqlParserRemoveHelper.removeNumberCondition(correctorSql);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
removeNumberCondition(semanticParseInfo);
}
private void addHaving(SemanticParseInfo semanticParseInfo) {
@@ -62,4 +64,10 @@ public class HavingCorrector extends BaseSemanticCorrector {
return;
}
private void removeNumberCondition(SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctorSql = SqlParserRemoveHelper.removeNumberCondition(sqlInfo.getCorrectS2SQL());
sqlInfo.setCorrectS2SQL(correctorSql);
}
}

View File

@@ -17,6 +17,9 @@ import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
/**
* Perform schema corrections on the Schema information in S2QL.
*/
@Slf4j
public class SchemaCorrector extends BaseSemanticCorrector {

View File

@@ -7,6 +7,9 @@ import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
/**
* Perform SQL corrections on the "Select" section in S2SQL.
*/
@Slf4j
public class SelectCorrector extends BaseSemanticCorrector {

View File

@@ -28,6 +28,9 @@ import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import org.springframework.util.CollectionUtils;
/**
* Perform SQL corrections on the "Where" section in S2SQL.
*/
@Slf4j
public class WhereCorrector extends BaseSemanticCorrector {
@@ -110,7 +113,6 @@ public class WhereCorrector extends BaseSemanticCorrector {
String correctS2SQL = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
aliasAndBizNameToTechName);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
return;
}
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {

View File

@@ -17,15 +17,14 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.LLMParserLayer;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
@@ -40,14 +39,8 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import org.springframework.web.client.RestTemplate;
@Slf4j
@Service
@@ -61,9 +54,9 @@ public class LLMRequestService {
@Autowired
private SchemaService schemaService;
@Autowired
private RestTemplate restTemplate;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private LLMParserLayer llmParserLayer;
public boolean check(QueryContext queryCtx) {
QueryReq request = queryCtx.getRequest();
@@ -144,23 +137,7 @@ public class LLMRequestService {
}
public LLMResp requestLLM(LLMReq llmReq, Long modelId) {
long startTime = System.currentTimeMillis();
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
try {
URL url = new URL(new URL(llmParserConfig.getUrl()), llmParserConfig.getQueryToSqlPath());
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
LLMResp.class);
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody());
return responseEntity.getBody();
} catch (Exception e) {
log.error("requestLLM error", e);
}
return null;
return llmParserLayer.query2sql(llmReq, modelId);
}
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, LLMParserConfig llmParserConfig) {

View File

@@ -45,7 +45,7 @@ public class LLMS2SQLParser implements SemanticParser {
if (Objects.isNull(llmResp)) {
return;
}
//5. get and update parserInfo and corrector sql
//5. get and update parserInfo
Map<String, Double> sqlWeight = llmResp.getSqlWeight();
ParseResult parseResult = ParseResult.builder()
.request(request)

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
/**
* Unified wrapper for invoking the llmparser Python service layer.
*/
public interface LLMParserLayer {
LLMResp query2sql(LLMReq llmReq, Long modelId);
}

View File

@@ -0,0 +1,47 @@
package com.tencent.supersonic.chat.service.impl;
import com.tencent.supersonic.chat.config.LLMParserConfig;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.service.LLMParserLayer;
import com.tencent.supersonic.common.util.JsonUtil;
import java.net.URL;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
@Service
@Slf4j
public class LLMParserLayerImpl implements LLMParserLayer {
@Autowired
private RestTemplate restTemplate;
@Autowired
private LLMParserConfig llmParserConfig;
public LLMResp query2sql(LLMReq llmReq, Long modelId) {
long startTime = System.currentTimeMillis();
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
try {
URL url = new URL(new URL(llmParserConfig.getUrl()), llmParserConfig.getQueryToSqlPath());
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
LLMResp.class);
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody());
return responseEntity.getBody();
} catch (Exception e) {
log.error("requestLLM error", e);
}
return null;
}
}

View File

@@ -166,7 +166,7 @@ public class SqlParserSelectHelper {
return new ArrayList<>(results);
}
private static ArrayList<String> getFieldsByPlainSelect(PlainSelect plainSelect) {
private static List<String> getFieldsByPlainSelect(PlainSelect plainSelect) {
if (Objects.isNull(plainSelect)) {
return new ArrayList<>();
}
@@ -396,9 +396,7 @@ public class SqlParserSelectHelper {
}
SelectBody selectBody = selectStatement.getSelectBody();
PlainSelect plainSelect = (PlainSelect) selectBody;
Table table = (Table) plainSelect.getFromItem();
return table;
return (Table) plainSelect.getFromItem();
}
public static String getDbTableName(String sql) {
@@ -406,5 +404,12 @@ public class SqlParserSelectHelper {
return table.getFullyQualifiedName();
}
public static String getNormalizedSql(String sql) {
Select selectStatement = getSelect(sql);
if (selectStatement == null) {
return null;
}
return selectStatement.toString();
}
}