From 31c8fea2dc46aa20ec7cb5a35322c5ce43e17df5 Mon Sep 17 00:00:00 2001 From: jolunoluo Date: Tue, 19 Sep 2023 15:36:15 +0800 Subject: [PATCH 1/8] (improvement)(chat) add QueryResponder to recall history similar solved query --- .../chat/api/pojo/response/ParseResp.java | 1 + .../pojo/response/SolvedQueryRecallResp.java | 17 +++ .../plugin/embedding/EmbeddingConfig.java | 6 + .../plugin/embedding/RecallRetrieval.java | 2 + .../queryresponder/DefaultQueryResponder.java | 141 ++++++++++++++++++ .../chat/queryresponder/QueryResponder.java | 12 ++ .../chat/service/impl/QueryServiceImpl.java | 8 + 7 files changed, 187 insertions(+) create mode 100644 chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/SolvedQueryRecallResp.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/QueryResponder.java diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java index e8fc10db8..9af3946a9 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java @@ -21,6 +21,7 @@ public class ParseResp { private ParseState state; private List selectedParses; private List candidateParses; + private List similarSolvedQuery; public enum ParseState { COMPLETED, diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/SolvedQueryRecallResp.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/SolvedQueryRecallResp.java new file mode 100644 index 000000000..e92dd1cb1 --- /dev/null +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/SolvedQueryRecallResp.java @@ -0,0 +1,17 @@ +package com.tencent.supersonic.chat.api.pojo.response; + + +import lombok.Builder; +import lombok.Data; + +@Data +@Builder +public class SolvedQueryRecallResp { + + private Long queryId; + + private Integer parseId; + + private String queryText; + +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingConfig.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingConfig.java index df3bdfc0f..5725ed4e3 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingConfig.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingConfig.java @@ -23,4 +23,10 @@ public class EmbeddingConfig { @Value("${embedding.nResult:1}") private String nResult; + @Value("${embedding.solvedQuery.recall.path:/solved_query_retrival}") + private String solvedQueryRecallPath; + + @Value("${embedding.solvedQuery.add.path:/solved_query_add}") + private String solvedQueryAddPath; + } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/RecallRetrieval.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/RecallRetrieval.java index 4d5470e4f..3a61970e9 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/RecallRetrieval.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/RecallRetrieval.java @@ -14,4 +14,6 @@ public class RecallRetrieval { private String presetId; + private String query; + } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java new file mode 100644 index 000000000..5357363fe --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java @@ -0,0 +1,141 @@ +package com.tencent.supersonic.chat.queryresponder; + +import com.alibaba.fastjson.JSONObject; +import com.google.common.collect.Lists; +import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp; +import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig; +import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp; +import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.CollectionUtils; +import org.apache.logging.log4j.util.Strings; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.stereotype.Component; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.util.UriComponentsBuilder; +import java.net.URI; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +@Slf4j +@Component +public class DefaultQueryResponder implements QueryResponder{ + + + private EmbeddingConfig embeddingConfig; + + private RestTemplate restTemplate; + + public DefaultQueryResponder(EmbeddingConfig embeddingConfig, + RestTemplate restTemplate) { + this.embeddingConfig = embeddingConfig; + this.restTemplate = restTemplate; + } + + @Override + public void saveSolvedQuery(String queryText, Long queryId, Integer parseId) { + try { + String uniqueId = generateUniqueId(queryId, parseId); + Map requestMap = new HashMap<>(); + requestMap.put("query", queryText); + requestMap.put("query_id", uniqueId); + doRequest(embeddingConfig.getSolvedQueryAddPath(), + JSONObject.toJSONString(Lists.newArrayList(requestMap))); + } catch (Exception e) { + log.warn("save history question to embedding failed, queryText:{}", queryText, e); + } + } + + @Override + public List recallSolvedQuery(String queryText) { + List solvedQueryRecallResps = Lists.newArrayList(); + try { + String url = embeddingConfig.getUrl() + embeddingConfig.getSolvedQueryRecallPath() + "?n_results=" + + embeddingConfig.getNResult(); + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.setLocation(URI.create(url)); + URI requestUrl = UriComponentsBuilder + .fromHttpUrl(url).build().encode().toUri(); + String jsonBody = JSONObject.toJSONString(Lists.newArrayList(queryText)); + HttpEntity entity = new HttpEntity<>(jsonBody, headers); + log.info("[embedding] request body:{}, url:{}", jsonBody, url); + ResponseEntity> embeddingResponseEntity = + restTemplate.exchange(requestUrl, HttpMethod.POST, entity, + new ParameterizedTypeReference>() { + }); + log.info("[embedding] recognize result body:{}", embeddingResponseEntity); + List embeddingResps = embeddingResponseEntity.getBody(); + Set querySet = new HashSet<>(); + if (CollectionUtils.isNotEmpty(embeddingResps)) { + for (EmbeddingResp embeddingResp : embeddingResps) { + List embeddingRetrievals = embeddingResp.getRetrieval(); + for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) { + if (querySet.contains(embeddingRetrieval.getQuery())) { + continue; + } + String id = embeddingRetrieval.getId(); + SolvedQueryRecallResp solvedQueryRecallResp = + SolvedQueryRecallResp.builder() + .queryText(embeddingRetrieval.getQuery()) + .queryId(getQueryId(id)).parseId(getParseId(id)) + .build(); + solvedQueryRecallResps.add(solvedQueryRecallResp); + querySet.add(embeddingRetrieval.getQuery()); + } + } + } + } catch (Exception e) { + log.warn("recall similar solved query failed", e); + } + return solvedQueryRecallResps; + } + + private String generateUniqueId(Long queryId, Integer parseId) { + String uniqueId = queryId + String.valueOf(parseId); + if (parseId < 10) { + uniqueId = queryId + String.format("0%s", parseId); + } + return uniqueId; + } + + private Long getQueryId(String uniqueId) { + return Long.parseLong(uniqueId) / 100; + } + + private Integer getParseId(String uniqueId) { + return Integer.parseInt(uniqueId) % 100; + } + + private ResponseEntity doRequest(String path, String jsonBody) { + if (Strings.isEmpty(embeddingConfig.getUrl())) { + return ResponseEntity.of(Optional.empty()); + } + String url = embeddingConfig.getUrl() + path; + try { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.setLocation(URI.create(url)); + URI requestUrl = UriComponentsBuilder + .fromHttpUrl(url).build().encode().toUri(); + HttpEntity entity = new HttpEntity<>(jsonBody, headers); + log.info("[embedding] request body :{}, url:{}", jsonBody, url); + ResponseEntity responseEntity = restTemplate.exchange(requestUrl, + HttpMethod.POST, entity, new ParameterizedTypeReference() {}); + log.info("[embedding] result body:{}", responseEntity); + return responseEntity; + } catch (Exception e) { + log.warn("connect to embedding service failed, url:{}", url); + } + return ResponseEntity.of(Optional.empty()); + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/QueryResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/QueryResponder.java new file mode 100644 index 000000000..2f154bd44 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/QueryResponder.java @@ -0,0 +1,12 @@ +package com.tencent.supersonic.chat.queryresponder; + +import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp; +import java.util.List; + +public interface QueryResponder { + + void saveSolvedQuery(String queryText, Long queryId, Integer parseId); + + List recallSolvedQuery(String queryText); + +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 0e1a1e59f..985715350 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -15,12 +15,14 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; +import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp; import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.persistence.dataobject.CostType; import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO; import com.tencent.supersonic.chat.query.QuerySelector; import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq; import com.tencent.supersonic.chat.query.QueryManager; +import com.tencent.supersonic.chat.queryresponder.QueryResponder; import com.tencent.supersonic.chat.service.ChatService; import com.tencent.supersonic.chat.service.QueryService; import com.tencent.supersonic.chat.service.SemanticService; @@ -63,6 +65,8 @@ public class QueryServiceImpl implements QueryService { private ChatService chatService; @Autowired private StatisticsService statisticsService; + @Autowired + private QueryResponder queryResponder; private final String entity = "ENTITY"; @@ -129,10 +133,13 @@ public class QueryServiceImpl implements QueryService { saveInfo(timeCostDOList, queryReq.getQueryText(), parseResult.getQueryId(), queryReq.getUser().getName(), queryReq.getChatId().longValue()); } else { + List solvedQueryRecallResps = + queryResponder.recallSolvedQuery(queryCtx.getRequest().getQueryText()); parseResult = ParseResp.builder() .chatId(queryReq.getChatId()) .queryText(queryReq.getQueryText()) .state(ParseResp.ParseState.FAILED) + .similarSolvedQuery(solvedQueryRecallResps) .build(); } return parseResult; @@ -171,6 +178,7 @@ public class QueryServiceImpl implements QueryService { chatCtx.setUser(queryReq.getUser().getName()); //chatService.addQuery(queryResult, chatCtx); chatService.updateQuery(queryReq.getQueryId(), queryResult, chatCtx); + queryResponder.saveSolvedQuery(queryReq.getQueryText(), queryReq.getQueryId(), queryReq.getParseId()); } else { chatService.deleteChatQuery(queryReq.getQueryId()); } From d5a253a7817826de117e281573f953d28fd45379 Mon Sep 17 00:00:00 2001 From: jolunoluo Date: Tue, 19 Sep 2023 21:08:41 +0800 Subject: [PATCH 2/8] (improvement)(chat) add QueryResponder to recall history similar solved query --- .../supersonic/chat/queryresponder/DefaultQueryResponder.java | 2 +- .../supersonic/chat/service/impl/QueryServiceImpl.java | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java index 5357363fe..03da21a34 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java @@ -28,7 +28,7 @@ import java.util.Set; @Slf4j @Component -public class DefaultQueryResponder implements QueryResponder{ +public class DefaultQueryResponder implements QueryResponder { private EmbeddingConfig embeddingConfig; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 985715350..a3b1050c9 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -133,8 +133,8 @@ public class QueryServiceImpl implements QueryService { saveInfo(timeCostDOList, queryReq.getQueryText(), parseResult.getQueryId(), queryReq.getUser().getName(), queryReq.getChatId().longValue()); } else { - List solvedQueryRecallResps = - queryResponder.recallSolvedQuery(queryCtx.getRequest().getQueryText()); + List solvedQueryRecallResps = + queryResponder.recallSolvedQuery(queryCtx.getRequest().getQueryText()); parseResult = ParseResp.builder() .chatId(queryReq.getChatId()) .queryText(queryReq.getQueryText()) From 3fe726ac23202627ac1d0489d1e43e59298e9391 Mon Sep 17 00:00:00 2001 From: jolunoluo Date: Tue, 19 Sep 2023 21:13:39 +0800 Subject: [PATCH 3/8] (improvement)(semantic) support adding tag for metric --- .../main/resources/db/semantic-schema-h2.sql | 1 + .../src/main/resources/db/schema-h2.sql | 1 + .../src/main/resources/db/schema-mysql.sql | 1 + .../src/main/resources/db/sql-update.sql | 4 +- .../src/test/resources/db/schema-h2.sql | 1 + .../api/model/request/MetricBaseReq.java | 3 + .../api/model/request/PageMetricReq.java | 2 - .../api/model/request/PageSchemaItemReq.java | 1 + .../api/model/response/MetricResp.java | 14 ++ .../model/application/MetricServiceImpl.java | 13 +- .../semantic/model/domain/MetricService.java | 5 +- .../model/domain/dataobject/MetricDO.java | 161 +++++++++++++++++- .../domain/dataobject/MetricDOExample.java | 155 +++++++++++++---- .../semantic/model/domain/pojo/Metric.java | 13 +- .../model/domain/utils/MetricConverter.java | 25 ++- .../semantic/model/rest/MetricController.java | 6 + .../main/resources/mapper/MetricDOMapper.xml | 20 ++- .../mapper/custom/MetricDOCustomMapper.xml | 44 ++--- 18 files changed, 384 insertions(+), 86 deletions(-) diff --git a/launchers/semantic/src/main/resources/db/semantic-schema-h2.sql b/launchers/semantic/src/main/resources/db/semantic-schema-h2.sql index e27fcd725..34da654cc 100644 --- a/launchers/semantic/src/main/resources/db/semantic-schema-h2.sql +++ b/launchers/semantic/src/main/resources/db/semantic-schema-h2.sql @@ -108,6 +108,7 @@ CREATE TABLE IF NOT EXISTS `s2_metric` ( `data_format_type` varchar(50) DEFAULT NULL , `data_format` varchar(500) DEFAULT NULL, `alias` varchar(500) DEFAULT NULL, + `tags` varchar(500) DEFAULT NULL, PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_metric IS 'metric information table'; diff --git a/launchers/standalone/src/main/resources/db/schema-h2.sql b/launchers/standalone/src/main/resources/db/schema-h2.sql index b8d090f07..556b19550 100644 --- a/launchers/standalone/src/main/resources/db/schema-h2.sql +++ b/launchers/standalone/src/main/resources/db/schema-h2.sql @@ -190,6 +190,7 @@ CREATE TABLE IF NOT EXISTS `s2_metric` ( `data_format_type` varchar(50) DEFAULT NULL , `data_format` varchar(500) DEFAULT NULL, `alias` varchar(500) DEFAULT NULL, + `tags` varchar(500) DEFAULT NULL, PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_metric IS 'metric information table'; diff --git a/launchers/standalone/src/main/resources/db/schema-mysql.sql b/launchers/standalone/src/main/resources/db/schema-mysql.sql index da3579178..255f2967a 100644 --- a/launchers/standalone/src/main/resources/db/schema-mysql.sql +++ b/launchers/standalone/src/main/resources/db/schema-mysql.sql @@ -254,6 +254,7 @@ CREATE TABLE `s2_metric` ( `data_format_type` varchar(50) DEFAULT NULL COMMENT '数值类型', `data_format` varchar(500) DEFAULT NULL COMMENT '数值类型参数', `alias` varchar(500) CHARACTER SET utf8 COLLATE utf8_unicode_ci DEFAULT NULL, + `tags` varchar(500) CHARACTER SET utf8 COLLATE utf8_unicode_ci DEFAULT NULL, PRIMARY KEY (`id`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='指标表'; diff --git a/launchers/standalone/src/main/resources/db/sql-update.sql b/launchers/standalone/src/main/resources/db/sql-update.sql index b56df3579..4ba811442 100644 --- a/launchers/standalone/src/main/resources/db/sql-update.sql +++ b/launchers/standalone/src/main/resources/db/sql-update.sql @@ -48,5 +48,7 @@ alter table s2_database drop column domain_id; alter table s2_chat add column agent_id int after chat_id; --20230907 +ALTER TABLE s2_model add alias varchar(200) default null after domain_id; -ALTER TABLE s2_model add alias varchar(200) default null after domain_id; \ No newline at end of file +--20230919 +alter table s2_metric add tags varchar(500) null; \ No newline at end of file diff --git a/launchers/standalone/src/test/resources/db/schema-h2.sql b/launchers/standalone/src/test/resources/db/schema-h2.sql index 64b84e38d..fd6a300f9 100644 --- a/launchers/standalone/src/test/resources/db/schema-h2.sql +++ b/launchers/standalone/src/test/resources/db/schema-h2.sql @@ -205,6 +205,7 @@ CREATE TABLE IF NOT EXISTS `s2_metric` ( `data_format_type` varchar(50) DEFAULT NULL , `data_format` varchar(500) DEFAULT NULL, `alias` varchar(500) DEFAULT NULL, + `tags` varchar(500) DEFAULT NULL, PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_metric IS 'metric information table'; diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/MetricBaseReq.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/MetricBaseReq.java index 1540baeeb..7481d997d 100644 --- a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/MetricBaseReq.java +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/MetricBaseReq.java @@ -4,6 +4,7 @@ package com.tencent.supersonic.semantic.api.model.request; import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem; import com.tencent.supersonic.common.pojo.DataFormat; import lombok.Data; +import java.util.List; @Data @@ -17,4 +18,6 @@ public class MetricBaseReq extends SchemaItem { private DataFormat dataFormat; + private List tags; + } diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/PageMetricReq.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/PageMetricReq.java index aad02b810..40edea0c2 100644 --- a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/PageMetricReq.java +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/PageMetricReq.java @@ -9,6 +9,4 @@ public class PageMetricReq extends PageSchemaItemReq { private String type; - private String key; - } diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/PageSchemaItemReq.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/PageSchemaItemReq.java index b2580779d..94e26026d 100644 --- a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/PageSchemaItemReq.java +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/PageSchemaItemReq.java @@ -16,4 +16,5 @@ public class PageSchemaItemReq extends PageBaseReq { private List modelIds = Lists.newArrayList(); private Integer sensitiveLevel; private Integer status; + private String key; } diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/MetricResp.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/MetricResp.java index 9afbe3368..08ff264e8 100644 --- a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/MetricResp.java +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/MetricResp.java @@ -1,11 +1,15 @@ package com.tencent.supersonic.semantic.api.model.response; +import com.google.common.collect.Lists; import com.tencent.supersonic.common.pojo.DataFormat; import com.tencent.supersonic.semantic.api.model.pojo.MetricTypeParams; import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem; import lombok.Data; import lombok.ToString; +import org.apache.commons.lang3.StringUtils; +import java.util.Arrays; +import java.util.List; @Data @@ -14,6 +18,8 @@ public class MetricResp extends SchemaItem { private Long modelId; + private Long domainId; + private String modelName; //ATOMIC DERIVED @@ -27,5 +33,13 @@ public class MetricResp extends SchemaItem { private String alias; + private List tags; + public void setTag(String tag) { + if (StringUtils.isBlank(tag)) { + tags = Lists.newArrayList(); + } else { + tags = Arrays.asList(tag.split(",")); + } + } } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java index 6545e1b70..dec9d2b3f 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java @@ -28,6 +28,7 @@ import com.tencent.supersonic.semantic.model.domain.utils.MetricConverter; import com.tencent.supersonic.semantic.model.domain.MetricService; import com.tencent.supersonic.semantic.model.domain.pojo.Metric; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -250,6 +251,16 @@ public class MetricServiceImpl implements MetricService { }); } + @Override + public Set getMetricTags() { + List metricResps = getMetrics(); + if (CollectionUtils.isEmpty(metricResps)) { + return new HashSet<>(); + } + return metricResps.stream().flatMap(metricResp -> + metricResp.getTags().stream()).collect(Collectors.toSet()); + } + private void saveMetricBatch(List metrics, User user) { if (CollectionUtils.isEmpty(metrics)) { @@ -293,7 +304,7 @@ public class MetricServiceImpl implements MetricService { Map modelMap = modelService.getModelMap(); if (!CollectionUtils.isEmpty(metricDOS)) { metricDescs = metricDOS.stream() - .map(metricDO -> MetricConverter.convert2MetricDesc(metricDO, modelMap)) + .map(metricDO -> MetricConverter.convert2MetricResp(metricDO, modelMap)) .collect(Collectors.toList()); } return metricDescs; diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/MetricService.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/MetricService.java index 1b28e8e64..1071fe27d 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/MetricService.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/MetricService.java @@ -7,6 +7,7 @@ import com.tencent.supersonic.semantic.api.model.request.PageMetricReq; import com.tencent.supersonic.semantic.api.model.response.MetricResp; import java.util.List; +import java.util.Set; public interface MetricService { @@ -22,7 +23,7 @@ public interface MetricService { void createMetricBatch(List metricReqs, User user) throws Exception; - PageInfo queryMetric(PageMetricReq pageMetrricReq); + PageInfo queryMetric(PageMetricReq pageMetricReq); MetricResp getMetric(Long modelId, String bizName); @@ -35,4 +36,6 @@ public interface MetricService { void deleteMetric(Long id) throws Exception; List mockAlias(MetricReq metricReq, String mockType, User user); + + Set getMetricTags(); } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/dataobject/MetricDO.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/dataobject/MetricDO.java index ceb2fad56..098e63945 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/dataobject/MetricDO.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/dataobject/MetricDO.java @@ -3,116 +3,247 @@ package com.tencent.supersonic.semantic.model.domain.dataobject; import java.util.Date; public class MetricDO { - + /** + * + */ private Long id; + /** + * 主体域ID + */ private Long modelId; + /** + * 指标名称 + */ private String name; + /** + * 字段名称 + */ private String bizName; + /** + * 描述 + */ private String description; + /** + * 指标状态,0正常,1下架,2删除 + */ private Integer status; + /** + * 敏感级别 + */ private Integer sensitiveLevel; + /** + * 指标类型 proxy,expr + */ private String type; + /** + * 创建时间 + */ private Date createdAt; + /** + * 创建人 + */ private String createdBy; + /** + * 更新时间 + */ private Date updatedAt; + /** + * 更新人 + */ private String updatedBy; + /** + * 数值类型 + */ private String dataFormatType; + /** + * 数值类型参数 + */ private String dataFormat; + /** + * + */ private String alias; + /** + * + */ + private String tags; + + /** + * 类型参数 + */ private String typeParams; - + /** + * + * @return id + */ public Long getId() { return id; } + /** + * + * @param id + */ public void setId(Long id) { this.id = id; } + /** + * 主体域ID + * @return model_id 主体域ID + */ public Long getModelId() { return modelId; } + /** + * 主体域ID + * @param modelId 主体域ID + */ public void setModelId(Long modelId) { this.modelId = modelId; } + /** + * 指标名称 + * @return name 指标名称 + */ public String getName() { return name; } + /** + * 指标名称 + * @param name 指标名称 + */ public void setName(String name) { this.name = name == null ? null : name.trim(); } + /** + * 字段名称 + * @return biz_name 字段名称 + */ public String getBizName() { return bizName; } + /** + * 字段名称 + * @param bizName 字段名称 + */ public void setBizName(String bizName) { this.bizName = bizName == null ? null : bizName.trim(); } + /** + * 描述 + * @return description 描述 + */ public String getDescription() { return description; } + /** + * 描述 + * @param description 描述 + */ public void setDescription(String description) { this.description = description == null ? null : description.trim(); } + /** + * 指标状态,0正常,1下架,2删除 + * @return status 指标状态,0正常,1下架,2删除 + */ public Integer getStatus() { return status; } + /** + * 指标状态,0正常,1下架,2删除 + * @param status 指标状态,0正常,1下架,2删除 + */ public void setStatus(Integer status) { this.status = status; } + /** + * 敏感级别 + * @return sensitive_level 敏感级别 + */ public Integer getSensitiveLevel() { return sensitiveLevel; } + /** + * 敏感级别 + * @param sensitiveLevel 敏感级别 + */ public void setSensitiveLevel(Integer sensitiveLevel) { this.sensitiveLevel = sensitiveLevel; } + /** + * 指标类型 proxy,expr + * @return type 指标类型 proxy,expr + */ public String getType() { return type; } + /** + * 指标类型 proxy,expr + * @param type 指标类型 proxy,expr + */ public void setType(String type) { this.type = type == null ? null : type.trim(); } + /** + * 创建时间 + * @return created_at 创建时间 + */ public Date getCreatedAt() { return createdAt; } + /** + * 创建时间 + * @param createdAt 创建时间 + */ public void setCreatedAt(Date createdAt) { this.createdAt = createdAt; } + /** + * 创建人 + * @return created_by 创建人 + */ public String getCreatedBy() { return createdBy; } + /** + * 创建人 + * @param createdBy 创建人 + */ public void setCreatedBy(String createdBy) { this.createdBy = createdBy == null ? null : createdBy.trim(); } @@ -182,21 +313,37 @@ public class MetricDO { } /** - * - * @return alias + * + * @return alias */ public String getAlias() { return alias; } /** - * - * @param alias + * + * @param alias */ public void setAlias(String alias) { this.alias = alias == null ? null : alias.trim(); } + /** + * + * @return tags + */ + public String getTags() { + return tags; + } + + /** + * + * @param tags + */ + public void setTags(String tags) { + this.tags = tags == null ? null : tags.trim(); + } + /** * 类型参数 * @return type_params 类型参数 @@ -212,4 +359,4 @@ public class MetricDO { public void setTypeParams(String typeParams) { this.typeParams = typeParams == null ? null : typeParams.trim(); } -} +} \ No newline at end of file diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/dataobject/MetricDOExample.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/dataobject/MetricDOExample.java index a74ee01eb..d57855cfe 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/dataobject/MetricDOExample.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/dataobject/MetricDOExample.java @@ -31,6 +31,7 @@ public class MetricDOExample { protected Integer limitEnd; /** + * * @mbg.generated */ public MetricDOExample() { @@ -38,6 +39,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public void setOrderByClause(String orderByClause) { @@ -45,6 +47,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public String getOrderByClause() { @@ -52,6 +55,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public void setDistinct(boolean distinct) { @@ -59,6 +63,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public boolean isDistinct() { @@ -66,6 +71,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public List getOredCriteria() { @@ -73,6 +79,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public void or(Criteria criteria) { @@ -80,6 +87,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public Criteria or() { @@ -89,6 +97,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public Criteria createCriteria() { @@ -100,6 +109,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ protected Criteria createCriteriaInternal() { @@ -108,6 +118,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public void clear() { @@ -117,13 +128,15 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public void setLimitStart(Integer limitStart) { - this.limitStart = limitStart; + this.limitStart=limitStart; } /** + * * @mbg.generated */ public Integer getLimitStart() { @@ -131,13 +144,15 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public void setLimitEnd(Integer limitEnd) { - this.limitEnd = limitEnd; + this.limitEnd=limitEnd; } /** + * * @mbg.generated */ public Integer getLimitEnd() { @@ -1177,6 +1192,76 @@ public class MetricDOExample { addCriterion("alias not between", value1, value2, "alias"); return (Criteria) this; } + + public Criteria andTagsIsNull() { + addCriterion("tags is null"); + return (Criteria) this; + } + + public Criteria andTagsIsNotNull() { + addCriterion("tags is not null"); + return (Criteria) this; + } + + public Criteria andTagsEqualTo(String value) { + addCriterion("tags =", value, "tags"); + return (Criteria) this; + } + + public Criteria andTagsNotEqualTo(String value) { + addCriterion("tags <>", value, "tags"); + return (Criteria) this; + } + + public Criteria andTagsGreaterThan(String value) { + addCriterion("tags >", value, "tags"); + return (Criteria) this; + } + + public Criteria andTagsGreaterThanOrEqualTo(String value) { + addCriterion("tags >=", value, "tags"); + return (Criteria) this; + } + + public Criteria andTagsLessThan(String value) { + addCriterion("tags <", value, "tags"); + return (Criteria) this; + } + + public Criteria andTagsLessThanOrEqualTo(String value) { + addCriterion("tags <=", value, "tags"); + return (Criteria) this; + } + + public Criteria andTagsLike(String value) { + addCriterion("tags like", value, "tags"); + return (Criteria) this; + } + + public Criteria andTagsNotLike(String value) { + addCriterion("tags not like", value, "tags"); + return (Criteria) this; + } + + public Criteria andTagsIn(List values) { + addCriterion("tags in", values, "tags"); + return (Criteria) this; + } + + public Criteria andTagsNotIn(List values) { + addCriterion("tags not in", values, "tags"); + return (Criteria) this; + } + + public Criteria andTagsBetween(String value1, String value2) { + addCriterion("tags between", value1, value2, "tags"); + return (Criteria) this; + } + + public Criteria andTagsNotBetween(String value1, String value2) { + addCriterion("tags not between", value1, value2, "tags"); + return (Criteria) this; + } } /** @@ -1209,6 +1294,38 @@ public class MetricDOExample { private String typeHandler; + public String getCondition() { + return condition; + } + + public Object getValue() { + return value; + } + + public Object getSecondValue() { + return secondValue; + } + + public boolean isNoValue() { + return noValue; + } + + public boolean isSingleValue() { + return singleValue; + } + + public boolean isBetweenValue() { + return betweenValue; + } + + public boolean isListValue() { + return listValue; + } + + public String getTypeHandler() { + return typeHandler; + } + protected Criterion(String condition) { super(); this.condition = condition; @@ -1244,37 +1361,5 @@ public class MetricDOExample { protected Criterion(String condition, Object value, Object secondValue) { this(condition, value, secondValue, null); } - - public String getCondition() { - return condition; - } - - public Object getValue() { - return value; - } - - public Object getSecondValue() { - return secondValue; - } - - public boolean isNoValue() { - return noValue; - } - - public boolean isSingleValue() { - return singleValue; - } - - public boolean isBetweenValue() { - return betweenValue; - } - - public boolean isListValue() { - return listValue; - } - - public String getTypeHandler() { - return typeHandler; - } } -} +} \ No newline at end of file diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/pojo/Metric.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/pojo/Metric.java index 771915de1..cfdc4917d 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/pojo/Metric.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/pojo/Metric.java @@ -1,10 +1,12 @@ package com.tencent.supersonic.semantic.model.domain.pojo; - import com.tencent.supersonic.common.pojo.DataFormat; import com.tencent.supersonic.semantic.api.model.pojo.MetricTypeParams; import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem; import lombok.Data; +import org.apache.commons.lang3.StringUtils; +import org.springframework.util.CollectionUtils; +import java.util.List; @Data public class Metric extends SchemaItem { @@ -23,4 +25,13 @@ public class Metric extends SchemaItem { private String alias; + private List tags; + + public String getTag() { + if (CollectionUtils.isEmpty(tags)) { + return ""; + } + return StringUtils.join(tags, ","); + } + } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/utils/MetricConverter.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/utils/MetricConverter.java index 9b6513d3a..5573c7b87 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/utils/MetricConverter.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/utils/MetricConverter.java @@ -37,6 +37,7 @@ public class MetricConverter { if (metric.getDataFormat() != null) { metricDO.setDataFormat(JSONObject.toJSONString(metric.getDataFormat())); } + metricDO.setTags(metric.getTag()); return metricDO; } @@ -51,27 +52,23 @@ public class MetricConverter { BeanUtils.copyProperties(metric, metricDO); metricDO.setTypeParams(JSONObject.toJSONString(metric.getTypeParams())); metricDO.setDataFormat(JSONObject.toJSONString(metric.getDataFormat())); + metricDO.setTags(metric.getTag()); return metricDO; } - public static MetricResp convert2MetricDesc(MetricDO metricDO, Map modelMap) { - MetricResp metricDesc = new MetricResp(); - BeanUtils.copyProperties(metricDO, metricDesc); - metricDesc.setTypeParams(JSONObject.parseObject(metricDO.getTypeParams(), MetricTypeParams.class)); - metricDesc.setDataFormat(JSONObject.parseObject(metricDO.getDataFormat(), DataFormat.class)); + public static MetricResp convert2MetricResp(MetricDO metricDO, Map modelMap) { + MetricResp metricResp = new MetricResp(); + BeanUtils.copyProperties(metricDO, metricResp); + metricResp.setTypeParams(JSONObject.parseObject(metricDO.getTypeParams(), MetricTypeParams.class)); + metricResp.setDataFormat(JSONObject.parseObject(metricDO.getDataFormat(), DataFormat.class)); ModelResp modelResp = modelMap.get(metricDO.getModelId()); if (modelResp != null) { - metricDesc.setModelName(modelResp.getName()); + metricResp.setModelName(modelResp.getName()); + metricResp.setDomainId(modelResp.getDomainId()); } - return metricDesc; - } - - public static Metric convert2Metric(MetricDO metricDO) { - Metric metric = new Metric(); - BeanUtils.copyProperties(metricDO, metric); - metric.setTypeParams(JSONObject.parseObject(metricDO.getTypeParams(), MetricTypeParams.class)); - return metric; + metricResp.setTag(metricDO.getTags()); + return metricResp; } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/MetricController.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/MetricController.java index 99a164695..d72485009 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/MetricController.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/MetricController.java @@ -9,6 +9,7 @@ import com.tencent.supersonic.semantic.api.model.request.PageMetricReq; import com.tencent.supersonic.semantic.api.model.response.MetricResp; import com.tencent.supersonic.semantic.model.domain.MetricService; import java.util.List; +import java.util.Set; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -90,4 +91,9 @@ public class MetricController { } + @GetMapping("/getMetricTags") + public Set getMetricTags() { + return metricService.getMetricTags(); + } + } diff --git a/semantic/model/src/main/resources/mapper/MetricDOMapper.xml b/semantic/model/src/main/resources/mapper/MetricDOMapper.xml index 37f09c2b5..2d20f9ee8 100644 --- a/semantic/model/src/main/resources/mapper/MetricDOMapper.xml +++ b/semantic/model/src/main/resources/mapper/MetricDOMapper.xml @@ -17,6 +17,7 @@ + @@ -52,7 +53,7 @@ id, model_id, name, biz_name, description, status, sensitive_level, type, created_at, - created_by, updated_at, updated_by, data_format_type, data_format, alias + created_by, updated_at, updated_by, data_format_type, data_format, alias, tags type_params @@ -108,13 +109,13 @@ sensitive_level, type, created_at, created_by, updated_at, updated_by, data_format_type, data_format, alias, - type_params) + tags, type_params) values (#{id,jdbcType=BIGINT}, #{modelId,jdbcType=BIGINT}, #{name,jdbcType=VARCHAR}, #{bizName,jdbcType=VARCHAR}, #{description,jdbcType=VARCHAR}, #{status,jdbcType=INTEGER}, #{sensitiveLevel,jdbcType=INTEGER}, #{type,jdbcType=VARCHAR}, #{createdAt,jdbcType=TIMESTAMP}, #{createdBy,jdbcType=VARCHAR}, #{updatedAt,jdbcType=TIMESTAMP}, #{updatedBy,jdbcType=VARCHAR}, #{dataFormatType,jdbcType=VARCHAR}, #{dataFormat,jdbcType=VARCHAR}, #{alias,jdbcType=VARCHAR}, - #{typeParams,jdbcType=LONGVARCHAR}) + #{tags,jdbcType=VARCHAR}, #{typeParams,jdbcType=LONGVARCHAR}) insert into s2_metric @@ -164,6 +165,9 @@ alias, + + tags, + type_params, @@ -214,6 +218,9 @@ #{alias,jdbcType=VARCHAR}, + + #{tags,jdbcType=VARCHAR}, + #{typeParams,jdbcType=LONGVARCHAR}, @@ -270,6 +277,9 @@ alias = #{alias,jdbcType=VARCHAR}, + + tags = #{tags,jdbcType=VARCHAR}, + type_params = #{typeParams,jdbcType=LONGVARCHAR}, @@ -292,6 +302,7 @@ data_format_type = #{dataFormatType,jdbcType=VARCHAR}, data_format = #{dataFormat,jdbcType=VARCHAR}, alias = #{alias,jdbcType=VARCHAR}, + tags = #{tags,jdbcType=VARCHAR}, type_params = #{typeParams,jdbcType=LONGVARCHAR} where id = #{id,jdbcType=BIGINT} @@ -310,7 +321,8 @@ updated_by = #{updatedBy,jdbcType=VARCHAR}, data_format_type = #{dataFormatType,jdbcType=VARCHAR}, data_format = #{dataFormat,jdbcType=VARCHAR}, - alias = #{alias,jdbcType=VARCHAR} + alias = #{alias,jdbcType=VARCHAR}, + tags = #{tags,jdbcType=VARCHAR} where id = #{id,jdbcType=BIGINT} \ No newline at end of file diff --git a/semantic/model/src/main/resources/mapper/custom/MetricDOCustomMapper.xml b/semantic/model/src/main/resources/mapper/custom/MetricDOCustomMapper.xml index 67db20d9c..8546d856b 100644 --- a/semantic/model/src/main/resources/mapper/custom/MetricDOCustomMapper.xml +++ b/semantic/model/src/main/resources/mapper/custom/MetricDOCustomMapper.xml @@ -2,22 +2,26 @@ - - - - - - - - - - - + + + + + + + + + + + + + + + + + - - + + @@ -51,12 +55,11 @@ - id - , model_id, name, biz_name, description, type, created_at, created_by, updated_at, - updated_by + id, model_id, name, biz_name, description, status, sensitive_level, type, created_at, + created_by, updated_at, updated_by, data_format_type, data_format, alias, tags - typeParams + type_params @@ -108,7 +111,8 @@ and ( id like CONCAT('%',#{key , jdbcType=VARCHAR},'%') or name like CONCAT('%',#{key , jdbcType=VARCHAR},'%') or biz_name like CONCAT('%',#{key , jdbcType=VARCHAR},'%') or - description like CONCAT('%',#{key , jdbcType=VARCHAR},'%') ) + description like CONCAT('%',#{key , jdbcType=VARCHAR},'%') or + tags like CONCAT('%',#{key , jdbcType=VARCHAR},'%') ) and id like CONCAT('%',#{id , jdbcType=VARCHAR},'%') From b824cd8ce7bcf50e4193211028d89cbaa5617b5d Mon Sep 17 00:00:00 2001 From: jolunoluo Date: Wed, 20 Sep 2023 11:08:40 +0800 Subject: [PATCH 4/8] (improvement)(auth) support super admin configuration --- .../constant/UserConstants.java | 2 + .../auth/api/authentication/pojo/User.java | 11 +- .../authentication/pojo/UserWithPassword.java | 9 +- .../adaptor/DefaultUserAdaptor.java | 2 +- .../persistence/dataobject/UserDO.java | 62 ++++-- .../persistence/dataobject/UserDOExample.java | 178 ++++++++++++------ .../authentication/utils/UserTokenUtils.java | 11 +- .../main/resources/mapper/UserDOMapper.xml | 53 ++---- .../src/main/resources/db/chat-data-h2.sql | 2 +- .../src/main/resources/db/chat-schema-h2.sql | 1 + .../main/resources/db/semantic-data-h2.sql | 2 +- .../main/resources/db/semantic-schema-h2.sql | 1 + .../src/main/resources/db/data-h2.sql | 4 +- .../src/main/resources/db/schema-h2.sql | 1 + .../src/main/resources/db/schema-mysql.sql | 3 +- .../src/main/resources/db/sql-update.sql | 5 +- .../tencent/supersonic/util/DataUtils.java | 2 +- .../src/test/resources/db/data-h2.sql | 2 +- .../src/test/resources/db/schema-h2.sql | 1 + .../application/DatabaseServiceImpl.java | 6 +- .../model/application/DomainServiceImpl.java | 29 +-- .../model/application/ModelServiceImpl.java | 30 +-- .../semantic/model/domain/DomainService.java | 2 +- .../semantic/model/domain/ModelService.java | 4 +- .../semantic/model/rest/ModelController.java | 2 +- .../query/service/SchemaServiceImpl.java | 2 +- .../query/utils/DataPermissionAOP.java | 4 +- 27 files changed, 273 insertions(+), 158 deletions(-) diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/constant/UserConstants.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/constant/UserConstants.java index 116629f19..3f582d2c5 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/constant/UserConstants.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/constant/UserConstants.java @@ -12,6 +12,8 @@ public class UserConstants { public static final String TOKEN_USER_EMAIL = "token_user_email"; + public static final String TOKEN_IS_ADMIN = "token_is_admin"; + public static final String TOKEN_ALGORITHM = "HS512"; public static final String TOKEN_CREATE_TIME = "token_create_time"; diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/User.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/User.java index 28241eb14..4cf2b526d 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/User.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/User.java @@ -18,17 +18,22 @@ public class User { private String email; - public static User get(Long id, String name, String displayName, String email) { - return new User(id, name, displayName, email); + private Integer isAdmin; + + public static User get(Long id, String name, String displayName, String email, Integer isAdmin) { + return new User(id, name, displayName, email, isAdmin); } public static User getFakeUser() { - return new User(1L, "admin", "admin", "admin@email"); + return new User(1L, "admin", "admin", "admin@email", 1); } public String getDisplayName() { return StringUtils.isBlank(displayName) ? name : displayName; } + public boolean isSuperAdmin() { + return isAdmin != null && isAdmin == 1; + } } diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/UserWithPassword.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/UserWithPassword.java index c7384c1e5..36f77eae2 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/UserWithPassword.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/UserWithPassword.java @@ -9,13 +9,14 @@ public class UserWithPassword extends User { private String password; - public UserWithPassword(Long id, String name, String displayName, String email, String password) { - super(id, name, displayName, email); + public UserWithPassword(Long id, String name, String displayName, String email, String password, Integer isAdmin) { + super(id, name, displayName, email, isAdmin); this.password = password; } - public static UserWithPassword get(Long id, String name, String displayName, String email, String password) { - return new UserWithPassword(id, name, displayName, email, password); + public static UserWithPassword get(Long id, String name, String displayName, + String email, String password, Integer isAdmin) { + return new UserWithPassword(id, name, displayName, email, password, isAdmin); } } diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java index e762ca9a3..9d5893343 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java @@ -71,7 +71,7 @@ public class DefaultUserAdaptor implements UserAdaptor { } if (userDO.getPassword().equals(userReq.getPassword())) { UserWithPassword user = UserWithPassword.get(userDO.getId(), userDO.getName(), userDO.getDisplayName(), - userDO.getEmail(), userDO.getPassword()); + userDO.getEmail(), userDO.getPassword(), userDO.getIsAdmin()); return userTokenUtils.generateToken(user); } throw new RuntimeException("password not correct, please try again"); diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDO.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDO.java index 77b4ae9e7..af32a9aff 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDO.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDO.java @@ -1,99 +1,129 @@ package com.tencent.supersonic.auth.authentication.persistence.dataobject; public class UserDO { - /** - * + * */ private Long id; /** - * + * */ private String name; /** - * + * */ private String password; /** - * + * */ private String displayName; /** - * + * */ private String email; /** - * @return id + * + */ + private Integer isAdmin; + + /** + * + * @return id */ public Long getId() { return id; } /** - * @param id + * + * @param id */ public void setId(Long id) { this.id = id; } /** - * @return name + * + * @return name */ public String getName() { return name; } /** - * @param name + * + * @param name */ public void setName(String name) { this.name = name == null ? null : name.trim(); } /** - * @return password + * + * @return password */ public String getPassword() { return password; } /** - * @param password + * + * @param password */ public void setPassword(String password) { this.password = password == null ? null : password.trim(); } /** - * @return display_name + * + * @return display_name */ public String getDisplayName() { return displayName; } /** - * @param displayName + * + * @param displayName */ public void setDisplayName(String displayName) { this.displayName = displayName == null ? null : displayName.trim(); } /** - * @return email + * + * @return email */ public String getEmail() { return email; } /** - * @param email + * + * @param email */ public void setEmail(String email) { this.email = email == null ? null : email.trim(); } + + /** + * + * @return is_admin + */ + public Integer getIsAdmin() { + return isAdmin; + } + + /** + * + * @param isAdmin + */ + public void setIsAdmin(Integer isAdmin) { + this.isAdmin = isAdmin; + } } \ No newline at end of file diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDOExample.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDOExample.java index 21f01f4ca..96d8fafdd 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDOExample.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDOExample.java @@ -4,7 +4,6 @@ import java.util.ArrayList; import java.util.List; public class UserDOExample { - /** * s2_user */ @@ -31,6 +30,7 @@ public class UserDOExample { protected Integer limitEnd; /** + * * @mbg.generated */ public UserDOExample() { @@ -38,13 +38,7 @@ public class UserDOExample { } /** - * @mbg.generated - */ - public String getOrderByClause() { - return orderByClause; - } - - /** + * * @mbg.generated */ public void setOrderByClause(String orderByClause) { @@ -52,13 +46,15 @@ public class UserDOExample { } /** + * * @mbg.generated */ - public boolean isDistinct() { - return distinct; + public String getOrderByClause() { + return orderByClause; } /** + * * @mbg.generated */ public void setDistinct(boolean distinct) { @@ -66,6 +62,15 @@ public class UserDOExample { } /** + * + * @mbg.generated + */ + public boolean isDistinct() { + return distinct; + } + + /** + * * @mbg.generated */ public List getOredCriteria() { @@ -73,6 +78,7 @@ public class UserDOExample { } /** + * * @mbg.generated */ public void or(Criteria criteria) { @@ -80,6 +86,7 @@ public class UserDOExample { } /** + * * @mbg.generated */ public Criteria or() { @@ -89,6 +96,7 @@ public class UserDOExample { } /** + * * @mbg.generated */ public Criteria createCriteria() { @@ -100,6 +108,7 @@ public class UserDOExample { } /** + * * @mbg.generated */ protected Criteria createCriteriaInternal() { @@ -108,6 +117,7 @@ public class UserDOExample { } /** + * * @mbg.generated */ public void clear() { @@ -117,6 +127,15 @@ public class UserDOExample { } /** + * + * @mbg.generated + */ + public void setLimitStart(Integer limitStart) { + this.limitStart=limitStart; + } + + /** + * * @mbg.generated */ public Integer getLimitStart() { @@ -124,31 +143,25 @@ public class UserDOExample { } /** + * * @mbg.generated */ - public void setLimitStart(Integer limitStart) { - this.limitStart = limitStart; + public void setLimitEnd(Integer limitEnd) { + this.limitEnd=limitEnd; } /** + * * @mbg.generated */ public Integer getLimitEnd() { return limitEnd; } - /** - * @mbg.generated - */ - public void setLimitEnd(Integer limitEnd) { - this.limitEnd = limitEnd; - } - /** * s2_user null */ protected abstract static class GeneratedCriteria { - protected List criteria; protected GeneratedCriteria() { @@ -528,6 +541,66 @@ public class UserDOExample { addCriterion("email not between", value1, value2, "email"); return (Criteria) this; } + + public Criteria andIsAdminIsNull() { + addCriterion("is_admin is null"); + return (Criteria) this; + } + + public Criteria andIsAdminIsNotNull() { + addCriterion("is_admin is not null"); + return (Criteria) this; + } + + public Criteria andIsAdminEqualTo(Integer value) { + addCriterion("is_admin =", value, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminNotEqualTo(Integer value) { + addCriterion("is_admin <>", value, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminGreaterThan(Integer value) { + addCriterion("is_admin >", value, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminGreaterThanOrEqualTo(Integer value) { + addCriterion("is_admin >=", value, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminLessThan(Integer value) { + addCriterion("is_admin <", value, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminLessThanOrEqualTo(Integer value) { + addCriterion("is_admin <=", value, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminIn(List values) { + addCriterion("is_admin in", values, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminNotIn(List values) { + addCriterion("is_admin not in", values, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminBetween(Integer value1, Integer value2) { + addCriterion("is_admin between", value1, value2, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminNotBetween(Integer value1, Integer value2) { + addCriterion("is_admin not between", value1, value2, "isAdmin"); + return (Criteria) this; + } } /** @@ -544,7 +617,6 @@ public class UserDOExample { * s2_user null */ public static class Criterion { - private String condition; private Object value; @@ -561,6 +633,38 @@ public class UserDOExample { private String typeHandler; + public String getCondition() { + return condition; + } + + public Object getValue() { + return value; + } + + public Object getSecondValue() { + return secondValue; + } + + public boolean isNoValue() { + return noValue; + } + + public boolean isSingleValue() { + return singleValue; + } + + public boolean isBetweenValue() { + return betweenValue; + } + + public boolean isListValue() { + return listValue; + } + + public String getTypeHandler() { + return typeHandler; + } + protected Criterion(String condition) { super(); this.condition = condition; @@ -596,37 +700,5 @@ public class UserDOExample { protected Criterion(String condition, Object value, Object secondValue) { this(condition, value, secondValue, null); } - - public String getCondition() { - return condition; - } - - public Object getValue() { - return value; - } - - public Object getSecondValue() { - return secondValue; - } - - public boolean isNoValue() { - return noValue; - } - - public boolean isSingleValue() { - return singleValue; - } - - public boolean isBetweenValue() { - return betweenValue; - } - - public boolean isListValue() { - return listValue; - } - - public String getTypeHandler() { - return typeHandler; - } } } \ No newline at end of file diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/UserTokenUtils.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/UserTokenUtils.java index c8749ad43..82e93bcf3 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/UserTokenUtils.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/UserTokenUtils.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.auth.authentication.utils; import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_ALGORITHM; import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_CREATE_TIME; +import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_IS_ADMIN; import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_PREFIX; import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_TIME_OUT; import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_DISPLAY_NAME; @@ -42,6 +43,7 @@ public class UserTokenUtils { claims.put(TOKEN_USER_PASSWORD, StringUtils.isEmpty(user.getPassword()) ? "" : user.getPassword()); claims.put(TOKEN_USER_DISPLAY_NAME, user.getDisplayName()); claims.put(TOKEN_CREATE_TIME, System.currentTimeMillis()); + claims.put(TOKEN_IS_ADMIN, user.getIsAdmin()); return generate(claims); } @@ -52,6 +54,7 @@ public class UserTokenUtils { claims.put(TOKEN_USER_PASSWORD, "admin"); claims.put(TOKEN_USER_DISPLAY_NAME, "admin"); claims.put(TOKEN_CREATE_TIME, System.currentTimeMillis()); + claims.put(TOKEN_IS_ADMIN, 1); return generate(claims); } @@ -63,7 +66,9 @@ public class UserTokenUtils { String userName = String.valueOf(claims.get(TOKEN_USER_NAME)); String email = String.valueOf(claims.get(TOKEN_USER_EMAIL)); String displayName = String.valueOf(claims.get(TOKEN_USER_DISPLAY_NAME)); - return User.get(userId, userName, displayName, email); + Integer isAdmin = claims.get(TOKEN_IS_ADMIN) == null + ? 0 : Integer.parseInt(claims.get(TOKEN_IS_ADMIN).toString()); + return User.get(userId, userName, displayName, email, isAdmin); } public UserWithPassword getUserWithPassword(HttpServletRequest request) { @@ -79,7 +84,9 @@ public class UserTokenUtils { String email = String.valueOf(claims.get(TOKEN_USER_EMAIL)); String displayName = String.valueOf(claims.get(TOKEN_USER_DISPLAY_NAME)); String password = String.valueOf(claims.get(TOKEN_USER_PASSWORD)); - return UserWithPassword.get(userId, userName, displayName, email, password); + Integer isAdmin = claims.get(TOKEN_IS_ADMIN) == null + ? 0 : Integer.parseInt(claims.get(TOKEN_IS_ADMIN).toString()); + return UserWithPassword.get(userId, userName, displayName, email, password, isAdmin); } private Claims getClaims(String token) { diff --git a/auth/authentication/src/main/resources/mapper/UserDOMapper.xml b/auth/authentication/src/main/resources/mapper/UserDOMapper.xml index 15eb2b49c..c1933db89 100644 --- a/auth/authentication/src/main/resources/mapper/UserDOMapper.xml +++ b/auth/authentication/src/main/resources/mapper/UserDOMapper.xml @@ -2,11 +2,12 @@ - + + @@ -38,7 +39,7 @@ - id, name, password, display_name, email + id, name, password, display_name, email, is_admin - - - delete from s2_user - where id = #{id,jdbcType=BIGINT} - insert into s2_user (id, name, password, - display_name, email) + display_name, email, is_admin + ) values (#{id,jdbcType=BIGINT}, #{name,jdbcType=VARCHAR}, #{password,jdbcType=VARCHAR}, - #{displayName,jdbcType=VARCHAR}, #{email,jdbcType=VARCHAR}) + #{displayName,jdbcType=VARCHAR}, #{email,jdbcType=VARCHAR}, #{isAdmin,jdbcType=INTEGER} + ) insert into s2_user @@ -91,6 +84,9 @@ email, + + is_admin, + @@ -108,6 +104,9 @@ #{email,jdbcType=VARCHAR}, + + #{isAdmin,jdbcType=INTEGER}, + - - update s2_user - - - name = #{name,jdbcType=VARCHAR}, - - - password = #{password,jdbcType=VARCHAR}, - - - display_name = #{displayName,jdbcType=VARCHAR}, - - - email = #{email,jdbcType=VARCHAR}, - - - where id = #{id,jdbcType=BIGINT} - - - update s2_user - set name = #{name,jdbcType=VARCHAR}, - password = #{password,jdbcType=VARCHAR}, - display_name = #{displayName,jdbcType=VARCHAR}, - email = #{email,jdbcType=VARCHAR} - where id = #{id,jdbcType=BIGINT} - \ No newline at end of file diff --git a/launchers/chat/src/main/resources/db/chat-data-h2.sql b/launchers/chat/src/main/resources/db/chat-data-h2.sql index a5207a41c..47989554b 100644 --- a/launchers/chat/src/main/resources/db/chat-data-h2.sql +++ b/launchers/chat/src/main/resources/db/chat-data-h2.sql @@ -1,4 +1,4 @@ -insert into s2_user (id, `name`, password, display_name, email) values (1, 'admin','admin','admin','admin@xx.com'); +insert into s2_user (id, `name`, password, display_name, email, is_admin) values (1, 'admin','admin','admin','admin@xx.com', 1); insert into s2_user (id, `name`, password, display_name, email) values (2, 'jack','123456','jack','jack@xx.com'); insert into s2_user (id, `name`, password, display_name, email) values (3, 'tom','123456','tom','tom@xx.com'); insert into s2_user (id, `name`, password, display_name, email) values (4, 'lucy','123456','lucy','lucy@xx.com'); diff --git a/launchers/chat/src/main/resources/db/chat-schema-h2.sql b/launchers/chat/src/main/resources/db/chat-schema-h2.sql index 8e28d349e..20e5c3bab 100644 --- a/launchers/chat/src/main/resources/db/chat-schema-h2.sql +++ b/launchers/chat/src/main/resources/db/chat-schema-h2.sql @@ -87,6 +87,7 @@ create table s2_user display_name varchar(100) null, password varchar(100) null, email varchar(100) null, + is_admin INT null, PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_user IS 'user information table'; diff --git a/launchers/semantic/src/main/resources/db/semantic-data-h2.sql b/launchers/semantic/src/main/resources/db/semantic-data-h2.sql index b9904da13..c11f8e64d 100644 --- a/launchers/semantic/src/main/resources/db/semantic-data-h2.sql +++ b/launchers/semantic/src/main/resources/db/semantic-data-h2.sql @@ -36,7 +36,7 @@ insert into s2_auth_groups (group_id, config) values (2, '{"domainId":"1","name":"tom_sales_permission","groupId":2,"authRules":[{"metrics":["stay_hours"],"dimensions":["page"]}],"dimensionFilters":["department in (''sales'')"],"dimensionFilterDescription":"开通 tom sales部门权限", "authorizedUsers":["tom"],"authorizedDepartmentIds":[]}'); -insert into s2_user (id, `name`, password, display_name, email) values (1, 'admin','admin','admin','admin@xx.com'); +insert into s2_user (id, `name`, password, display_name, email, is_admin) values (1, 'admin','admin','admin','admin@xx.com', 1); insert into s2_user (id, `name`, password, display_name, email) values (2, 'jack','123456','jack','jack@xx.com'); insert into s2_user (id, `name`, password, display_name, email) values (3, 'tom','123456','tom','tom@xx.com'); insert into s2_user (id, `name`, password, display_name, email) values (4, 'lucy','123456','lucy','lucy@xx.com'); diff --git a/launchers/semantic/src/main/resources/db/semantic-schema-h2.sql b/launchers/semantic/src/main/resources/db/semantic-schema-h2.sql index 34da654cc..2c846cead 100644 --- a/launchers/semantic/src/main/resources/db/semantic-schema-h2.sql +++ b/launchers/semantic/src/main/resources/db/semantic-schema-h2.sql @@ -80,6 +80,7 @@ create table s2_user display_name varchar(100) null, password varchar(100) null, email varchar(100) null, + is_admin INT null, PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_user IS 'user information table'; diff --git a/launchers/standalone/src/main/resources/db/data-h2.sql b/launchers/standalone/src/main/resources/db/data-h2.sql index 0ba57383b..68709c0c2 100644 --- a/launchers/standalone/src/main/resources/db/data-h2.sql +++ b/launchers/standalone/src/main/resources/db/data-h2.sql @@ -1,8 +1,8 @@ -- sample user -insert into s2_user (id, `name`, password, display_name, email) values (1, 'admin','admin','admin','admin@xx.com'); +insert into s2_user (id, `name`, password, display_name, email, is_admin) values (1, 'admin','admin','admin','admin@xx.com', 1); insert into s2_user (id, `name`, password, display_name, email) values (2, 'jack','123456','jack','jack@xx.com'); insert into s2_user (id, `name`, password, display_name, email) values (3, 'tom','123456','tom','tom@xx.com'); -insert into s2_user (id, `name`, password, display_name, email) values (4, 'lucy','123456','lucy','lucy@xx.com'); +insert into s2_user (id, `name`, password, display_name, email, is_admin) values (4, 'lucy','123456','lucy','lucy@xx.com', 1); insert into s2_user (id, `name`, password, display_name, email) values (5, 'alice','123456','alice','alice@xx.com'); -- sample models diff --git a/launchers/standalone/src/main/resources/db/schema-h2.sql b/launchers/standalone/src/main/resources/db/schema-h2.sql index 556b19550..1370bc82d 100644 --- a/launchers/standalone/src/main/resources/db/schema-h2.sql +++ b/launchers/standalone/src/main/resources/db/schema-h2.sql @@ -87,6 +87,7 @@ create table s2_user display_name varchar(100) null, password varchar(100) null, email varchar(100) null, + is_admin INT null, PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_user IS 'user information table'; diff --git a/launchers/standalone/src/main/resources/db/schema-mysql.sql b/launchers/standalone/src/main/resources/db/schema-mysql.sql index 255f2967a..266c67ff6 100644 --- a/launchers/standalone/src/main/resources/db/schema-mysql.sql +++ b/launchers/standalone/src/main/resources/db/schema-mysql.sql @@ -369,7 +369,8 @@ create table s2_user display_name varchar(100) null, password varchar(100) null, email varchar(100) null, + is_admin int(11) null, PRIMARY KEY (`id`) ); -insert into s2_user (id, `name`, password, display_name, email) values (1, 'admin','admin','admin','admin@xx.com'); +insert into s2_user (id, `name`, password, display_name, email, is_admin) values (1, 'admin','admin','admin','admin@xx.com', 1); diff --git a/launchers/standalone/src/main/resources/db/sql-update.sql b/launchers/standalone/src/main/resources/db/sql-update.sql index 4ba811442..7799b3656 100644 --- a/launchers/standalone/src/main/resources/db/sql-update.sql +++ b/launchers/standalone/src/main/resources/db/sql-update.sql @@ -51,4 +51,7 @@ alter table s2_chat add column agent_id int after chat_id; ALTER TABLE s2_model add alias varchar(200) default null after domain_id; --20230919 -alter table s2_metric add tags varchar(500) null; \ No newline at end of file +alter table s2_metric add tags varchar(500) null; + +--20230920 +alter table s2_user add is_admin int null; \ No newline at end of file diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java b/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java index 0bb34a21c..50938af67 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java @@ -23,7 +23,7 @@ import static java.time.LocalDate.now; public class DataUtils { - private static final User user_test = new User(1L, "admin", "admin", "admin@email"); + private static final User user_test = User.getFakeUser(); public static User getUser() { return user_test; diff --git a/launchers/standalone/src/test/resources/db/data-h2.sql b/launchers/standalone/src/test/resources/db/data-h2.sql index c2ee76401..10f6a3ef5 100644 --- a/launchers/standalone/src/test/resources/db/data-h2.sql +++ b/launchers/standalone/src/test/resources/db/data-h2.sql @@ -1,5 +1,5 @@ -- sample user -insert into s2_user (id, `name`, password, display_name, email) values (1, 'admin','admin','admin','admin@xx.com'); +insert into s2_user (id, `name`, password, display_name, email, is_admin) values (1, 'admin','admin','admin','admin@xx.com', 1); insert into s2_user (id, `name`, password, display_name, email) values (2, 'jack','123456','jack','jack@xx.com'); insert into s2_user (id, `name`, password, display_name, email) values (3, 'tom','123456','tom','tom@xx.com'); insert into s2_user (id, `name`, password, display_name, email) values (4, 'lucy','123456','lucy','lucy@xx.com'); diff --git a/launchers/standalone/src/test/resources/db/schema-h2.sql b/launchers/standalone/src/test/resources/db/schema-h2.sql index fd6a300f9..33429ca0b 100644 --- a/launchers/standalone/src/test/resources/db/schema-h2.sql +++ b/launchers/standalone/src/test/resources/db/schema-h2.sql @@ -102,6 +102,7 @@ create table s2_user display_name varchar(100) null, password varchar(100) null, email varchar(100) null, + is_admin INT null, PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_user IS 'user information table'; diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DatabaseServiceImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DatabaseServiceImpl.java index 6c4e731a3..e1a23c908 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DatabaseServiceImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DatabaseServiceImpl.java @@ -72,7 +72,8 @@ public class DatabaseServiceImpl implements DatabaseService { private void fillPermission(List databaseResps, User user) { databaseResps.forEach(databaseResp -> { if (databaseResp.getAdmins().contains(user.getName()) - || user.getName().equalsIgnoreCase(databaseResp.getCreatedBy())) { + || user.getName().equalsIgnoreCase(databaseResp.getCreatedBy()) + || user.isSuperAdmin()) { databaseResp.setHasPermission(true); databaseResp.setHasEditPermission(true); databaseResp.setHasUsePermission(true); @@ -111,7 +112,8 @@ public class DatabaseServiceImpl implements DatabaseService { List viewers = databaseResp.getViewers(); if (!admins.contains(user.getName()) && !viewers.contains(user.getName()) - && !databaseResp.getCreatedBy().equalsIgnoreCase(user.getName())) { + && !databaseResp.getCreatedBy().equalsIgnoreCase(user.getName()) + && !user.isSuperAdmin()) { String message = String.format("您暂无当前数据库%s权限, 请联系数据库管理员%s开通", databaseResp.getName(), String.join(",", admins)); diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DomainServiceImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DomainServiceImpl.java index fd51d9496..ebc629b8d 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DomainServiceImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DomainServiceImpl.java @@ -96,12 +96,12 @@ public class DomainServiceImpl implements DomainService { @Override public List getDomainListWithAdminAuth(User user) { - Set domainWithAuthAll = getDomainAuthSet(user.getName(), AuthType.ADMIN); + Set domainWithAuthAll = getDomainAuthSet(user, AuthType.ADMIN); if (!CollectionUtils.isEmpty(domainWithAuthAll)) { List domainIds = domainWithAuthAll.stream().map(DomainResp::getId).collect(Collectors.toList()); domainWithAuthAll.addAll(getParentDomain(domainIds)); } - List modelResps = modelService.getModelAuthList(user.getName(), AuthType.ADMIN); + List modelResps = modelService.getModelAuthList(user, AuthType.ADMIN); if (!CollectionUtils.isEmpty(modelResps)) { List domainIds = modelResps.stream().map(ModelResp::getDomainId).collect(Collectors.toList()); domainWithAuthAll.addAll(getParentDomain(domainIds)); @@ -111,18 +111,18 @@ public class DomainServiceImpl implements DomainService { } @Override - public Set getDomainAuthSet(String userName, AuthType authTypeEnum) { + public Set getDomainAuthSet(User user, AuthType authTypeEnum) { List domainResps = getDomainList(); - Set orgIds = userService.getUserAllOrgId(userName); + Set orgIds = userService.getUserAllOrgId(user.getName()); List domainWithAuth = Lists.newArrayList(); if (authTypeEnum.equals(AuthType.ADMIN)) { domainWithAuth = domainResps.stream() - .filter(domainResp -> checkAdminPermission(orgIds, userName, domainResp)) + .filter(domainResp -> checkAdminPermission(orgIds, user, domainResp)) .collect(Collectors.toList()); } if (authTypeEnum.equals(AuthType.VISIBLE)) { domainWithAuth = domainResps.stream() - .filter(domainResp -> checkViewerPermission(orgIds, userName, domainResp)) + .filter(domainResp -> checkViewerPermission(orgIds, user, domainResp)) .collect(Collectors.toList()); } List domainIds = domainWithAuth.stream().map(DomainResp::getId) @@ -240,11 +240,13 @@ public class DomainServiceImpl implements DomainService { } - private boolean checkAdminPermission(Set orgIds, String userName, DomainResp domainResp) { - + private boolean checkAdminPermission(Set orgIds, User user, DomainResp domainResp) { List admins = domainResp.getAdmins(); List adminOrgs = domainResp.getAdminOrgs(); - if (admins.contains(userName) || domainResp.getCreatedBy().equals(userName)) { + if (user.isSuperAdmin()) { + return true; + } + if (admins.contains(user.getName()) || domainResp.getCreatedBy().equals(user.getName())) { return true; } if (CollectionUtils.isEmpty(adminOrgs)) { @@ -258,12 +260,17 @@ public class DomainServiceImpl implements DomainService { return false; } - private boolean checkViewerPermission(Set orgIds, String userName, DomainResp domainDesc) { + private boolean checkViewerPermission(Set orgIds, User user, DomainResp domainDesc) { List admins = domainDesc.getAdmins(); List viewers = domainDesc.getViewers(); List adminOrgs = domainDesc.getAdminOrgs(); List viewOrgs = domainDesc.getViewOrgs(); - if (admins.contains(userName) || viewers.contains(userName) || domainDesc.getCreatedBy().equals(userName)) { + if (user.isSuperAdmin()) { + return true; + } + if (admins.contains(user.getName()) + || viewers.contains(user.getName()) + || domainDesc.getCreatedBy().equals(user.getName())) { return true; } if (CollectionUtils.isEmpty(adminOrgs) && CollectionUtils.isEmpty(viewOrgs)) { diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/ModelServiceImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/ModelServiceImpl.java index b10dd9704..8d28a26c0 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/ModelServiceImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/ModelServiceImpl.java @@ -97,10 +97,10 @@ public class ModelServiceImpl implements ModelService { } @Override - public List getModelListWithAuth(String userName, Long domainId, AuthType authType) { - List modelResps = getModelAuthList(userName, authType); + public List getModelListWithAuth(User user, Long domainId, AuthType authType) { + List modelResps = getModelAuthList(user, authType); Set modelRespSet = new HashSet<>(modelResps); - List modelRespsAuthInheritDomain = getModelRespAuthInheritDomain(userName, authType); + List modelRespsAuthInheritDomain = getModelRespAuthInheritDomain(user, authType); modelRespSet.addAll(modelRespsAuthInheritDomain); if (domainId != null && domainId > 0) { modelRespSet = modelRespSet.stream().filter(modelResp -> @@ -109,8 +109,8 @@ public class ModelServiceImpl implements ModelService { return fillMetricInfo(new ArrayList<>(modelRespSet)); } - public List getModelRespAuthInheritDomain(String userName, AuthType authType) { - Set domainResps = domainService.getDomainAuthSet(userName, authType); + public List getModelRespAuthInheritDomain(User user, AuthType authType) { + Set domainResps = domainService.getDomainAuthSet(user, authType); if (CollectionUtils.isEmpty(domainResps)) { return Lists.newArrayList(); } @@ -121,18 +121,18 @@ public class ModelServiceImpl implements ModelService { } @Override - public List getModelAuthList(String userName, AuthType authTypeEnum) { + public List getModelAuthList(User user, AuthType authTypeEnum) { List modelResps = getModelList(); - Set orgIds = userService.getUserAllOrgId(userName); + Set orgIds = userService.getUserAllOrgId(user.getName()); List modelWithAuth = Lists.newArrayList(); if (authTypeEnum.equals(AuthType.ADMIN)) { modelWithAuth = modelResps.stream() - .filter(modelResp -> checkAdminPermission(orgIds, userName, modelResp)) + .filter(modelResp -> checkAdminPermission(orgIds, user, modelResp)) .collect(Collectors.toList()); } if (authTypeEnum.equals(AuthType.VISIBLE)) { modelWithAuth = modelResps.stream() - .filter(domainResp -> checkViewerPermission(orgIds, userName, domainResp)) + .filter(domainResp -> checkViewerPermission(orgIds, user, domainResp)) .collect(Collectors.toList()); } return modelWithAuth; @@ -325,9 +325,13 @@ public class ModelServiceImpl implements ModelService { return new ArrayList<>(getModelMap().keySet()); } - public static boolean checkAdminPermission(Set orgIds, String userName, ModelResp modelResp) { + public static boolean checkAdminPermission(Set orgIds, User user, ModelResp modelResp) { List admins = modelResp.getAdmins(); List adminOrgs = modelResp.getAdminOrgs(); + if (user.isSuperAdmin()) { + return true; + } + String userName = user.getName(); if (admins.contains(userName) || modelResp.getCreatedBy().equals(userName)) { return true; } @@ -342,14 +346,18 @@ public class ModelServiceImpl implements ModelService { return false; } - public static boolean checkViewerPermission(Set orgIds, String userName, ModelResp modelResp) { + public static boolean checkViewerPermission(Set orgIds, User user, ModelResp modelResp) { List admins = modelResp.getAdmins(); List viewers = modelResp.getViewers(); List adminOrgs = modelResp.getAdminOrgs(); List viewOrgs = modelResp.getViewOrgs(); + if (user.isSuperAdmin()) { + return true; + } if (modelResp.openToAll()) { return true; } + String userName = user.getName(); if (admins.contains(userName) || viewers.contains(userName) || modelResp.getCreatedBy().equals(userName)) { return true; } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/DomainService.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/DomainService.java index 565167b8a..6a72d2adf 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/DomainService.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/DomainService.java @@ -30,7 +30,7 @@ public interface DomainService { List getDomainListWithAdminAuth(User user); - Set getDomainAuthSet(String userName, AuthType authTypeEnum); + Set getDomainAuthSet(User user, AuthType authTypeEnum); Set getDomainChildren(List domainId); diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/ModelService.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/ModelService.java index 7e05fa38e..f4458d8ba 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/ModelService.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/ModelService.java @@ -13,9 +13,9 @@ import java.util.Map; public interface ModelService { - List getModelListWithAuth(String userName, Long domainId, AuthType authType); + List getModelListWithAuth(User user, Long domainId, AuthType authType); - List getModelAuthList(String userName, AuthType authTypeEnum); + List getModelAuthList(User user, AuthType authTypeEnum); List getModelByDomainIds(List domainIds); diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/ModelController.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/ModelController.java index 0eb0b5175..3c3ec3624 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/ModelController.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/ModelController.java @@ -60,7 +60,7 @@ public class ModelController { HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); - return modelService.getModelListWithAuth(user.getName(), domainId, AuthType.ADMIN); + return modelService.getModelListWithAuth(user, domainId, AuthType.ADMIN); } diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/SchemaServiceImpl.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/SchemaServiceImpl.java index 68058404b..2658fc84c 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/SchemaServiceImpl.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/SchemaServiceImpl.java @@ -126,7 +126,7 @@ public class SchemaServiceImpl implements SchemaService { @Override public List getModelList(User user, AuthType authTypeEnum, Long domainId) { - return modelService.getModelListWithAuth(user.getName(), domainId, authTypeEnum); + return modelService.getModelListWithAuth(user, domainId, authTypeEnum); } } diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DataPermissionAOP.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DataPermissionAOP.java index 36afdef4e..dd0637795 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DataPermissionAOP.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DataPermissionAOP.java @@ -140,7 +140,7 @@ public class DataPermissionAOP { private boolean doModelAdmin(User user, QueryStructReq queryStructReq) { Long modelId = queryStructReq.getModelId(); - List modelListAdmin = modelService.getModelListWithAuth(user.getName(), null, AuthType.ADMIN); + List modelListAdmin = modelService.getModelListWithAuth(user, null, AuthType.ADMIN); if (CollectionUtils.isEmpty(modelListAdmin)) { return false; } else { @@ -153,7 +153,7 @@ public class DataPermissionAOP { private void doModelVisible(User user, QueryStructReq queryStructReq) { Boolean visible = true; Long modelId = queryStructReq.getModelId(); - List modelListVisible = modelService.getModelListWithAuth(user.getName(), null, AuthType.VISIBLE); + List modelListVisible = modelService.getModelListWithAuth(user, null, AuthType.VISIBLE); if (CollectionUtils.isEmpty(modelListVisible)) { visible = false; } else { From 63eff5c62a58719be7fcb5fce31581bde3717819 Mon Sep 17 00:00:00 2001 From: jolunoluo Date: Wed, 20 Sep 2023 16:08:53 +0800 Subject: [PATCH 5/8] (improvement)(semantic) add admin auth check in metric market --- .../chat/api/component/SemanticLayer.java | 2 +- .../chat/rest/ChatConfigController.java | 9 ++++---- .../semantic/LocalSemanticLayer.java | 4 ++-- .../semantic/RemoteSemanticLayer.java | 2 +- .../api/model/response/MetricResp.java | 2 ++ .../model/application/MetricServiceImpl.java | 22 +++++++++++++++++-- .../semantic/model/domain/MetricService.java | 2 +- .../semantic/model/rest/MetricController.java | 7 ++++-- .../query/service/SchemaServiceImpl.java | 2 +- 9 files changed, 38 insertions(+), 14 deletions(-) diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticLayer.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticLayer.java index 12d336ed2..a6be0a0ab 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticLayer.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticLayer.java @@ -49,7 +49,7 @@ public interface SemanticLayer { PageInfo getDimensionPage(PageDimensionReq pageDimensionCmd); - PageInfo getMetricPage(PageMetricReq pageMetricCmd); + PageInfo getMetricPage(PageMetricReq pageMetricCmd, User user); List getDomainList(User user); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatConfigController.java b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatConfigController.java index 88612800a..d8521e73a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatConfigController.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatConfigController.java @@ -110,17 +110,18 @@ public class ChatConfigController { } @PostMapping("/dimension/page") - public PageInfo getDimension(@RequestBody PageDimensionReq pageDimensionCmd, + public PageInfo getDimension(@RequestBody PageDimensionReq pageDimensionReq, HttpServletRequest request, HttpServletResponse response) { - return semanticLayer.getDimensionPage(pageDimensionCmd); + return semanticLayer.getDimensionPage(pageDimensionReq); } @PostMapping("/metric/page") - public PageInfo getMetric(@RequestBody PageMetricReq pageMetrricCmd, + public PageInfo getMetric(@RequestBody PageMetricReq pageMetricReq, HttpServletRequest request, HttpServletResponse response) { - return semanticLayer.getMetricPage(pageMetrricCmd); + User user = UserHolder.findUser(request, response); + return semanticLayer.getMetricPage(pageMetricReq, user); } diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java index c675a7217..e9a9676f3 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java @@ -110,9 +110,9 @@ public class LocalSemanticLayer extends BaseSemanticLayer { } @Override - public PageInfo getMetricPage(PageMetricReq pageMetricReq) { + public PageInfo getMetricPage(PageMetricReq pageMetricReq, User user) { metricService = ContextUtils.getBean(MetricService.class); - return metricService.queryMetric(pageMetricReq); + return metricService.queryMetric(pageMetricReq, user); } } diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/RemoteSemanticLayer.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/RemoteSemanticLayer.java index 723dc6a40..c18479162 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/RemoteSemanticLayer.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/RemoteSemanticLayer.java @@ -258,7 +258,7 @@ public class RemoteSemanticLayer extends BaseSemanticLayer { } @Override - public PageInfo getMetricPage(PageMetricReq pageMetricCmd) { + public PageInfo getMetricPage(PageMetricReq pageMetricCmd, User user) { String body = JsonUtil.toString(pageMetricCmd); DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class); log.info("url:{}", defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchMetricPagePath()); diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/MetricResp.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/MetricResp.java index 08ff264e8..0a2aecd24 100644 --- a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/MetricResp.java +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/MetricResp.java @@ -35,6 +35,8 @@ public class MetricResp extends SchemaItem { private List tags; + private boolean hasAdminRes = false; + public void setTag(String tag) { if (StringUtils.isBlank(tag)) { tags = Lists.newArrayList(); diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java index dec9d2b3f..327e9017d 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java @@ -9,6 +9,7 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.pojo.DataAddEvent; import com.tencent.supersonic.common.pojo.DataDeleteEvent; import com.tencent.supersonic.common.pojo.DataUpdateEvent; +import com.tencent.supersonic.common.pojo.enums.AuthType; import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.common.util.ChatGptHelper; import com.tencent.supersonic.semantic.api.model.pojo.Measure; @@ -124,7 +125,7 @@ public class MetricServiceImpl implements MetricService { } @Override - public PageInfo queryMetric(PageMetricReq pageMetricReq) { + public PageInfo queryMetric(PageMetricReq pageMetricReq, User user) { MetricFilter metricFilter = new MetricFilter(); BeanUtils.copyProperties(pageMetricReq, metricFilter); Set domainResps = domainService.getDomainChildren(pageMetricReq.getDomainIds()); @@ -138,7 +139,9 @@ public class MetricServiceImpl implements MetricService { .doSelectPageInfo(() -> queryMetric(metricFilter)); PageInfo pageInfo = new PageInfo<>(); BeanUtils.copyProperties(metricDOPageInfo, pageInfo); - pageInfo.setList(convertList(metricDOPageInfo.getList())); + List metricResps = convertList(metricDOPageInfo.getList()); + fillAdminRes(metricResps, user); + pageInfo.setList(metricResps); return pageInfo; } @@ -146,6 +149,21 @@ public class MetricServiceImpl implements MetricService { return metricRepository.getMetric(metricFilter); } + + private void fillAdminRes(List metricResps, User user) { + List modelResps = modelService.getModelListWithAuth(user, null, AuthType.ADMIN); + if (CollectionUtils.isEmpty(modelResps)) { + return; + } + Set modelIdSet = modelResps.stream().map(ModelResp::getId).collect(Collectors.toSet()); + for (MetricResp metricResp : metricResps) { + if (modelIdSet.contains(metricResp.getModelId())) { + metricResp.setHasAdminRes(true); + } + } + + } + @Override public MetricResp getMetric(Long modelId, String bizName) { List metricDescs = getMetricByModelId(modelId); diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/MetricService.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/MetricService.java index 1071fe27d..969b935e5 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/MetricService.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/MetricService.java @@ -23,7 +23,7 @@ public interface MetricService { void createMetricBatch(List metricReqs, User user) throws Exception; - PageInfo queryMetric(PageMetricReq pageMetricReq); + PageInfo queryMetric(PageMetricReq pageMetricReq, User user); MetricResp getMetric(Long modelId, String bizName); diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/MetricController.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/MetricController.java index d72485009..4b60bfab2 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/MetricController.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/MetricController.java @@ -69,8 +69,11 @@ public class MetricController { @PostMapping("/queryMetric") - public PageInfo queryMetric(@RequestBody PageMetricReq pageMetrricReq) { - return metricService.queryMetric(pageMetrricReq); + public PageInfo queryMetric(@RequestBody PageMetricReq pageMetricReq, + HttpServletRequest request, + HttpServletResponse response) { + User user = UserHolder.findUser(request, response); + return metricService.queryMetric(pageMetricReq, user); } @GetMapping("getMetric/{modelId}/{bizName}") diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/SchemaServiceImpl.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/SchemaServiceImpl.java index 2658fc84c..6049b4efe 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/SchemaServiceImpl.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/SchemaServiceImpl.java @@ -116,7 +116,7 @@ public class SchemaServiceImpl implements SchemaService { @Override public PageInfo queryMetric(PageMetricReq pageMetricCmd, User user) { - return metricService.queryMetric(pageMetricCmd); + return metricService.queryMetric(pageMetricCmd, user); } @Override From 51f62438cf88c6201af4891f7778ccdcba0ad1f5 Mon Sep 17 00:00:00 2001 From: jolunoluo Date: Fri, 22 Sep 2023 14:52:53 +0800 Subject: [PATCH 6/8] (improvement)(chat) add QueryResponder to recall history similar solved query --- .../chat/api/pojo/response/QueryResult.java | 1 + .../plugin/embedding/EmbeddingConfig.java | 3 + .../queryresponder/DefaultQueryResponder.java | 2 +- .../chat/service/impl/QueryServiceImpl.java | 45 +++--- .../resources/mapper/ChatQueryDOMapper.xml | 146 ++---------------- 5 files changed, 42 insertions(+), 155 deletions(-) diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/QueryResult.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/QueryResult.java index 3858f64f2..74020d195 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/QueryResult.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/QueryResult.java @@ -21,4 +21,5 @@ public class QueryResult { private SemanticParseInfo chatContext; private Object response; private List> queryResults; + private List similarSolvedQuery; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingConfig.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingConfig.java index 5725ed4e3..46ff9f848 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingConfig.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingConfig.java @@ -29,4 +29,7 @@ public class EmbeddingConfig { @Value("${embedding.solvedQuery.add.path:/solved_query_add}") private String solvedQueryAddPath; + @Value("${embedding.solved.query.nResult:5}") + private String solvedQueryResultNum; + } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java index 03da21a34..8a7a80e9f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java @@ -60,7 +60,7 @@ public class DefaultQueryResponder implements QueryResponder { List solvedQueryRecallResps = Lists.newArrayList(); try { String url = embeddingConfig.getUrl() + embeddingConfig.getSolvedQueryRecallPath() + "?n_results=" - + embeddingConfig.getNResult(); + + embeddingConfig.getSolvedQueryResultNum(); HttpHeaders headers = new HttpHeaders(); headers.setContentType(MediaType.APPLICATION_JSON); headers.setLocation(URI.create(url)); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 817578804..efcc45b77 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -175,27 +175,34 @@ public class QueryServiceImpl implements QueryService { ChatContext chatCtx = chatService.getOrCreateContext(queryReq.getChatId()); chatCtx.setAgentId(queryReq.getAgentId()); Long startTime = System.currentTimeMillis(); - QueryResult queryResult = semanticQuery.execute(queryReq.getUser()); - - if (queryResult != null) { - timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime)) - .interfaceName(semanticQuery.getClass().getSimpleName()).type(CostType.QUERY.getType()).build()); - saveInfo(timeCostDOList, queryReq.getQueryText(), queryReq.getQueryId(), - queryReq.getUser().getName(), queryReq.getChatId().longValue()); - queryResult.setChatContext(parseInfo); - // update chat context after a successful semantic query - if (queryReq.isSaveAnswer() && QueryState.SUCCESS.equals(queryResult.getQueryState())) { - chatCtx.setParseInfo(parseInfo); - chatService.updateContext(chatCtx); - } - chatCtx.setQueryText(queryReq.getQueryText()); - chatCtx.setUser(queryReq.getUser().getName()); - chatService.updateQuery(queryReq.getQueryId(), queryResult, chatCtx); - queryResponder.saveSolvedQuery(queryReq.getQueryText(), queryReq.getQueryId(), queryReq.getParseId()); - } else { - chatService.deleteChatQuery(queryReq.getQueryId()); + QueryResult queryResult = null; + try { + queryResult = semanticQuery.execute(queryReq.getUser()); + } catch (Exception e) { + log.error("query execute failed, queryText:{}", queryReq.getQueryText(), e); + queryResult = new QueryResult(); + queryResult.setQueryState(QueryState.INVALID); } + timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime)) + .interfaceName(semanticQuery.getClass().getSimpleName()).type(CostType.QUERY.getType()).build()); + saveInfo(timeCostDOList, queryReq.getQueryText(), queryReq.getQueryId(), + queryReq.getUser().getName(), queryReq.getChatId().longValue()); + queryResult.setChatContext(parseInfo); + // update chat context after a successful semantic query + if (queryReq.isSaveAnswer() && QueryState.SUCCESS.equals(queryResult.getQueryState())) { + chatCtx.setParseInfo(parseInfo); + chatService.updateContext(chatCtx); + queryResponder.saveSolvedQuery(queryReq.getQueryText(), queryReq.getQueryId(), queryReq.getParseId()); + } + chatCtx.setQueryText(queryReq.getQueryText()); + chatCtx.setUser(queryReq.getUser().getName()); + chatService.updateQuery(queryReq.getQueryId(), queryResult, chatCtx); + if (!QueryState.SUCCESS.equals(queryResult.getQueryState())) { + List solvedQueryRecallResps = + queryResponder.recallSolvedQuery(queryReq.getQueryText()); + queryResult.setSimilarSolvedQuery(solvedQueryRecallResps); + } return queryResult; } diff --git a/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml b/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml index 72cbbc4b4..f965a0882 100644 --- a/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml +++ b/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml @@ -3,6 +3,7 @@ + @@ -44,7 +45,7 @@ - question_id, create_time, user_name, query_state, chat_id, score, feedback + question_id, agent_id, create_time, user_name, query_state, chat_id, score, feedback query_text, query_result @@ -65,142 +66,23 @@ order by ${orderByClause} - - + + delete from s2_chat_query where question_id = #{questionId,jdbcType=BIGINT} - insert into s2_chat_query (question_id, create_time, user_name, + insert into s2_chat_query (question_id, agent_id, create_time, user_name, query_state, chat_id, score, feedback, query_text, query_result ) - values (#{questionId,jdbcType=BIGINT}, #{createTime,jdbcType=TIMESTAMP}, #{userName,jdbcType=VARCHAR}, + values (#{questionId,jdbcType=BIGINT}, #{agentId,jdbcType=BIGINT}, #{createTime,jdbcType=TIMESTAMP}, #{userName,jdbcType=VARCHAR}, #{queryState,jdbcType=INTEGER}, #{chatId,jdbcType=BIGINT}, #{score,jdbcType=INTEGER}, #{feedback,jdbcType=VARCHAR}, #{queryText,jdbcType=LONGVARCHAR}, #{queryResult,jdbcType=LONGVARCHAR} ) - - insert into s2_chat_query - - - question_id, - - - create_time, - - - user_name, - - - query_state, - - - chat_id, - - - score, - - - feedback, - - - query_text, - - - query_result, - - - - - #{questionId,jdbcType=BIGINT}, - - - #{createTime,jdbcType=TIMESTAMP}, - - - #{userName,jdbcType=VARCHAR}, - - - #{queryState,jdbcType=INTEGER}, - - - #{chatId,jdbcType=BIGINT}, - - - #{score,jdbcType=INTEGER}, - - - #{feedback,jdbcType=VARCHAR}, - - - #{queryText,jdbcType=LONGVARCHAR}, - - - #{queryResult,jdbcType=LONGVARCHAR}, - - - - - - update s2_chat_query - - - create_time = #{createTime,jdbcType=TIMESTAMP}, - - - user_name = #{userName,jdbcType=VARCHAR}, - - - query_state = #{queryState,jdbcType=INTEGER}, - - - chat_id = #{chatId,jdbcType=BIGINT}, - - - score = #{score,jdbcType=INTEGER}, - - - feedback = #{feedback,jdbcType=VARCHAR}, - - - query_text = #{queryText,jdbcType=LONGVARCHAR}, - - - query_result = #{queryResult,jdbcType=LONGVARCHAR}, - - - where question_id = #{questionId,jdbcType=BIGINT} - + update s2_chat_query @@ -216,6 +98,9 @@ chat_id = #{chatId,jdbcType=BIGINT}, + + agent_id = #{agentId,jdbcType=INTEGER}, + score = #{score,jdbcType=INTEGER}, @@ -231,14 +116,5 @@ where question_id = #{questionId,jdbcType=BIGINT} - - update s2_chat_query - set create_time = #{createTime,jdbcType=TIMESTAMP}, - user_name = #{userName,jdbcType=VARCHAR}, - query_state = #{queryState,jdbcType=INTEGER}, - chat_id = #{chatId,jdbcType=BIGINT}, - score = #{score,jdbcType=INTEGER}, - feedback = #{feedback,jdbcType=VARCHAR} - where question_id = #{questionId,jdbcType=BIGINT} - + From 65653c0ee2749add3939510bcc4278637550ca6f Mon Sep 17 00:00:00 2001 From: jolunoluo Date: Mon, 25 Sep 2023 16:54:36 +0800 Subject: [PATCH 7/8] (improvement)(chat) save agentId in history query --- .../resources/mapper/ChatQueryDOMapper.xml | 142 +----------------- .../mapper/custom/ShowCaseCustomMapper.xml | 2 +- 2 files changed, 8 insertions(+), 136 deletions(-) diff --git a/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml b/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml index 72cbbc4b4..1cba2a902 100644 --- a/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml +++ b/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml @@ -3,6 +3,7 @@ + @@ -44,7 +45,7 @@ - question_id, create_time, user_name, query_state, chat_id, score, feedback + question_id, agent_id, create_time, user_name, query_state, chat_id, score, feedback query_text, query_result @@ -65,142 +66,23 @@ order by ${orderByClause} - - + + delete from s2_chat_query where question_id = #{questionId,jdbcType=BIGINT} - insert into s2_chat_query (question_id, create_time, user_name, + insert into s2_chat_query (question_id, agent_id, create_time, user_name, query_state, chat_id, score, feedback, query_text, query_result ) - values (#{questionId,jdbcType=BIGINT}, #{createTime,jdbcType=TIMESTAMP}, #{userName,jdbcType=VARCHAR}, + values (#{questionId,jdbcType=BIGINT}, #{agentId,jdbcType=INTEGER}, #{createTime,jdbcType=TIMESTAMP}, #{userName,jdbcType=VARCHAR}, #{queryState,jdbcType=INTEGER}, #{chatId,jdbcType=BIGINT}, #{score,jdbcType=INTEGER}, #{feedback,jdbcType=VARCHAR}, #{queryText,jdbcType=LONGVARCHAR}, #{queryResult,jdbcType=LONGVARCHAR} ) - - insert into s2_chat_query - - - question_id, - - - create_time, - - - user_name, - - - query_state, - - - chat_id, - - - score, - - - feedback, - - - query_text, - - - query_result, - - - - - #{questionId,jdbcType=BIGINT}, - - - #{createTime,jdbcType=TIMESTAMP}, - - - #{userName,jdbcType=VARCHAR}, - - - #{queryState,jdbcType=INTEGER}, - - - #{chatId,jdbcType=BIGINT}, - - - #{score,jdbcType=INTEGER}, - - - #{feedback,jdbcType=VARCHAR}, - - - #{queryText,jdbcType=LONGVARCHAR}, - - - #{queryResult,jdbcType=LONGVARCHAR}, - - - - - - update s2_chat_query - - - create_time = #{createTime,jdbcType=TIMESTAMP}, - - - user_name = #{userName,jdbcType=VARCHAR}, - - - query_state = #{queryState,jdbcType=INTEGER}, - - - chat_id = #{chatId,jdbcType=BIGINT}, - - - score = #{score,jdbcType=INTEGER}, - - - feedback = #{feedback,jdbcType=VARCHAR}, - - - query_text = #{queryText,jdbcType=LONGVARCHAR}, - - - query_result = #{queryResult,jdbcType=LONGVARCHAR}, - - - where question_id = #{questionId,jdbcType=BIGINT} - + update s2_chat_query @@ -231,14 +113,4 @@ where question_id = #{questionId,jdbcType=BIGINT} - - update s2_chat_query - set create_time = #{createTime,jdbcType=TIMESTAMP}, - user_name = #{userName,jdbcType=VARCHAR}, - query_state = #{queryState,jdbcType=INTEGER}, - chat_id = #{chatId,jdbcType=BIGINT}, - score = #{score,jdbcType=INTEGER}, - feedback = #{feedback,jdbcType=VARCHAR} - where question_id = #{questionId,jdbcType=BIGINT} - diff --git a/chat/core/src/main/resources/mapper/custom/ShowCaseCustomMapper.xml b/chat/core/src/main/resources/mapper/custom/ShowCaseCustomMapper.xml index adaf36822..7dcb1d213 100644 --- a/chat/core/src/main/resources/mapper/custom/ShowCaseCustomMapper.xml +++ b/chat/core/src/main/resources/mapper/custom/ShowCaseCustomMapper.xml @@ -59,7 +59,7 @@ join ( select distinct chat_id from s2_chat_query - where query_state = 0 and agent_id = ${agentId} + where query_state = 1 and agent_id = ${agentId} order by chat_id desc limit #{start}, #{limit} ) q2 From 34816451c0d7c53f7f4d75666c232adddce0a2fc Mon Sep 17 00:00:00 2001 From: jolunoluo Date: Mon, 25 Sep 2023 17:11:21 +0800 Subject: [PATCH 8/8] (improvement)(chat) recall history solved query in every parse --- .../chat/api/pojo/response/QueryResult.java | 1 - .../chat/corrector/GlobalCorrector.java | 6 + .../queryresponder/DefaultQueryResponder.java | 2 +- .../chat/service/impl/QueryServiceImpl.java | 6 +- .../util/jsqlparser/FiledExpression.java | 12 ++ .../jsqlparser/FiledFilterReplaceVisitor.java | 113 ++++++++++++++++++ .../jsqlparser/SqlParserSelectHelper.java | 30 ++++- .../jsqlparser/SqlParserUpdateHelper.java | 111 +++++++++++++++++ .../jsqlparser/SqlParserUpdateHelperTest.java | 82 +++++++++++++ 9 files changed, 352 insertions(+), 11 deletions(-) create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledExpression.java create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledFilterReplaceVisitor.java diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/QueryResult.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/QueryResult.java index 74020d195..3858f64f2 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/QueryResult.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/QueryResult.java @@ -21,5 +21,4 @@ public class QueryResult { private SemanticParseInfo chatContext; private Object response; private List> queryResults; - private List similarSolvedQuery; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalCorrector.java index 774e12aeb..ec9e819b0 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalCorrector.java @@ -6,6 +6,7 @@ import com.tencent.supersonic.chat.query.llm.dsl.LLMReq; import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import java.util.List; import java.util.Map; @@ -32,6 +33,11 @@ public class GlobalCorrector extends BaseSemanticCorrector { private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) { + if (SqlParserSelectHelper.hasGroupBy(semanticCorrectInfo.getSql())) { + + return; + } + } private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java index 8a7a80e9f..1c8030db0 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java @@ -95,7 +95,7 @@ public class DefaultQueryResponder implements QueryResponder { } } } catch (Exception e) { - log.warn("recall similar solved query failed", e); + log.warn("recall similar solved query failed, queryText:{}", queryText); } return solvedQueryRecallResps; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 40eb1aa51..4f4d8f0f4 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -143,15 +143,15 @@ public class QueryServiceImpl implements QueryService { saveInfo(timeCostDOList, queryReq.getQueryText(), parseResult.getQueryId(), queryReq.getUser().getName(), queryReq.getChatId().longValue()); } else { - List solvedQueryRecallResps = - queryResponder.recallSolvedQuery(queryCtx.getRequest().getQueryText()); parseResult = ParseResp.builder() .chatId(queryReq.getChatId()) .queryText(queryReq.getQueryText()) .state(ParseResp.ParseState.FAILED) - .similarSolvedQuery(solvedQueryRecallResps) .build(); } + List solvedQueryRecallResps = + queryResponder.recallSolvedQuery(queryCtx.getRequest().getQueryText()); + parseResult.setSimilarSolvedQuery(solvedQueryRecallResps); return parseResult; } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledExpression.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledExpression.java new file mode 100644 index 000000000..e19ae3d80 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledExpression.java @@ -0,0 +1,12 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import lombok.Data; + +@Data +public class FiledExpression { + + private String operator; + + private String fieldName; + +} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledFilterReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledFilterReplaceVisitor.java new file mode 100644 index 000000000..9950e8e43 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FiledFilterReplaceVisitor.java @@ -0,0 +1,113 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; +import net.sf.jsqlparser.expression.Function; +import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; +import net.sf.jsqlparser.expression.operators.relational.EqualsTo; +import net.sf.jsqlparser.expression.operators.relational.GreaterThan; +import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; +import net.sf.jsqlparser.expression.operators.relational.MinorThan; +import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.schema.Column; +import org.apache.commons.collections.CollectionUtils; + +@Slf4j +public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter { + + private List waitingForAdds = new ArrayList<>(); + private Set fieldNames; + + public FiledFilterReplaceVisitor(Set fieldNames) { + this.fieldNames = fieldNames; + } + + @Override + public void visit(MinorThan expr) { + List expressions = parserFilter(expr); + if (Objects.nonNull(expressions)) { + waitingForAdds.addAll(expressions); + } + } + + @Override + public void visit(EqualsTo expr) { + List expressions = parserFilter(expr); + if (Objects.nonNull(expressions)) { + waitingForAdds.addAll(expressions); + } + } + + @Override + public void visit(MinorThanEquals expr) { + List expressions = parserFilter(expr); + if (Objects.nonNull(expressions)) { + waitingForAdds.addAll(expressions); + } + } + + + @Override + public void visit(GreaterThan expr) { + List expressions = parserFilter(expr); + if (Objects.nonNull(expressions)) { + waitingForAdds.addAll(expressions); + } + } + + @Override + public void visit(GreaterThanEquals expr) { + List expressions = parserFilter(expr); + if (Objects.nonNull(expressions)) { + waitingForAdds.addAll(expressions); + } + } + + public List getWaitingForAdds() { + return waitingForAdds; + } + + + public List parserFilter(ComparisonOperator comparisonOperator) { + List result = new ArrayList<>(); + String toString = comparisonOperator.toString(); + Expression leftExpression = comparisonOperator.getLeftExpression(); + if (!(leftExpression instanceof Function)) { + return result; + } + Function leftExpressionFunction = (Function) leftExpression; + if (leftExpressionFunction.toString().contains(DateFunctionHelper.DATE_FUNCTION)) { + return result; + } + + List leftExpressions = leftExpressionFunction.getParameters().getExpressions(); + if (CollectionUtils.isEmpty(leftExpressions)) { + return result; + } + Column field = (Column) leftExpressions.get(0); + String columnName = field.getColumnName(); + if (!fieldNames.contains(columnName)) { + return null; + } + try { + ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(" 1 = 1 "); + comparisonOperator.setLeftExpression(expression.getLeftExpression()); + comparisonOperator.setRightExpression(expression.getRightExpression()); + comparisonOperator.setASTNode(expression.getASTNode()); + result.add(CCJSqlParserUtil.parseCondExpression(toString)); + return result; + } catch (JSQLParserException e) { + log.error("JSQLParserException", e); + } + return null; + } + +} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java index 59d40f166..a1e4002be 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java @@ -205,6 +205,30 @@ public class SqlParserSelectHelper { public static boolean hasAggregateFunction(String sql) { + if (hasFunction(sql)) { + return true; + } + return hasGroupBy(sql); + } + + public static boolean hasGroupBy(String sql) { + Select selectStatement = getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return false; + } + PlainSelect plainSelect = (PlainSelect) selectBody; + GroupByElement groupBy = plainSelect.getGroupBy(); + if (Objects.nonNull(groupBy)) { + GroupByVisitor replaceVisitor = new GroupByVisitor(); + groupBy.accept(replaceVisitor); + return replaceVisitor.isHasAggregateFunction(); + } + return false; + } + + public static boolean hasFunction(String sql) { Select selectStatement = getSelect(sql); SelectBody selectBody = selectStatement.getSelectBody(); @@ -221,12 +245,6 @@ public class SqlParserSelectHelper { if (selectFunction) { return true; } - GroupByElement groupBy = plainSelect.getGroupBy(); - if (Objects.nonNull(groupBy)) { - GroupByVisitor replaceVisitor = new GroupByVisitor(); - groupBy.accept(replaceVisitor); - return replaceVisitor.isHasAggregateFunction(); - } return false; } } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java index b841613b1..48906a2d0 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java @@ -11,6 +11,7 @@ import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.relational.EqualsTo; +import net.sf.jsqlparser.expression.operators.relational.ExpressionList; import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.statement.select.GroupByElement; @@ -20,6 +21,7 @@ import net.sf.jsqlparser.statement.select.Select; import net.sf.jsqlparser.statement.select.SelectBody; import net.sf.jsqlparser.statement.select.SelectExpressionItem; import net.sf.jsqlparser.statement.select.SelectItem; +import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; import net.sf.jsqlparser.util.SelectUtils; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; @@ -278,5 +280,114 @@ public class SqlParserUpdateHelper { return selectStatement.toString(); } + public static String addAggregateToField(String sql, Map fieldNameToAggregate) { + if (SqlParserSelectHelper.hasGroupBy(sql)) { + return sql; + } + + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + selectBody.accept(new SelectVisitorAdapter() { + @Override + public void visit(PlainSelect plainSelect) { + addAggregateToSelectItems(plainSelect.getSelectItems(), fieldNameToAggregate); + addAggregateToOrderByItems(plainSelect.getOrderByElements(), fieldNameToAggregate); + } + }); + return selectStatement.toString(); + } + + public static String addGroupBy(String sql, List groupByFields) { + if (SqlParserSelectHelper.hasGroupBy(sql)) { + return sql; + } + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + + PlainSelect plainSelect = (PlainSelect) selectBody; + GroupByElement groupByElement = new GroupByElement(); + for (String groupByField : groupByFields) { + groupByElement.addGroupByExpression(new Column(groupByField)); + } + plainSelect.setGroupByElement(groupByElement); + return selectStatement.toString(); + } + + private static void addAggregateToSelectItems(List selectItems, + Map fieldNameToAggregate) { + for (SelectItem selectItem : selectItems) { + if (selectItem instanceof SelectExpressionItem) { + SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem; + Expression expression = selectExpressionItem.getExpression(); + String columnName = ((Column) expression).getColumnName(); + Function function = getFunction(expression, fieldNameToAggregate.get(columnName)); + if (Objects.isNull(function)) { + continue; + } + selectExpressionItem.setExpression(function); + } + } + } + + private static void addAggregateToOrderByItems(List orderByElements, + Map fieldNameToAggregate) { + if (orderByElements == null) { + return; + } + for (OrderByElement orderByElement : orderByElements) { + Expression expression = orderByElement.getExpression(); + String columnName = ((Column) expression).getColumnName(); + if (StringUtils.isEmpty(columnName)) { + continue; + } + Function function = getFunction(expression, fieldNameToAggregate.get(columnName)); + if (Objects.isNull(function)) { + continue; + } + orderByElement.setExpression(function); + } + } + + private static Function getFunction(Expression expression, String aggregateName) { + if (StringUtils.isEmpty(aggregateName)) { + return null; + } + Function sumFunction = new Function(); + sumFunction.setName(aggregateName); + sumFunction.setParameters(new ExpressionList(expression)); + return sumFunction; + } + + public static String addHaving(String sql, Set fieldNames) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + + PlainSelect plainSelect = (PlainSelect) selectBody; + //replace metric to 1 and 1 and add having metric + Expression where = plainSelect.getWhere(); + FiledFilterReplaceVisitor visitor = new FiledFilterReplaceVisitor(fieldNames); + if (Objects.nonNull(where)) { + where.accept(visitor); + } + List waitingForAdds = visitor.getWaitingForAdds(); + if (!CollectionUtils.isEmpty(waitingForAdds)) { + for (Expression waitingForAdd : waitingForAdds) { + plainSelect.setHaving(waitingForAdd); + } + } + return selectStatement.toString(); + } } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java index 77c0ab4c2..8745235e8 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java @@ -1,9 +1,12 @@ package com.tencent.supersonic.common.util.jsqlparser; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.parser.CCJSqlParserUtil; @@ -266,6 +269,85 @@ class SqlParserUpdateHelperTest { } + @Test + void addAggregateToField() { + String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + + String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression); + System.out.println(replaceSql); + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + replaceSql); + + sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + havingExpression = SqlParserSelectHelper.getHavingExpression(sql); + + replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression); + System.out.println(replaceSql); + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + replaceSql); + + } + + + @Test + void addAggregateToMetricField() { + String sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' order by pv desc limit 10"; + + Map filedNameToAggregate = new HashMap<>(); + filedNameToAggregate.put("pv", "sum"); + + List groupByFields = new ArrayList<>(); + groupByFields.add("department"); + + String replaceSql = SqlParserUpdateHelper.addAggregateToField(sql, filedNameToAggregate); + replaceSql = SqlParserUpdateHelper.addGroupBy(replaceSql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + } + + @Test + void addGroupBy() { + String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' " + + "order by sum(pv) desc limit 10"; + + List groupByFields = new ArrayList<>(); + groupByFields.add("department"); + + String replaceSql = SqlParserUpdateHelper.addGroupBy(sql, groupByFields); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + } + + @Test + void addHaving() { + String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' and " + + "sum(pv) > 2000 group by department order by sum(pv) desc limit 10"; + List groupByFields = new ArrayList<>(); + groupByFields.add("department"); + + Set fieldNames = new HashSet<>(); + fieldNames.add("pv"); + + String replaceSql = SqlParserUpdateHelper.addHaving(sql, fieldNames); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "AND 1 > 1 GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10", + replaceSql); + } + + private Map initParams() { Map fieldToBizName = new HashMap<>(); fieldToBizName.put("部门", "department");