[fix](headless)Fix a number of issues. (#2026)
Some checks are pending
supersonic CentOS CI / build (21) (push) Waiting to run
supersonic mac CI / build (21) (push) Waiting to run
supersonic ubuntu CI / build (21) (push) Waiting to run
supersonic windows CI / build (21) (push) Waiting to run

This commit is contained in:
Jun Zhang
2025-02-02 12:50:29 +08:00
committed by GitHub
parent de92b357df
commit d294fec2a0
19 changed files with 184 additions and 239 deletions

View File

@@ -113,7 +113,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
sqlInfo.setCorrectedS2SQL(sql);
}
public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext,
public void removeUnmappedFilterValue(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectedS2SQL();

View File

@@ -3,12 +3,9 @@ package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.env.Environment;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
@@ -45,13 +42,6 @@ public class SelectCorrector extends BaseSemanticCorrector {
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
// decide whether add order by expression field to select
Environment environment = ContextUtils.getBean(Environment.class);
String correctorAdditionalInfo = environment.getProperty(ADDITIONAL_INFORMATION);
if (StringUtils.isNotBlank(correctorAdditionalInfo)
&& Boolean.parseBoolean(correctorAdditionalInfo)) {
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
}
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
return correctS2SQL;
}

View File

@@ -70,20 +70,24 @@ public class TimeRangeParser implements SemanticParser {
}
private DateConf parseDateCN(String queryText) {
List<TimeNLP> times = TimeNLPUtil.parse(queryText);
if (times.isEmpty()) {
try {
List<TimeNLP> times = TimeNLPUtil.parse(queryText);
if (times.isEmpty()) {
return null;
}
Date startDate = times.get(0).getTime();
String detectWord = times.get(0).getTimeExpression();
Date endDate = times.size() > 1 ? times.get(1).getTime() : startDate;
if (times.size() > 1) {
detectWord += "~" + times.get(1).getTimeExpression();
}
return getDateConf(startDate, endDate, detectWord);
} catch (Exception e) {
return null;
}
Date startDate = times.get(0).getTime();
String detectWord = times.get(0).getTimeExpression();
Date endDate = times.size() > 1 ? times.get(1).getTime() : startDate;
if (times.size() > 1) {
detectWord += "~" + times.get(1).getTimeExpression();
}
return getDateConf(startDate, endDate, detectWord);
}
private DateConf parseDateNumber(String queryText) {

View File

@@ -3,17 +3,12 @@ package com.tencent.supersonic.headless.chat.corrector;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.*;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.parser.llm.ParseResult;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import org.junit.Assert;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
@@ -21,7 +16,6 @@ import java.util.HashSet;
import java.util.List;
import java.util.Set;
@Disabled
class SchemaCorrectorTest {
private String json = "{\n" + " \"dataSetId\": 1,\n" + " \"llmReq\": {\n"
@@ -40,52 +34,54 @@ class SchemaCorrectorTest {
+ " },\n" + " \"request\": null\n" + "}";
@Test
void doCorrect() throws JsonProcessingException {
Long dataSetId = 1L;
ChatQueryContext chatQueryContext = buildQueryContext(dataSetId);
void testCorrectWrongColumnName() {
String sql = "SELECT 歌曲 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY SUM(播放) DESC LIMIT 10";
ChatQueryContext chatQueryContext = buildQueryContext(sql);
SemanticParseInfo parseInfo = chatQueryContext.getCandidateQueries().get(0).getParseInfo();
SchemaCorrector schemaCorrector = new SchemaCorrector();
schemaCorrector.correct(chatQueryContext, parseInfo);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY SUM(播放量) DESC LIMIT 10",
parseInfo.getSqlInfo().getCorrectedS2SQL());
}
@Test
void testRemoveUnmappedFilterValue() throws JsonProcessingException {
String sql = "SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10";
ChatQueryContext chatQueryContext = buildQueryContext(sql);
SemanticParseInfo parseInfo = chatQueryContext.getCandidateQueries().get(0).getParseInfo();
ObjectMapper objectMapper = new ObjectMapper();
ParseResult parseResult = objectMapper.readValue(json, ParseResult.class);
String sql = "select 歌曲名 from 歌曲 where 发行日期 >= '2024-01-01' "
+ "and 商务组 = 'xxx' order by 播放量 desc limit 10";
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setParsedS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSetId(dataSetId);
semanticParseInfo.setDataSet(schemaElement);
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
parseInfo.getProperties().put(Constants.CONTEXT, parseResult);
SchemaCorrector schemaCorrector = new SchemaCorrector();
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo);
schemaCorrector.removeUnmappedFilterValue(chatQueryContext, parseInfo);
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " + "ORDER BY 播放量 DESC LIMIT 10",
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
parseResult = objectMapper.readValue(json, ParseResult.class);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY 播放量 DESC LIMIT 10",
parseInfo.getSqlInfo().getCorrectedS2SQL());
List<LLMReq.ElementValue> linkingValues = new ArrayList<>();
LLMReq.ElementValue elementValue = new LLMReq.ElementValue();
elementValue.setFieldName("商务组");
elementValue.setFieldValue("xxx");
linkingValues.add(elementValue);
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
parseResult.getLlmReq().getSchema().setValues(linkingValues);
parseInfo.getProperties().put(Constants.CONTEXT, parseResult);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(sql);
semanticParseInfo.getSqlInfo().setParsedS2SQL(sql);
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
parseInfo.getSqlInfo().setCorrectedS2SQL(sql);
parseInfo.getSqlInfo().setParsedS2SQL(sql);
schemaCorrector.removeUnmappedFilterValue(chatQueryContext, parseInfo);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10",
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
parseInfo.getSqlInfo().getCorrectedS2SQL());
}
private ChatQueryContext buildQueryContext(Long dataSetId) {
private ChatQueryContext buildQueryContext(String sql) {
Long dataSetId = 1L;
ChatQueryContext chatQueryContext = new ChatQueryContext();
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
DataSetSchema dataSetSchema = new DataSetSchema();
@@ -94,27 +90,29 @@ class SchemaCorrectorTest {
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSetId(dataSetId);
dataSetSchema.setDataSet(schemaElement);
Set<SchemaElement> dimensions = new HashSet<>();
SchemaElement element1 = new SchemaElement();
element1.setDataSetId(1L);
element1.setName("歌曲名");
dimensions.add(element1);
SchemaElement element2 = new SchemaElement();
element2.setDataSetId(1L);
element2.setName("商务组");
dimensions.add(element2);
SchemaElement element3 = new SchemaElement();
element3.setDataSetId(1L);
element3.setName("发行日期");
dimensions.add(element3);
dimensions.add(SchemaElement.builder().name("歌曲名").dataSetId(dataSetId).build());
dimensions.add(SchemaElement.builder().name("商务组").dataSetId(dataSetId).build());
dimensions.add(SchemaElement.builder().name("发行日期").dataSetId(dataSetId).build());
dimensions.add(SchemaElement.builder().name("播放量").dataSetId(dataSetId).build());
dataSetSchema.setDimensions(dimensions);
dataSetSchemaList.add(dataSetSchema);
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
chatQueryContext.setSemanticSchema(semanticSchema);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setParsedS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
semanticParseInfo.setDataSet(dataSetSchema.getDataSet());
LLMSqlQuery sqlQuery = new LLMSqlQuery();
sqlQuery.setParseInfo(semanticParseInfo);
chatQueryContext.getCandidateQueries().add(sqlQuery);
return chatQueryContext;
}
}

View File

@@ -20,7 +20,7 @@ public class DefaultQueryCache implements QueryCache {
if (isCache(semanticQueryReq)) {
Object result = cacheManager.get(cacheKey);
if (Objects.nonNull(result)) {
log.info("query from cache, key:{},result:{}", cacheKey,
log.debug("query from cache, key:{},result:{}", cacheKey,
StringUtils.normalizeSpace(result.toString()));
}
return result;

View File

@@ -45,7 +45,7 @@ public class JdbcExecutor implements QueryExecutor {
sqlUtil.queryInternal(queryStatement.getSql(), queryResultWithColumns);
queryResultWithColumns.setSql(sql);
} catch (Exception e) {
log.error("queryInternal with error [{}]", StringUtils.normalizeSpace(e.getMessage()));
log.error("queryInternal with error ", e);
queryResultWithColumns.setErrorMsg(e.getMessage());
}
return queryResultWithColumns;

View File

@@ -50,7 +50,8 @@ public class DefaultSemanticTranslator implements SemanticTranslator {
optimizer.rewrite(queryStatement);
}
}
log.info("translated query SQL: [{}]", StringUtils.normalizeSpace(queryStatement.getSql()));
log.debug("translated query SQL: [{}]",
StringUtils.normalizeSpace(queryStatement.getSql()));
}
private void mergeOntologyQuery(QueryStatement queryStatement) throws Exception {

View File

@@ -288,8 +288,6 @@ public class S2SemanticLayerService implements SemanticLayerService {
queryStatement.setSql(semanticQueryReq.getSqlInfo().getQuerySQL());
queryStatement.setIsTranslated(true);
}
queryStatement.setDataSetId(semanticQueryReq.getDataSetId());
queryStatement.setDataSetName(semanticQueryReq.getDataSetName());
return queryStatement;
}
@@ -321,6 +319,11 @@ public class S2SemanticLayerService implements SemanticLayerService {
Long dataSetId = dataSetService.getDataSetIdFromSql(querySqlReq.getSql(), user);
querySqlReq.setDataSetId(dataSetId);
}
if (querySqlReq.getDataSetId() != null) {
DataSetResp dataSetResp = dataSetService.getDataSet(querySqlReq.getDataSetId());
queryStatement.setDataSetId(dataSetResp.getId());
queryStatement.setDataSetName(dataSetResp.getName());
}
return queryStatement;
}

View File

@@ -273,9 +273,10 @@ public class DictUtils {
private QuerySqlReq constructQuerySqlReq(DictItemResp dictItemResp) {
ModelResp model = modelService.getModel(dictItemResp.getModelId());
String sqlPattern =
"select %s,count(1) from tbl %s group by %s order by count(1) desc limit %d";
String bizName = dictItemResp.getBizName();
"select %s,count(1) from %s %s group by %s order by count(1) desc limit %d";
String dimBizName = dictItemResp.getBizName();
String whereStr = generateWhereStr(dictItemResp);
String where = StringUtils.isEmpty(whereStr) ? "" : "WHERE" + whereStr;
ItemValueConfig config = dictItemResp.getConfig();
@@ -286,7 +287,8 @@ public class DictUtils {
limit = Integer.MAX_VALUE;
}
String sql = String.format(sqlPattern, bizName, where, bizName, limit);
String sql =
String.format(sqlPattern, dimBizName, model.getBizName(), where, dimBizName, limit);
Set<Long> modelIds = new HashSet<>();
modelIds.add(dictItemResp.getModelId());
QuerySqlReq querySqlReq = new QuerySqlReq();