(improvement)(project) support for modifying filter conditions and fix group by pushdown and add windows scipt (#49)

Co-authored-by: lexluo <lexluo@tencent.com>
This commit is contained in:
lexluo09
2023-09-03 23:51:47 +08:00
committed by GitHub
parent 8440f1f30e
commit 559ef974b0
317 changed files with 7449 additions and 9413 deletions

View File

@@ -1,9 +1,8 @@
package com.tencent.supersonic.chat.query.llm.dsl.corrector;
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.corrector.DateFieldCorrector;
import org.junit.Assert;
import org.junit.jupiter.api.Test;

View File

@@ -1,8 +1,7 @@
package com.tencent.supersonic.chat.query.llm.dsl.corrector;
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.corrector.FieldValueCorrector;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
@@ -14,12 +13,12 @@ import java.util.Map;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
class FieldValueCorrectorTest {
class FieldNameCorrectorTest {
@Test
void rewriter() {
FieldValueCorrector corrector = new FieldValueCorrector();
FieldNameCorrector corrector = new FieldNameCorrector();
CorrectionInfo correctionInfo = CorrectionInfo.builder()
.sql("select 歌曲名 from 歌曲库 where 专辑照片 = '七里香' and 专辑名 = '流行' and 数据日期 = '2023-08-19'")
.build();

View File

@@ -0,0 +1,71 @@
package com.tencent.supersonic.chat.corrector;
import static org.mockito.Mockito.when;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
class FieldValueCorrectorTest {
@Test
void corrector() {
MockedStatic<ContextUtils> mockContextUtils = Mockito.mockStatic(ContextUtils.class);
SchemaService mockSchemaService = Mockito.mock(SchemaService.class);
SemanticSchema mockSemanticSchema = Mockito.mock(SemanticSchema.class);
List<SchemaElement> dimensions = new ArrayList<>();
List<SchemaValueMap> schemaValueMaps = new ArrayList<>();
SchemaValueMap value1 = new SchemaValueMap();
value1.setBizName("杰伦");
value1.setTechName("周杰伦");
value1.setAlias(Arrays.asList("周杰倫", "Jay Chou", "周董", "周先生"));
schemaValueMaps.add(value1);
SchemaElement schemaElement = SchemaElement.builder()
.bizName("singer_name")
.name("歌手名")
.model(2L)
.schemaValueMaps(schemaValueMaps)
.build();
dimensions.add(schemaElement);
when(mockSemanticSchema.getDimensions()).thenReturn(dimensions);
when(mockSchemaService.getSemanticSchema()).thenReturn(mockSemanticSchema);
mockContextUtils.when(() -> ContextUtils.getBean(SchemaService.class)).thenReturn(mockSchemaService);
SemanticParseInfo parseInfo = new SemanticParseInfo();
SchemaElement model = new SchemaElement();
model.setId(2L);
parseInfo.setModel(model);
CorrectionInfo correctionInfo = CorrectionInfo.builder()
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生'")
.parseInfo(parseInfo)
.build();
FieldValueCorrector corrector = new FieldValueCorrector();
CorrectionInfo info = corrector.corrector(correctionInfo);
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", info.getSql());
correctionInfo.setSql("select count(song_name) from 歌曲库 where singer_name = '杰伦'");
info = corrector.corrector(correctionInfo);
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", info.getSql());
}
}

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.chat.query.llm.dsl.corrector;
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.corrector.SelectFieldAppendCorrector;
import org.junit.Assert;
import org.junit.jupiter.api.Test;

View File

@@ -1,12 +1,80 @@
package com.tencent.supersonic.chat.parser.llm.dsl;
import static org.mockito.Mockito.when;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
class LLMDslParserTest {
@Test
void getDimensionFilter() {
}
void setFilter() {
MockedStatic<ContextUtils> mockContextUtils = Mockito.mockStatic(ContextUtils.class);
SchemaService mockSchemaService = Mockito.mock(SchemaService.class);
SemanticSchema mockSemanticSchema = Mockito.mock(SemanticSchema.class);
List<SchemaElement> dimensions = new ArrayList<>();
List<SchemaValueMap> schemaValueMaps = new ArrayList<>();
SchemaValueMap value1 = new SchemaValueMap();
value1.setBizName("杰伦");
value1.setTechName("周杰伦");
value1.setAlias(Arrays.asList("周杰倫", "Jay Chou", "周董", "周先生"));
schemaValueMaps.add(value1);
SchemaElement schemaElement = SchemaElement.builder()
.bizName("singer_name")
.name("歌手名")
.model(2L)
.schemaValueMaps(schemaValueMaps)
.build();
dimensions.add(schemaElement);
SchemaElement schemaElement2 = SchemaElement.builder()
.bizName("publish_time")
.name("发布时间")
.model(2L)
.build();
dimensions.add(schemaElement2);
when(mockSemanticSchema.getDimensions()).thenReturn(dimensions);
List<SchemaElement> metrics = new ArrayList<>();
SchemaElement metric = SchemaElement.builder()
.bizName("play_count")
.name("播放量")
.model(2L)
.build();
metrics.add(metric);
when(mockSemanticSchema.getMetrics()).thenReturn(metrics);
when(mockSchemaService.getSemanticSchema()).thenReturn(mockSemanticSchema);
mockContextUtils.when(() -> ContextUtils.getBean(SchemaService.class)).thenReturn(mockSchemaService);
SemanticParseInfo parseInfo = new SemanticParseInfo();
SchemaElement model = new SchemaElement();
model.setId(2L);
parseInfo.setModel(model);
CorrectionInfo correctionInfo = CorrectionInfo.builder()
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生' and YEAR(publish_time) >= 2023 and ")
.parseInfo(parseInfo)
.build();
LLMDslParser llmDslParser = new LLMDslParser();
llmDslParser.setFilter(correctionInfo, 2L, parseInfo);
}
}

View File

@@ -1,13 +1,13 @@
package com.tencent.supersonic.chat.test.context;
import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
import com.tencent.supersonic.chat.persistence.repository.impl.ChatContextRepositoryImpl;
import com.tencent.supersonic.chat.test.ChatBizLauncher;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
import com.tencent.supersonic.knowledge.semantic.RemoteSemanticLayer;
import com.tencent.supersonic.chat.test.ChatBizLauncher;
import com.tencent.supersonic.semantic.model.domain.DimensionService;
import com.tencent.supersonic.semantic.model.domain.MetricService;
import com.tencent.supersonic.semantic.model.domain.ModelService;
import com.tencent.supersonic.semantic.model.domain.MetricService;
import com.tencent.supersonic.semantic.query.service.QueryService;
import org.junit.runner.RunWith;
import org.slf4j.Logger;

View File

@@ -2,11 +2,12 @@ package com.tencent.supersonic.chat.test.context;
import com.google.gson.Gson;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.DateConf;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;

View File

@@ -1,2 +1 @@
insert into chat_context (chat_id, modified_at, `user`, `query_text`, `semantic_parse`, ext_data)
VALUES (1, '2023-05-24 00:00:00', 'admin', '超音数访问次数', '', 'admin');
insert into chat_context (chat_id, modified_at , `user`, `query_text`, `semantic_parse` ,ext_data) VALUES(1, '2023-05-24 00:00:00', 'admin', '超音数访问次数', '', 'admin');

View File

@@ -1,59 +1,59 @@
CREATE TABLE `chat_context`
(
`chat_id` BIGINT NOT NULL, -- context chat id
`modified_at` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, -- row modify time
`user` varchar(64) DEFAULT NULL, -- row modify user
`query_text` LONGVARCHAR DEFAULT NULL, -- query text
`semantic_parse` LONGVARCHAR DEFAULT NULL, -- parse data
`ext_data` LONGVARCHAR DEFAULT NULL, -- extend data
`chat_id` BIGINT NOT NULL , -- context chat id
`modified_at` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP , -- row modify time
`user` varchar(64) DEFAULT NULL , -- row modify user
`query_text` LONGVARCHAR DEFAULT NULL , -- query text
`semantic_parse` LONGVARCHAR DEFAULT NULL , -- parse data
`ext_data` LONGVARCHAR DEFAULT NULL , -- extend data
PRIMARY KEY (`chat_id`)
);
CREATE TABLE `chat`
(
`chat_id` BIGINT NOT NULL,-- AUTO_INCREMENT,
`chat_name` varchar(100) DEFAULT NULL,
`create_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`last_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`creator` varchar(30) DEFAULT NULL,
`last_question` varchar(200) DEFAULT NULL,
`is_delete` INT DEFAULT '0' COMMENT 'is deleted',
`is_top` INT DEFAULT '0' COMMENT 'is top',
`chat_id` BIGINT NOT NULL ,-- AUTO_INCREMENT,
`chat_name` varchar(100) DEFAULT NULL,
`create_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ,
`last_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ,
`creator` varchar(30) DEFAULT NULL,
`last_question` varchar(200) DEFAULT NULL,
`is_delete` INT DEFAULT '0' COMMENT 'is deleted',
`is_top` INT DEFAULT '0' COMMENT 'is top',
PRIMARY KEY (`chat_id`)
);
) ;
CREATE TABLE `chat_query`
(
`id` BIGINT NOT NULL,--AUTO_INCREMENT,
`question_id` BIGINT DEFAULT NULL,
`id` BIGINT NOT NULL ,--AUTO_INCREMENT,
`question_id` BIGINT DEFAULT NULL,
`create_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`user_name` varchar(150) DEFAULT NULL COMMENT '',
`question` varchar(300) DEFAULT NULL COMMENT '',
`user_name` varchar(150) DEFAULT NULL COMMENT '',
`question` varchar(300) DEFAULT NULL COMMENT '',
`query_result` LONGVARCHAR,
`time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ,
`state` int(1) DEFAULT NULL,
`data_content` varchar(30) DEFAULT NULL,
`name` varchar(100) DEFAULT NULL,
`data_content` varchar(30) DEFAULT NULL,
`name` varchar(100) DEFAULT NULL,
`scene_type` int(2) DEFAULT NULL,
`query_type` int(2) DEFAULT NULL,
`is_deleted` int(1) DEFAULT NULL,
`module` varchar(30) DEFAULT NULL,
`module` varchar(30) DEFAULT NULL,
`entity` LONGVARCHAR COMMENT '',
`chat_id` BIGINT DEFAULT NULL COMMENT 'chat id',
`chat_id` BIGINT DEFAULT NULL COMMENT 'chat id',
`recommend` text,
`aggregator` varchar(20) DEFAULT 'trend',
`top_num` int DEFAULT NULL,
`start_time` varchar(30) DEFAULT NULL,
`end_time` varchar(30) DEFAULT NULL,
`aggregator` varchar(20) DEFAULT 'trend',
`top_num` int DEFAULT NULL,
`start_time` varchar(30) DEFAULT NULL,
`end_time` varchar(30) DEFAULT NULL,
`compare_recommend` LONGVARCHAR,
`compare_entity` LONGVARCHAR,
`query_sql` LONGVARCHAR,
`columns` varchar(2000) DEFAULT NULL,
`columns` varchar(2000) DEFAULT NULL,
`result_list` LONGVARCHAR,
`main_entity` varchar(5000) DEFAULT NULL,
`semantic_text` varchar(5000) DEFAULT NULL,
`score` int DEFAULT '0',
`feedback` varchar(1024) DEFAULT '',
`main_entity` varchar(5000) DEFAULT NULL,
`semantic_text` varchar(5000) DEFAULT NULL,
`score` int DEFAULT '0',
`feedback` varchar(1024) DEFAULT '',
PRIMARY KEY (`id`)
);
) ;