mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 20:25:12 +00:00
[fix](headless)Fix a number of issues. (#2026)
This commit is contained in:
@@ -113,7 +113,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
}
|
||||
|
||||
public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext,
|
||||
public void removeUnmappedFilterValue(ChatQueryContext chatQueryContext,
|
||||
SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
|
||||
|
||||
@@ -3,12 +3,9 @@ package com.tencent.supersonic.headless.chat.corrector;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.core.env.Environment;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -45,13 +42,6 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
|
||||
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
|
||||
|
||||
// decide whether add order by expression field to select
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
String correctorAdditionalInfo = environment.getProperty(ADDITIONAL_INFORMATION);
|
||||
if (StringUtils.isNotBlank(correctorAdditionalInfo)
|
||||
&& Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
|
||||
}
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
|
||||
return correctS2SQL;
|
||||
}
|
||||
|
||||
@@ -70,20 +70,24 @@ public class TimeRangeParser implements SemanticParser {
|
||||
}
|
||||
|
||||
private DateConf parseDateCN(String queryText) {
|
||||
List<TimeNLP> times = TimeNLPUtil.parse(queryText);
|
||||
if (times.isEmpty()) {
|
||||
try {
|
||||
List<TimeNLP> times = TimeNLPUtil.parse(queryText);
|
||||
if (times.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
Date startDate = times.get(0).getTime();
|
||||
String detectWord = times.get(0).getTimeExpression();
|
||||
Date endDate = times.size() > 1 ? times.get(1).getTime() : startDate;
|
||||
|
||||
if (times.size() > 1) {
|
||||
detectWord += "~" + times.get(1).getTimeExpression();
|
||||
}
|
||||
|
||||
return getDateConf(startDate, endDate, detectWord);
|
||||
} catch (Exception e) {
|
||||
return null;
|
||||
}
|
||||
|
||||
Date startDate = times.get(0).getTime();
|
||||
String detectWord = times.get(0).getTimeExpression();
|
||||
Date endDate = times.size() > 1 ? times.get(1).getTime() : startDate;
|
||||
|
||||
if (times.size() > 1) {
|
||||
detectWord += "~" + times.get(1).getTimeExpression();
|
||||
}
|
||||
|
||||
return getDateConf(startDate, endDate, detectWord);
|
||||
}
|
||||
|
||||
private DateConf parseDateNumber(String queryText) {
|
||||
|
||||
@@ -3,17 +3,12 @@ package com.tencent.supersonic.headless.chat.corrector;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.*;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.parser.llm.ParseResult;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -21,7 +16,6 @@ import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
@Disabled
|
||||
class SchemaCorrectorTest {
|
||||
|
||||
private String json = "{\n" + " \"dataSetId\": 1,\n" + " \"llmReq\": {\n"
|
||||
@@ -40,52 +34,54 @@ class SchemaCorrectorTest {
|
||||
+ " },\n" + " \"request\": null\n" + "}";
|
||||
|
||||
@Test
|
||||
void doCorrect() throws JsonProcessingException {
|
||||
Long dataSetId = 1L;
|
||||
ChatQueryContext chatQueryContext = buildQueryContext(dataSetId);
|
||||
void testCorrectWrongColumnName() {
|
||||
String sql = "SELECT 歌曲 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY SUM(播放) DESC LIMIT 10";
|
||||
ChatQueryContext chatQueryContext = buildQueryContext(sql);
|
||||
SemanticParseInfo parseInfo = chatQueryContext.getCandidateQueries().get(0).getParseInfo();
|
||||
|
||||
SchemaCorrector schemaCorrector = new SchemaCorrector();
|
||||
schemaCorrector.correct(chatQueryContext, parseInfo);
|
||||
|
||||
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY SUM(播放量) DESC LIMIT 10",
|
||||
parseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testRemoveUnmappedFilterValue() throws JsonProcessingException {
|
||||
String sql = "SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10";
|
||||
ChatQueryContext chatQueryContext = buildQueryContext(sql);
|
||||
SemanticParseInfo parseInfo = chatQueryContext.getCandidateQueries().get(0).getParseInfo();
|
||||
|
||||
ObjectMapper objectMapper = new ObjectMapper();
|
||||
ParseResult parseResult = objectMapper.readValue(json, ParseResult.class);
|
||||
|
||||
String sql = "select 歌曲名 from 歌曲 where 发行日期 >= '2024-01-01' "
|
||||
+ "and 商务组 = 'xxx' order by 播放量 desc limit 10";
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
sqlInfo.setParsedS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setDataSetId(dataSetId);
|
||||
semanticParseInfo.setDataSet(schemaElement);
|
||||
|
||||
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
||||
parseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
||||
|
||||
SchemaCorrector schemaCorrector = new SchemaCorrector();
|
||||
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo);
|
||||
schemaCorrector.removeUnmappedFilterValue(chatQueryContext, parseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " + "ORDER BY 播放量 DESC LIMIT 10",
|
||||
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
|
||||
parseResult = objectMapper.readValue(json, ParseResult.class);
|
||||
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY 播放量 DESC LIMIT 10",
|
||||
parseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
|
||||
List<LLMReq.ElementValue> linkingValues = new ArrayList<>();
|
||||
LLMReq.ElementValue elementValue = new LLMReq.ElementValue();
|
||||
elementValue.setFieldName("商务组");
|
||||
elementValue.setFieldValue("xxx");
|
||||
linkingValues.add(elementValue);
|
||||
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
||||
parseResult.getLlmReq().getSchema().setValues(linkingValues);
|
||||
parseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
||||
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.getSqlInfo().setParsedS2SQL(sql);
|
||||
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo);
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(sql);
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(sql);
|
||||
schemaCorrector.removeUnmappedFilterValue(chatQueryContext, parseInfo);
|
||||
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
||||
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10",
|
||||
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
parseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
}
|
||||
|
||||
private ChatQueryContext buildQueryContext(Long dataSetId) {
|
||||
private ChatQueryContext buildQueryContext(String sql) {
|
||||
Long dataSetId = 1L;
|
||||
|
||||
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
||||
DataSetSchema dataSetSchema = new DataSetSchema();
|
||||
@@ -94,27 +90,29 @@ class SchemaCorrectorTest {
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setDataSetId(dataSetId);
|
||||
dataSetSchema.setDataSet(schemaElement);
|
||||
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
SchemaElement element1 = new SchemaElement();
|
||||
element1.setDataSetId(1L);
|
||||
element1.setName("歌曲名");
|
||||
dimensions.add(element1);
|
||||
|
||||
SchemaElement element2 = new SchemaElement();
|
||||
element2.setDataSetId(1L);
|
||||
element2.setName("商务组");
|
||||
dimensions.add(element2);
|
||||
|
||||
SchemaElement element3 = new SchemaElement();
|
||||
element3.setDataSetId(1L);
|
||||
element3.setName("发行日期");
|
||||
dimensions.add(element3);
|
||||
|
||||
dimensions.add(SchemaElement.builder().name("歌曲名").dataSetId(dataSetId).build());
|
||||
dimensions.add(SchemaElement.builder().name("商务组").dataSetId(dataSetId).build());
|
||||
dimensions.add(SchemaElement.builder().name("发行日期").dataSetId(dataSetId).build());
|
||||
dimensions.add(SchemaElement.builder().name("播放量").dataSetId(dataSetId).build());
|
||||
dataSetSchema.setDimensions(dimensions);
|
||||
dataSetSchemaList.add(dataSetSchema);
|
||||
|
||||
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
||||
chatQueryContext.setSemanticSchema(semanticSchema);
|
||||
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
sqlInfo.setParsedS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
semanticParseInfo.setDataSet(dataSetSchema.getDataSet());
|
||||
LLMSqlQuery sqlQuery = new LLMSqlQuery();
|
||||
sqlQuery.setParseInfo(semanticParseInfo);
|
||||
chatQueryContext.getCandidateQueries().add(sqlQuery);
|
||||
|
||||
return chatQueryContext;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user