mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-16 23:23:10 +00:00
(improvement)(project) support for modifying filter conditions and fix group by pushdown and add windows scipt (#49)
Co-authored-by: lexluo <lexluo@tencent.com>
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
package com.tencent.supersonic.chat.query.llm.dsl.corrector;
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.corrector.DateFieldCorrector;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
package com.tencent.supersonic.chat.query.llm.dsl.corrector;
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.corrector.FieldValueCorrector;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
|
||||
@@ -14,12 +13,12 @@ import java.util.Map;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class FieldValueCorrectorTest {
|
||||
class FieldNameCorrectorTest {
|
||||
|
||||
@Test
|
||||
void rewriter() {
|
||||
|
||||
FieldValueCorrector corrector = new FieldValueCorrector();
|
||||
FieldNameCorrector corrector = new FieldNameCorrector();
|
||||
CorrectionInfo correctionInfo = CorrectionInfo.builder()
|
||||
.sql("select 歌曲名 from 歌曲库 where 专辑照片 = '七里香' and 专辑名 = '流行' and 数据日期 = '2023-08-19'")
|
||||
.build();
|
||||
@@ -0,0 +1,71 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.MockedStatic;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
class FieldValueCorrectorTest {
|
||||
|
||||
|
||||
@Test
|
||||
void corrector() {
|
||||
|
||||
MockedStatic<ContextUtils> mockContextUtils = Mockito.mockStatic(ContextUtils.class);
|
||||
|
||||
SchemaService mockSchemaService = Mockito.mock(SchemaService.class);
|
||||
|
||||
SemanticSchema mockSemanticSchema = Mockito.mock(SemanticSchema.class);
|
||||
|
||||
List<SchemaElement> dimensions = new ArrayList<>();
|
||||
List<SchemaValueMap> schemaValueMaps = new ArrayList<>();
|
||||
SchemaValueMap value1 = new SchemaValueMap();
|
||||
value1.setBizName("杰伦");
|
||||
value1.setTechName("周杰伦");
|
||||
value1.setAlias(Arrays.asList("周杰倫", "Jay Chou", "周董", "周先生"));
|
||||
schemaValueMaps.add(value1);
|
||||
|
||||
SchemaElement schemaElement = SchemaElement.builder()
|
||||
.bizName("singer_name")
|
||||
.name("歌手名")
|
||||
.model(2L)
|
||||
.schemaValueMaps(schemaValueMaps)
|
||||
.build();
|
||||
dimensions.add(schemaElement);
|
||||
|
||||
when(mockSemanticSchema.getDimensions()).thenReturn(dimensions);
|
||||
when(mockSchemaService.getSemanticSchema()).thenReturn(mockSemanticSchema);
|
||||
mockContextUtils.when(() -> ContextUtils.getBean(SchemaService.class)).thenReturn(mockSchemaService);
|
||||
|
||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
SchemaElement model = new SchemaElement();
|
||||
model.setId(2L);
|
||||
parseInfo.setModel(model);
|
||||
CorrectionInfo correctionInfo = CorrectionInfo.builder()
|
||||
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生'")
|
||||
.parseInfo(parseInfo)
|
||||
.build();
|
||||
|
||||
FieldValueCorrector corrector = new FieldValueCorrector();
|
||||
CorrectionInfo info = corrector.corrector(correctionInfo);
|
||||
|
||||
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", info.getSql());
|
||||
|
||||
correctionInfo.setSql("select count(song_name) from 歌曲库 where singer_name = '杰伦'");
|
||||
info = corrector.corrector(correctionInfo);
|
||||
|
||||
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", info.getSql());
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.chat.query.llm.dsl.corrector;
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.corrector.SelectFieldAppendCorrector;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
@@ -1,12 +1,80 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.MockedStatic;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
class LLMDslParserTest {
|
||||
|
||||
|
||||
@Test
|
||||
void getDimensionFilter() {
|
||||
}
|
||||
void setFilter() {
|
||||
MockedStatic<ContextUtils> mockContextUtils = Mockito.mockStatic(ContextUtils.class);
|
||||
|
||||
SchemaService mockSchemaService = Mockito.mock(SchemaService.class);
|
||||
SemanticSchema mockSemanticSchema = Mockito.mock(SemanticSchema.class);
|
||||
|
||||
List<SchemaElement> dimensions = new ArrayList<>();
|
||||
List<SchemaValueMap> schemaValueMaps = new ArrayList<>();
|
||||
SchemaValueMap value1 = new SchemaValueMap();
|
||||
value1.setBizName("杰伦");
|
||||
value1.setTechName("周杰伦");
|
||||
value1.setAlias(Arrays.asList("周杰倫", "Jay Chou", "周董", "周先生"));
|
||||
schemaValueMaps.add(value1);
|
||||
|
||||
SchemaElement schemaElement = SchemaElement.builder()
|
||||
.bizName("singer_name")
|
||||
.name("歌手名")
|
||||
.model(2L)
|
||||
.schemaValueMaps(schemaValueMaps)
|
||||
.build();
|
||||
dimensions.add(schemaElement);
|
||||
|
||||
SchemaElement schemaElement2 = SchemaElement.builder()
|
||||
.bizName("publish_time")
|
||||
.name("发布时间")
|
||||
.model(2L)
|
||||
.build();
|
||||
dimensions.add(schemaElement2);
|
||||
|
||||
when(mockSemanticSchema.getDimensions()).thenReturn(dimensions);
|
||||
|
||||
List<SchemaElement> metrics = new ArrayList<>();
|
||||
SchemaElement metric = SchemaElement.builder()
|
||||
.bizName("play_count")
|
||||
.name("播放量")
|
||||
.model(2L)
|
||||
.build();
|
||||
metrics.add(metric);
|
||||
|
||||
when(mockSemanticSchema.getMetrics()).thenReturn(metrics);
|
||||
|
||||
when(mockSchemaService.getSemanticSchema()).thenReturn(mockSemanticSchema);
|
||||
mockContextUtils.when(() -> ContextUtils.getBean(SchemaService.class)).thenReturn(mockSchemaService);
|
||||
|
||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
SchemaElement model = new SchemaElement();
|
||||
model.setId(2L);
|
||||
parseInfo.setModel(model);
|
||||
CorrectionInfo correctionInfo = CorrectionInfo.builder()
|
||||
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生' and YEAR(publish_time) >= 2023 and ")
|
||||
.parseInfo(parseInfo)
|
||||
.build();
|
||||
|
||||
LLMDslParser llmDslParser = new LLMDslParser();
|
||||
|
||||
llmDslParser.setFilter(correctionInfo, 2L, parseInfo);
|
||||
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,13 @@
|
||||
package com.tencent.supersonic.chat.test.context;
|
||||
|
||||
import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
|
||||
import com.tencent.supersonic.chat.persistence.repository.impl.ChatContextRepositoryImpl;
|
||||
import com.tencent.supersonic.chat.test.ChatBizLauncher;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
|
||||
import com.tencent.supersonic.knowledge.semantic.RemoteSemanticLayer;
|
||||
import com.tencent.supersonic.chat.test.ChatBizLauncher;
|
||||
import com.tencent.supersonic.semantic.model.domain.DimensionService;
|
||||
import com.tencent.supersonic.semantic.model.domain.MetricService;
|
||||
import com.tencent.supersonic.semantic.model.domain.ModelService;
|
||||
import com.tencent.supersonic.semantic.model.domain.MetricService;
|
||||
import com.tencent.supersonic.semantic.query.service.QueryService;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.slf4j.Logger;
|
||||
|
||||
@@ -2,11 +2,12 @@ package com.tencent.supersonic.chat.test.context;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
insert into chat_context (chat_id, modified_at, `user`, `query_text`, `semantic_parse`, ext_data)
|
||||
VALUES (1, '2023-05-24 00:00:00', 'admin', '超音数访问次数', '', 'admin');
|
||||
insert into chat_context (chat_id, modified_at , `user`, `query_text`, `semantic_parse` ,ext_data) VALUES(1, '2023-05-24 00:00:00', 'admin', '超音数访问次数', '', 'admin');
|
||||
@@ -1,59 +1,59 @@
|
||||
CREATE TABLE `chat_context`
|
||||
(
|
||||
`chat_id` BIGINT NOT NULL, -- context chat id
|
||||
`modified_at` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, -- row modify time
|
||||
`user` varchar(64) DEFAULT NULL, -- row modify user
|
||||
`query_text` LONGVARCHAR DEFAULT NULL, -- query text
|
||||
`semantic_parse` LONGVARCHAR DEFAULT NULL, -- parse data
|
||||
`ext_data` LONGVARCHAR DEFAULT NULL, -- extend data
|
||||
`chat_id` BIGINT NOT NULL , -- context chat id
|
||||
`modified_at` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP , -- row modify time
|
||||
`user` varchar(64) DEFAULT NULL , -- row modify user
|
||||
`query_text` LONGVARCHAR DEFAULT NULL , -- query text
|
||||
`semantic_parse` LONGVARCHAR DEFAULT NULL , -- parse data
|
||||
`ext_data` LONGVARCHAR DEFAULT NULL , -- extend data
|
||||
PRIMARY KEY (`chat_id`)
|
||||
);
|
||||
|
||||
|
||||
CREATE TABLE `chat`
|
||||
(
|
||||
`chat_id` BIGINT NOT NULL,-- AUTO_INCREMENT,
|
||||
`chat_name` varchar(100) DEFAULT NULL,
|
||||
`create_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
`last_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
`creator` varchar(30) DEFAULT NULL,
|
||||
`last_question` varchar(200) DEFAULT NULL,
|
||||
`is_delete` INT DEFAULT '0' COMMENT 'is deleted',
|
||||
`is_top` INT DEFAULT '0' COMMENT 'is top',
|
||||
`chat_id` BIGINT NOT NULL ,-- AUTO_INCREMENT,
|
||||
`chat_name` varchar(100) DEFAULT NULL,
|
||||
`create_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ,
|
||||
`last_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ,
|
||||
`creator` varchar(30) DEFAULT NULL,
|
||||
`last_question` varchar(200) DEFAULT NULL,
|
||||
`is_delete` INT DEFAULT '0' COMMENT 'is deleted',
|
||||
`is_top` INT DEFAULT '0' COMMENT 'is top',
|
||||
PRIMARY KEY (`chat_id`)
|
||||
);
|
||||
) ;
|
||||
|
||||
CREATE TABLE `chat_query`
|
||||
(
|
||||
`id` BIGINT NOT NULL,--AUTO_INCREMENT,
|
||||
`question_id` BIGINT DEFAULT NULL,
|
||||
`id` BIGINT NOT NULL ,--AUTO_INCREMENT,
|
||||
`question_id` BIGINT DEFAULT NULL,
|
||||
`create_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
`user_name` varchar(150) DEFAULT NULL COMMENT '',
|
||||
`question` varchar(300) DEFAULT NULL COMMENT '',
|
||||
`user_name` varchar(150) DEFAULT NULL COMMENT '',
|
||||
`question` varchar(300) DEFAULT NULL COMMENT '',
|
||||
`query_result` LONGVARCHAR,
|
||||
`time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
`time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ,
|
||||
`state` int(1) DEFAULT NULL,
|
||||
`data_content` varchar(30) DEFAULT NULL,
|
||||
`name` varchar(100) DEFAULT NULL,
|
||||
`data_content` varchar(30) DEFAULT NULL,
|
||||
`name` varchar(100) DEFAULT NULL,
|
||||
`scene_type` int(2) DEFAULT NULL,
|
||||
`query_type` int(2) DEFAULT NULL,
|
||||
`is_deleted` int(1) DEFAULT NULL,
|
||||
`module` varchar(30) DEFAULT NULL,
|
||||
`module` varchar(30) DEFAULT NULL,
|
||||
`entity` LONGVARCHAR COMMENT '',
|
||||
`chat_id` BIGINT DEFAULT NULL COMMENT 'chat id',
|
||||
`chat_id` BIGINT DEFAULT NULL COMMENT 'chat id',
|
||||
`recommend` text,
|
||||
`aggregator` varchar(20) DEFAULT 'trend',
|
||||
`top_num` int DEFAULT NULL,
|
||||
`start_time` varchar(30) DEFAULT NULL,
|
||||
`end_time` varchar(30) DEFAULT NULL,
|
||||
`aggregator` varchar(20) DEFAULT 'trend',
|
||||
`top_num` int DEFAULT NULL,
|
||||
`start_time` varchar(30) DEFAULT NULL,
|
||||
`end_time` varchar(30) DEFAULT NULL,
|
||||
`compare_recommend` LONGVARCHAR,
|
||||
`compare_entity` LONGVARCHAR,
|
||||
`query_sql` LONGVARCHAR,
|
||||
`columns` varchar(2000) DEFAULT NULL,
|
||||
`columns` varchar(2000) DEFAULT NULL,
|
||||
`result_list` LONGVARCHAR,
|
||||
`main_entity` varchar(5000) DEFAULT NULL,
|
||||
`semantic_text` varchar(5000) DEFAULT NULL,
|
||||
`score` int DEFAULT '0',
|
||||
`feedback` varchar(1024) DEFAULT '',
|
||||
`main_entity` varchar(5000) DEFAULT NULL,
|
||||
`semantic_text` varchar(5000) DEFAULT NULL,
|
||||
`score` int DEFAULT '0',
|
||||
`feedback` varchar(1024) DEFAULT '',
|
||||
PRIMARY KEY (`id`)
|
||||
);
|
||||
) ;
|
||||
|
||||
Reference in New Issue
Block a user