mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 14:36:47 +00:00
[improvement][chat]llm parser corrector is simplified by sql distribution (#120)
This commit is contained in:
@@ -1,45 +0,0 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.MockedStatic;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
class DateFieldCorrectorTest {
|
||||
|
||||
@Test
|
||||
void corrector() {
|
||||
MockedStatic<DSLDateHelper> dslDateHelper = Mockito.mockStatic(DSLDateHelper.class);
|
||||
|
||||
dslDateHelper.when(() -> DSLDateHelper.getReferenceDate(any())).thenReturn("2023-08-14");
|
||||
DateFieldCorrector dateFieldCorrector = new DateFieldCorrector();
|
||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
SchemaElement model = new SchemaElement();
|
||||
model.setId(2L);
|
||||
parseInfo.setModel(model);
|
||||
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
|
||||
.sql("select count(歌曲名) from 歌曲库 ")
|
||||
.parseInfo(parseInfo)
|
||||
.build();
|
||||
|
||||
dateFieldCorrector.correct(semanticCorrectInfo);
|
||||
|
||||
Assert.assertEquals("SELECT count(歌曲名) FROM 歌曲库 WHERE 数据日期 = '2023-08-14'", semanticCorrectInfo.getSql());
|
||||
|
||||
semanticCorrectInfo = SemanticCorrectInfo.builder()
|
||||
.sql("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'")
|
||||
.parseInfo(parseInfo)
|
||||
.build();
|
||||
|
||||
dateFieldCorrector.correct(semanticCorrectInfo);
|
||||
|
||||
Assert.assertEquals("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'", semanticCorrectInfo.getSql());
|
||||
|
||||
}
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class FieldNameCorrectorTest {
|
||||
|
||||
@Test
|
||||
void corrector() {
|
||||
|
||||
FieldNameCorrector corrector = new FieldNameCorrector();
|
||||
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
|
||||
.sql("select 歌曲名 from 歌曲库 where 专辑照片 = '七里香' and 专辑名 = '流行' and 数据日期 = '2023-08-19'")
|
||||
.build();
|
||||
|
||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
|
||||
DSLParseResult dslParseResult = new DSLParseResult();
|
||||
LLMReq llmReq = new LLMReq();
|
||||
List<ElementValue> linking = new ArrayList<>();
|
||||
ElementValue elementValue = new ElementValue();
|
||||
elementValue.setFieldValue("流行");
|
||||
elementValue.setFieldName("歌曲风格");
|
||||
linking.add(elementValue);
|
||||
|
||||
ElementValue elementValue2 = new ElementValue();
|
||||
elementValue2.setFieldValue("七里香");
|
||||
elementValue2.setFieldName("歌曲名");
|
||||
linking.add(elementValue2);
|
||||
|
||||
ElementValue elementValue3 = new ElementValue();
|
||||
elementValue3.setFieldValue("周杰伦");
|
||||
elementValue3.setFieldName("歌手名");
|
||||
linking.add(elementValue3);
|
||||
|
||||
ElementValue elementValue4 = new ElementValue();
|
||||
elementValue4.setFieldValue("流行");
|
||||
elementValue4.setFieldName("歌曲流派");
|
||||
linking.add(elementValue4);
|
||||
|
||||
llmReq.setLinking(linking);
|
||||
dslParseResult.setLlmReq(llmReq);
|
||||
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, dslParseResult);
|
||||
|
||||
parseInfo.setProperties(properties);
|
||||
semanticCorrectInfo.setParseInfo(parseInfo);
|
||||
|
||||
corrector.correct(semanticCorrectInfo);
|
||||
|
||||
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE 歌曲名 = '七里香' AND 歌曲流派 = '流行' AND 数据日期 = '2023-08-19'",
|
||||
semanticCorrectInfo.getSql());
|
||||
}
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.MockedStatic;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
class FieldValueCorrectorTest {
|
||||
|
||||
|
||||
@Test
|
||||
void corrector() {
|
||||
|
||||
MockedStatic<ContextUtils> mockContextUtils = Mockito.mockStatic(ContextUtils.class);
|
||||
|
||||
SchemaService mockSchemaService = Mockito.mock(SchemaService.class);
|
||||
|
||||
SemanticSchema mockSemanticSchema = Mockito.mock(SemanticSchema.class);
|
||||
|
||||
List<SchemaElement> dimensions = new ArrayList<>();
|
||||
List<SchemaValueMap> schemaValueMaps = new ArrayList<>();
|
||||
SchemaValueMap value1 = new SchemaValueMap();
|
||||
value1.setBizName("杰伦");
|
||||
value1.setTechName("周杰伦");
|
||||
value1.setAlias(Arrays.asList("周杰倫", "Jay Chou", "周董", "周先生"));
|
||||
schemaValueMaps.add(value1);
|
||||
|
||||
SchemaElement schemaElement = SchemaElement.builder()
|
||||
.bizName("singer_name")
|
||||
.name("歌手名")
|
||||
.model(2L)
|
||||
.schemaValueMaps(schemaValueMaps)
|
||||
.build();
|
||||
dimensions.add(schemaElement);
|
||||
|
||||
when(mockSemanticSchema.getDimensions()).thenReturn(dimensions);
|
||||
when(mockSchemaService.getSemanticSchema()).thenReturn(mockSemanticSchema);
|
||||
mockContextUtils.when(() -> ContextUtils.getBean(SchemaService.class)).thenReturn(mockSchemaService);
|
||||
|
||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
SchemaElement model = new SchemaElement();
|
||||
model.setId(2L);
|
||||
parseInfo.setModel(model);
|
||||
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
|
||||
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生'")
|
||||
.parseInfo(parseInfo)
|
||||
.build();
|
||||
|
||||
FieldValueCorrector corrector = new FieldValueCorrector();
|
||||
corrector.correct(semanticCorrectInfo);
|
||||
|
||||
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", semanticCorrectInfo.getSql());
|
||||
|
||||
semanticCorrectInfo.setSql("select count(song_name) from 歌曲库 where singer_name = '杰伦'");
|
||||
corrector.correct(semanticCorrectInfo);
|
||||
|
||||
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", semanticCorrectInfo.getSql());
|
||||
}
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class SelectFieldAppendCorrectorTest {
|
||||
|
||||
@Test
|
||||
void corrector() {
|
||||
SelectFieldAppendCorrector corrector = new SelectFieldAppendCorrector();
|
||||
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
|
||||
.sql("select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' "
|
||||
+ "and sys_imp_date = '2023-08-09' and 歌曲发布时 = '2023-08-01' order by 播放量 desc limit 11")
|
||||
.build();
|
||||
|
||||
corrector.correct(semanticCorrectInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名, 歌手名, 播放量, 歌曲发布时, 发布日期 FROM 歌曲库 WHERE "
|
||||
+ "datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '邓紫棋' "
|
||||
+ "AND sys_imp_date = '2023-08-09' AND 歌曲发布时 = '2023-08-01'"
|
||||
+ " ORDER BY 播放量 DESC LIMIT 11", semanticCorrectInfo.getSql());
|
||||
|
||||
semanticCorrectInfo.setSql("select 用户名 from 内容库产品 where datediff('day', 数据日期, '2023-09-14') <= 30"
|
||||
+ " group by 用户名 having sum(访问次数) > 2000");
|
||||
|
||||
corrector.correct(semanticCorrectInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE "
|
||||
+ "datediff('day', 数据日期, '2023-09-14') <= 30 "
|
||||
+ "GROUP BY 用户名 HAVING sum(访问次数) > 2000", semanticCorrectInfo.getSql());
|
||||
|
||||
semanticCorrectInfo.setSql("SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE "
|
||||
+ "datediff('day', 数据日期, '2023-09-14') <= 30 "
|
||||
+ "GROUP BY 用户名 HAVING sum(访问次数) > 2000");
|
||||
|
||||
corrector.correct(semanticCorrectInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE "
|
||||
+ "datediff('day', 数据日期, '2023-09-14') <= 30 "
|
||||
+ "GROUP BY 用户名 HAVING sum(访问次数) > 2000", semanticCorrectInfo.getSql());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user