mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 06:27:21 +00:00
[release](project)update version 0.7.4 backend (#66)
This commit is contained in:
@@ -1,37 +1,45 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.MockedStatic;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
class DateFieldCorrectorTest {
|
||||
|
||||
@Test
|
||||
void rewriter() {
|
||||
void corrector() {
|
||||
MockedStatic<DSLDateHelper> dslDateHelper = Mockito.mockStatic(DSLDateHelper.class);
|
||||
|
||||
dslDateHelper.when(() -> DSLDateHelper.getReferenceDate(any())).thenReturn("2023-08-14");
|
||||
DateFieldCorrector dateFieldCorrector = new DateFieldCorrector();
|
||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
SchemaElement model = new SchemaElement();
|
||||
model.setId(2L);
|
||||
parseInfo.setModel(model);
|
||||
CorrectionInfo correctionInfo = CorrectionInfo.builder()
|
||||
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
|
||||
.sql("select count(歌曲名) from 歌曲库 ")
|
||||
.parseInfo(parseInfo)
|
||||
.build();
|
||||
|
||||
CorrectionInfo rewriter = dateFieldCorrector.corrector(correctionInfo);
|
||||
dateFieldCorrector.correct(semanticCorrectInfo);
|
||||
|
||||
Assert.assertEquals("SELECT count(歌曲名) FROM 歌曲库 WHERE 数据日期 = '2023-08-14'", rewriter.getSql());
|
||||
Assert.assertEquals("SELECT count(歌曲名) FROM 歌曲库 WHERE 数据日期 = '2023-08-14'", semanticCorrectInfo.getSql());
|
||||
|
||||
correctionInfo = CorrectionInfo.builder()
|
||||
semanticCorrectInfo = SemanticCorrectInfo.builder()
|
||||
.sql("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'")
|
||||
.parseInfo(parseInfo)
|
||||
.build();
|
||||
|
||||
rewriter = dateFieldCorrector.corrector(correctionInfo);
|
||||
dateFieldCorrector.correct(semanticCorrectInfo);
|
||||
|
||||
Assert.assertEquals("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'", rewriter.getSql());
|
||||
Assert.assertEquals("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'", semanticCorrectInfo.getSql());
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
||||
@@ -16,10 +16,10 @@ import org.junit.jupiter.api.Test;
|
||||
class FieldNameCorrectorTest {
|
||||
|
||||
@Test
|
||||
void rewriter() {
|
||||
void corrector() {
|
||||
|
||||
FieldNameCorrector corrector = new FieldNameCorrector();
|
||||
CorrectionInfo correctionInfo = CorrectionInfo.builder()
|
||||
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
|
||||
.sql("select 歌曲名 from 歌曲库 where 专辑照片 = '七里香' and 专辑名 = '流行' and 数据日期 = '2023-08-19'")
|
||||
.build();
|
||||
|
||||
@@ -55,11 +55,11 @@ class FieldNameCorrectorTest {
|
||||
properties.put(Constants.CONTEXT, dslParseResult);
|
||||
|
||||
parseInfo.setProperties(properties);
|
||||
correctionInfo.setParseInfo(parseInfo);
|
||||
semanticCorrectInfo.setParseInfo(parseInfo);
|
||||
|
||||
CorrectionInfo rewriter = corrector.corrector(correctionInfo);
|
||||
corrector.correct(semanticCorrectInfo);
|
||||
|
||||
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE 歌曲名 = '七里香' AND 歌曲流派 = '流行' AND 数据日期 = '2023-08-19'",
|
||||
rewriter.getSql());
|
||||
semanticCorrectInfo.getSql());
|
||||
}
|
||||
}
|
||||
@@ -2,9 +2,9 @@ package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -53,19 +53,19 @@ class FieldValueCorrectorTest {
|
||||
SchemaElement model = new SchemaElement();
|
||||
model.setId(2L);
|
||||
parseInfo.setModel(model);
|
||||
CorrectionInfo correctionInfo = CorrectionInfo.builder()
|
||||
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
|
||||
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生'")
|
||||
.parseInfo(parseInfo)
|
||||
.build();
|
||||
|
||||
FieldValueCorrector corrector = new FieldValueCorrector();
|
||||
CorrectionInfo info = corrector.corrector(correctionInfo);
|
||||
corrector.correct(semanticCorrectInfo);
|
||||
|
||||
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", info.getSql());
|
||||
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", semanticCorrectInfo.getSql());
|
||||
|
||||
correctionInfo.setSql("select count(song_name) from 歌曲库 where singer_name = '杰伦'");
|
||||
info = corrector.corrector(correctionInfo);
|
||||
semanticCorrectInfo.setSql("select count(song_name) from 歌曲库 where singer_name = '杰伦'");
|
||||
corrector.correct(semanticCorrectInfo);
|
||||
|
||||
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", info.getSql());
|
||||
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", semanticCorrectInfo.getSql());
|
||||
}
|
||||
}
|
||||
@@ -1,25 +1,26 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class SelectFieldAppendCorrectorTest {
|
||||
|
||||
@Test
|
||||
void rewriter() {
|
||||
void corrector() {
|
||||
SelectFieldAppendCorrector corrector = new SelectFieldAppendCorrector();
|
||||
CorrectionInfo correctionInfo = CorrectionInfo.builder()
|
||||
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
|
||||
.sql("select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' "
|
||||
+ "and sys_imp_date = '2023-08-09' and 歌曲发布时 = '2023-08-01' order by 播放量 desc limit 11")
|
||||
.build();
|
||||
|
||||
CorrectionInfo rewriter = corrector.corrector(correctionInfo);
|
||||
corrector.correct(semanticCorrectInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名, 歌手名, 歌曲发布时, 发布日期 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 "
|
||||
+ "AND 歌手名 = '邓紫棋' AND sys_imp_date = '2023-08-09' "
|
||||
+ "AND 歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", rewriter.getSql());
|
||||
"SELECT 歌曲名, 歌手名, 播放量, 歌曲发布时, 发布日期 FROM 歌曲库 WHERE "
|
||||
+ "datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '邓紫棋' "
|
||||
+ "AND sys_imp_date = '2023-08-09' AND 歌曲发布时 = '2023-08-01'"
|
||||
+ " ORDER BY 播放量 DESC LIMIT 11", semanticCorrectInfo.getSql());
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,9 @@ package com.tencent.supersonic.chat.parser.llm.dsl;
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -67,14 +67,14 @@ class LLMDslParserTest {
|
||||
SchemaElement model = new SchemaElement();
|
||||
model.setId(2L);
|
||||
parseInfo.setModel(model);
|
||||
CorrectionInfo correctionInfo = CorrectionInfo.builder()
|
||||
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
|
||||
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生' and YEAR(publish_time) >= 2023 and ")
|
||||
.parseInfo(parseInfo)
|
||||
.build();
|
||||
|
||||
LLMDslParser llmDslParser = new LLMDslParser();
|
||||
|
||||
llmDslParser.setFilter(correctionInfo, 2L, parseInfo);
|
||||
llmDslParser.setFilter(semanticCorrectInfo, 2L, parseInfo);
|
||||
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user