diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingPersistentTask.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingPersistentTask.java deleted file mode 100644 index 502d227f2..000000000 --- a/common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingPersistentTask.java +++ /dev/null @@ -1,32 +0,0 @@ -package com.tencent.supersonic.common.util.embedding; - -import com.tencent.supersonic.common.util.ComponentFactory; -import javax.annotation.PreDestroy; -import lombok.extern.slf4j.Slf4j; -import org.springframework.scheduling.annotation.Scheduled; -import org.springframework.stereotype.Component; - -@Component -@Slf4j -public class EmbeddingPersistentTask { - - private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore(); - - @PreDestroy - public void onShutdown() { - embeddingStorePersistentToFile(); - } - - private void embeddingStorePersistentToFile() { - if (s2EmbeddingStore instanceof InMemoryS2EmbeddingStore) { - log.info("start persistentToFile"); - ((InMemoryS2EmbeddingStore) s2EmbeddingStore).persistentToFile(); - log.info("end persistentToFile"); - } - } - - @Scheduled(cron = "${inMemoryEmbeddingStore.persistent.cron:0 0 * * * ?}") - public void executeTask() { - embeddingStorePersistentToFile(); - } -} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingQuery.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingQuery.java index f57448857..0fa077061 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingQuery.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingQuery.java @@ -1,10 +1,13 @@ package com.tencent.supersonic.common.util.embedding; +import com.alibaba.fastjson.JSONObject; +import com.tencent.supersonic.common.pojo.DataItem; import lombok.Data; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; @Data public class EmbeddingQuery { @@ -17,4 +20,17 @@ public class EmbeddingQuery { private List queryEmbedding; + public static List convertToEmbedding(List dataItems) { + return dataItems.stream().map(dataItem -> { + EmbeddingQuery embeddingQuery = new EmbeddingQuery(); + embeddingQuery.setQueryId( + dataItem.getId() + dataItem.getType().name().toLowerCase()); + embeddingQuery.setQuery(dataItem.getName()); + Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class); + embeddingQuery.setMetadata(meta); + embeddingQuery.setQueryEmbedding(null); + return embeddingQuery; + }).collect(Collectors.toList()); + } + } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QuerySqlReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QuerySqlReq.java index 91aed844d..85b90d1b4 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QuerySqlReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QuerySqlReq.java @@ -2,6 +2,9 @@ package com.tencent.supersonic.headless.api.pojo.request; import lombok.Data; import lombok.ToString; +import org.apache.commons.collections.CollectionUtils; + +import java.util.Objects; @Data @ToString @@ -25,4 +28,8 @@ public class QuerySqlReq extends SemanticQueryReq { return stringBuilder.toString(); } + public boolean needGetDataSetId() { + return (Objects.isNull(this.getDataSetId()) || this.getDataSetId() <= 0) + && (CollectionUtils.isEmpty(this.getModelIds())); + } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java index 3bcd07ff6..e1a2b0256 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java @@ -1,15 +1,12 @@ package com.tencent.supersonic.headless.server.listener; -import com.alibaba.fastjson.JSONObject; import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.pojo.DataEvent; +import com.tencent.supersonic.common.pojo.DataItem; import com.tencent.supersonic.common.pojo.enums.EventType; import com.tencent.supersonic.common.util.ComponentFactory; import com.tencent.supersonic.common.util.embedding.EmbeddingQuery; import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; @@ -18,6 +15,8 @@ import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; +import java.util.List; + @Component @Slf4j public class MetaEmbeddingListener implements ApplicationListener { @@ -33,30 +32,15 @@ public class MetaEmbeddingListener implements ApplicationListener { @Async @Override public void onApplicationEvent(DataEvent event) { - if (CollectionUtils.isEmpty(event.getDataItems())) { + List dataItems = event.getDataItems(); + if (CollectionUtils.isEmpty(dataItems)) { return; } - - List embeddingQueries = event.getDataItems() - .stream() - .map(dataItem -> { - EmbeddingQuery embeddingQuery = new EmbeddingQuery(); - embeddingQuery.setQueryId( - dataItem.getId() + dataItem.getType().name().toLowerCase()); - embeddingQuery.setQuery(dataItem.getName()); - Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class); - embeddingQuery.setMetadata(meta); - embeddingQuery.setQueryEmbedding(null); - return embeddingQuery; - }).collect(Collectors.toList()); + List embeddingQueries = EmbeddingQuery.convertToEmbedding(dataItems); if (CollectionUtils.isEmpty(embeddingQueries)) { return; } - try { - Thread.sleep(embeddingOperationSleepTime); - } catch (InterruptedException e) { - log.error("", e); - } + sleep(); s2EmbeddingStore.addCollection(embeddingConfig.getMetaCollectionName()); if (event.getEventType().equals(EventType.ADD)) { s2EmbeddingStore.addQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries); @@ -68,4 +52,12 @@ public class MetaEmbeddingListener implements ApplicationListener { } } + private void sleep() { + try { + Thread.sleep(embeddingOperationSleepTime); + } catch (InterruptedException e) { + log.error("", e); + } + } + } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/KnowledgeController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/KnowledgeController.java index 4cd31fd71..01ea4a6c2 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/KnowledgeController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/KnowledgeController.java @@ -3,25 +3,25 @@ package com.tencent.supersonic.headless.server.rest; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; - -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.validation.Valid; - import com.tencent.supersonic.headless.api.pojo.request.DictItemFilter; import com.tencent.supersonic.headless.api.pojo.request.DictItemReq; import com.tencent.supersonic.headless.api.pojo.request.DictSingleTaskReq; import com.tencent.supersonic.headless.api.pojo.response.DictItemResp; import com.tencent.supersonic.headless.api.pojo.response.DictTaskResp; +import com.tencent.supersonic.headless.server.schedule.EmbeddingTask; import com.tencent.supersonic.headless.server.service.DictConfService; import com.tencent.supersonic.headless.server.service.DictTaskService; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PutMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.validation.Valid; import java.util.List; @@ -35,6 +35,9 @@ public class KnowledgeController { @Autowired private DictConfService confService; + @Autowired + private EmbeddingTask embeddingTask; + /** * addDictConf-新增item的字典配置 * Add configuration information for dictionary entries @@ -43,8 +46,8 @@ public class KnowledgeController { */ @PostMapping("/conf") public DictItemResp addDictConf(@RequestBody @Valid DictItemReq dictItemReq, - HttpServletRequest request, - HttpServletResponse response) { + HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); return confService.addDictConf(dictItemReq, user); } @@ -57,8 +60,8 @@ public class KnowledgeController { */ @PutMapping("/conf") public DictItemResp editDictConf(@RequestBody @Valid DictItemReq dictItemReq, - HttpServletRequest request, - HttpServletResponse response) { + HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); return confService.editDictConf(dictItemReq, user); } @@ -129,4 +132,9 @@ public class KnowledgeController { return taskService.queryLatestDictTask(taskReq, user); } + @GetMapping("/meta/embedding/reload") + public Object reloadMetaEmbedding() { + embeddingTask.reloadMetaEmbedding(); + return true; + } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/SqlQueryApiController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/SqlQueryApiController.java index d9f152ce8..de6965f57 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/SqlQueryApiController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/SqlQueryApiController.java @@ -2,18 +2,11 @@ package com.tencent.supersonic.headless.server.rest.api; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; -import com.tencent.supersonic.common.pojo.enums.QueryType; -import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.api.pojo.SemanticSchema; -import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlsReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; -import com.tencent.supersonic.headless.core.chat.corrector.GrammarCorrector; -import com.tencent.supersonic.headless.core.pojo.QueryContext; +import com.tencent.supersonic.headless.server.service.ChatQueryService; import com.tencent.supersonic.headless.server.service.QueryService; -import com.tencent.supersonic.headless.server.service.impl.SemanticService; -import com.tencent.supersonic.headless.server.utils.ComponentFactory; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; @@ -36,14 +29,14 @@ public class SqlQueryApiController { private QueryService queryService; @Autowired - private SemanticService semanticService; + private ChatQueryService chatQueryService; @PostMapping("/sql") public Object queryBySql(@RequestBody QuerySqlReq querySqlReq, HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); - correct(querySqlReq); + chatQueryService.correct(querySqlReq, user); return queryService.queryByReq(querySqlReq, user); } @@ -57,28 +50,9 @@ public class SqlQueryApiController { QuerySqlReq querySqlReq = new QuerySqlReq(); BeanUtils.copyProperties(querySqlsReq, querySqlReq); querySqlReq.setSql(sql); - correct(querySqlReq); + chatQueryService.correct(querySqlReq, user); return querySqlReq; }).collect(Collectors.toList()); return queryService.queryByReqs(semanticQueryReqs, user); } - - private void correct(QuerySqlReq querySqlReq) { - QueryContext queryCtx = new QueryContext(); - SemanticSchema semanticSchema = semanticService.getSemanticSchema(); - queryCtx.setSemanticSchema(semanticSchema); - SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); - SqlInfo sqlInfo = new SqlInfo(); - sqlInfo.setCorrectS2SQL(querySqlReq.getSql()); - sqlInfo.setS2SQL(querySqlReq.getSql()); - semanticParseInfo.setSqlInfo(sqlInfo); - semanticParseInfo.setQueryType(QueryType.TAG); - - ComponentFactory.getSemanticCorrectors().forEach(corrector -> { - if (!(corrector instanceof GrammarCorrector)) { - corrector.correct(queryCtx, semanticParseInfo); - } - }); - querySqlReq.setSql(sqlInfo.getCorrectS2SQL()); - } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/schedule/EmbeddingTask.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/schedule/EmbeddingTask.java new file mode 100644 index 000000000..c57546ce1 --- /dev/null +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/schedule/EmbeddingTask.java @@ -0,0 +1,72 @@ +package com.tencent.supersonic.headless.server.schedule; + +import com.tencent.supersonic.common.config.EmbeddingConfig; +import com.tencent.supersonic.common.pojo.DataItem; +import com.tencent.supersonic.common.util.ComponentFactory; +import com.tencent.supersonic.common.util.embedding.EmbeddingQuery; +import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore; +import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore; +import com.tencent.supersonic.headless.server.service.DimensionService; +import com.tencent.supersonic.headless.server.service.MetricService; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.scheduling.annotation.Scheduled; +import org.springframework.stereotype.Component; + +import javax.annotation.PreDestroy; +import java.util.List; + +@Component +@Slf4j +public class EmbeddingTask { + + private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore(); + @Autowired + private EmbeddingConfig embeddingConfig; + @Autowired + private MetricService metricService; + + @Autowired + private DimensionService dimensionService; + + @PreDestroy + public void onShutdown() { + embeddingStorePersistentToFile(); + } + + private void embeddingStorePersistentToFile() { + if (s2EmbeddingStore instanceof InMemoryS2EmbeddingStore) { + log.info("start persistentToFile"); + ((InMemoryS2EmbeddingStore) s2EmbeddingStore).persistentToFile(); + log.info("end persistentToFile"); + } + } + + @Scheduled(cron = "${inMemoryEmbeddingStore.persistent.cron:0 0 * * * ?}") + public void executeTask() { + embeddingStorePersistentToFile(); + } + + + /*** + * reload meta embedding + */ + @Scheduled(cron = "${reload.meta.embedding.corn:0 0 */2 * * ?}") + public void reloadMetaEmbedding() { + log.info("reload.meta.embedding start"); + try { + List metricDataItems = metricService.getDataEvent().getDataItems(); + + s2EmbeddingStore.addQuery(embeddingConfig.getMetaCollectionName(), + EmbeddingQuery.convertToEmbedding(metricDataItems)); + + List dimensionDataItems = dimensionService.getDataEvent().getDataItems(); + s2EmbeddingStore.addQuery(embeddingConfig.getMetaCollectionName(), + EmbeddingQuery.convertToEmbedding(dimensionDataItems)); + } catch (Exception e) { + log.error("reload.meta.embedding error", e); + } + + log.info("reload.meta.embedding end"); + } +} \ No newline at end of file diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ChatQueryService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ChatQueryService.java index 88f62435b..cb160b384 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ChatQueryService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ChatQueryService.java @@ -7,6 +7,7 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; import com.tencent.supersonic.headless.api.pojo.request.QueryReq; +import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.QueryResult; @@ -29,5 +30,7 @@ public interface ChatQueryService { EntityInfo getEntityInfo(SemanticParseInfo parseInfo, User user); Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception; + + void correct(QuerySqlReq querySqlReq, User user); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DataSetService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DataSetService.java index 5c15ff815..1d1000e5b 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DataSetService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DataSetService.java @@ -38,4 +38,6 @@ public interface DataSetService { SemanticQueryReq convert(QueryDataSetReq queryDataSetReq); + Long getDataSetIdFromSql(String sql, User user); + } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DimensionService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DimensionService.java index 80c6509ad..4c23dfbf5 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DimensionService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DimensionService.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.server.service; import com.github.pagehelper.PageInfo; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.common.pojo.DataEvent; import com.tencent.supersonic.common.pojo.enums.EventType; import com.tencent.supersonic.headless.api.pojo.DimValueMap; import com.tencent.supersonic.headless.api.pojo.request.DimensionReq; @@ -10,6 +11,7 @@ import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; import com.tencent.supersonic.headless.server.pojo.DimensionsFilter; import com.tencent.supersonic.headless.server.pojo.MetaFilter; + import java.util.List; public interface DimensionService { @@ -41,4 +43,6 @@ public interface DimensionService { List mockDimensionValueAlias(DimensionReq dimensionReq, User user); void sendDimensionEventBatch(List modelIds, EventType eventType); + + DataEvent getDataEvent(); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/MetricService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/MetricService.java index 9d312b749..1e51122ef 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/MetricService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/MetricService.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.server.service; import com.github.pagehelper.PageInfo; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.common.pojo.DataEvent; import com.tencent.supersonic.common.pojo.enums.EventType; import com.tencent.supersonic.headless.api.pojo.DrillDownDimension; import com.tencent.supersonic.headless.api.pojo.MetricQueryDefaultConfig; @@ -14,6 +15,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.response.MetricResp; import com.tencent.supersonic.headless.server.pojo.MetaFilter; import com.tencent.supersonic.headless.server.pojo.MetricsFilter; + import java.util.List; import java.util.Set; @@ -60,4 +62,6 @@ public interface MetricService { List queryMetrics(MetricsFilter metricsFilter); QueryStructReq convert(QueryMetricReq queryMetricReq); + + DataEvent getDataEvent(); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/QueryService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/QueryService.java index 1c91edd25..c2462a3df 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/QueryService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/QueryService.java @@ -21,5 +21,4 @@ public interface QueryService { List getStatInfo(ItemUseReq itemUseCommend); ExplainResp explain(ExplainSqlReq explainSqlReq, User user) throws Exception; - } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java index 9bf9d9809..6412e7115 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java @@ -20,6 +20,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; +import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.enums.CostType; import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; @@ -29,6 +30,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.api.pojo.request.QueryReq; +import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.response.ExplainResp; @@ -37,6 +39,7 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; +import com.tencent.supersonic.headless.core.chat.corrector.GrammarCorrector; import com.tencent.supersonic.headless.core.chat.corrector.SemanticCorrector; import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult; import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService; @@ -647,5 +650,30 @@ public class ChatQueryServiceImpl implements ChatQueryService { return queryService.queryByReq(queryStructReq, user); } -} + public void correct(QuerySqlReq querySqlReq, User user) { + QueryContext queryCtx = new QueryContext(); + SemanticSchema semanticSchema = semanticService.getSemanticSchema(); + queryCtx.setSemanticSchema(semanticSchema); + SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); + SqlInfo sqlInfo = new SqlInfo(); + sqlInfo.setCorrectS2SQL(querySqlReq.getSql()); + sqlInfo.setS2SQL(querySqlReq.getSql()); + semanticParseInfo.setSqlInfo(sqlInfo); + semanticParseInfo.setQueryType(QueryType.TAG); + Long dataSetId = querySqlReq.getDataSetId(); + if (Objects.isNull(dataSetId)) { + dataSetId = dataSetService.getDataSetIdFromSql(querySqlReq.getSql(), user); + } + SchemaElement dataSet = semanticSchema.getDataSet(dataSetId); + semanticParseInfo.setDataSet(dataSet); + + ComponentFactory.getSemanticCorrectors().forEach(corrector -> { + if (!(corrector instanceof GrammarCorrector)) { + corrector.correct(queryCtx, semanticParseInfo); + } + }); + querySqlReq.setSql(sqlInfo.getCorrectS2SQL()); + } + +} \ No newline at end of file diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java index 4b4065975..0f867d0c2 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java @@ -13,6 +13,7 @@ import com.tencent.supersonic.common.pojo.enums.StatusEnum; import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.common.util.BeanMapper; +import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.headless.api.pojo.DataSetDetail; import com.tencent.supersonic.headless.api.pojo.QueryConfig; import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType; @@ -34,6 +35,15 @@ import com.tencent.supersonic.headless.server.service.DimensionService; import com.tencent.supersonic.headless.server.service.DomainService; import com.tencent.supersonic.headless.server.service.MetricService; import com.tencent.supersonic.headless.server.service.TagMetaService; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.springframework.beans.BeanUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Lazy; +import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; + import java.util.Arrays; import java.util.Comparator; import java.util.Date; @@ -45,15 +55,9 @@ import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.function.Function; import java.util.stream.Collectors; -import org.apache.commons.lang3.StringUtils; -import org.apache.commons.lang3.tuple.Pair; -import org.springframework.beans.BeanUtils; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.context.annotation.Lazy; -import org.springframework.stereotype.Service; -import org.springframework.util.CollectionUtils; @Service +@Slf4j public class DataSetServiceImpl extends ServiceImpl implements DataSetService { @@ -311,4 +315,21 @@ public class DataSetServiceImpl .map(Object::toString) .collect(Collectors.toList()); } + + public Long getDataSetIdFromSql(String sql, User user) { + List dataSets = null; + try { + String tableName = SqlSelectHelper.getTableName(sql); + dataSets = getDataSets(tableName, user); + } catch (Exception e) { + log.error("getDataSetIdFromSql error:{}", e); + } + if (org.apache.commons.collections.CollectionUtils.isEmpty(dataSets)) { + throw new InvalidArgumentException("从Sql参数中无法获取到DataSetId"); + } + Long dataSetId = dataSets.get(0).getId(); + log.info("getDataSetIdFromSql dataSetId:{}", dataSetId); + return dataSetId; + } + } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DimensionServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DimensionServiceImpl.java index 1c3ac182c..b20446421 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DimensionServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DimensionServiceImpl.java @@ -430,17 +430,27 @@ public class DimensionServiceImpl implements DimensionService { } private void sendEventBatch(List dimensionDOS, EventType eventType) { + DataEvent dataEvent = getDataEvent(dimensionDOS, eventType); + eventPublisher.publishEvent(dataEvent); + } + + public DataEvent getDataEvent() { + DimensionFilter dimensionFilter = new DimensionFilter(); + List dimensionDOS = queryDimension(dimensionFilter); + return getDataEvent(dimensionDOS, EventType.ADD); + } + + private DataEvent getDataEvent(List dimensionDOS, EventType eventType) { List dataItems = dimensionDOS.stream() .map(dimensionDO -> DataItem.builder().id(dimensionDO.getId() + Constants.UNDERLINE) .name(dimensionDO.getName()).modelId(dimensionDO.getModelId() + Constants.UNDERLINE) .type(TypeEnums.DIMENSION).build()) .collect(Collectors.toList()); - eventPublisher.publishEvent(new DataEvent(this, dataItems, eventType)); + return new DataEvent(this, dataItems, eventType); } private void sendEvent(DataItem dataItem, EventType eventType) { eventPublisher.publishEvent(new DataEvent(this, Lists.newArrayList(dataItem), eventType)); } - -} +} \ No newline at end of file diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java index f00715621..a8af89d40 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java @@ -97,13 +97,13 @@ public class MetricServiceImpl implements MetricService { private TagMetaService tagMetaService; public MetricServiceImpl(MetricRepository metricRepository, - ModelService modelService, - ChatGptHelper chatGptHelper, - CollectService collectService, - DataSetService dataSetService, - ApplicationEventPublisher eventPublisher, - DimensionService dimensionService, - TagMetaService tagMetaService) { + ModelService modelService, + ChatGptHelper chatGptHelper, + CollectService collectService, + DataSetService dataSetService, + ApplicationEventPublisher eventPublisher, + DimensionService dimensionService, + TagMetaService tagMetaService) { this.metricRepository = metricRepository; this.modelService = modelService; this.chatGptHelper = chatGptHelper; @@ -326,7 +326,7 @@ public class MetricServiceImpl implements MetricService { } private boolean filterByField(List metricResps, MetricResp metricResp, - List fields, Set metricRespFiltered) { + List fields, Set metricRespFiltered) { if (MetricDefineType.METRIC.equals(metricResp.getMetricDefineType())) { List ids = metricResp.getMetricDefineByMetricParams().getMetrics() .stream().map(MetricParam::getId).collect(Collectors.toList()); @@ -556,9 +556,20 @@ public class MetricServiceImpl implements MetricService { } private void sendEventBatch(List metricDOS, EventType eventType) { + DataEvent dataEvent = getDataEvent(metricDOS, eventType); + eventPublisher.publishEvent(dataEvent); + } + + public DataEvent getDataEvent() { + MetricsFilter metricsFilter = new MetricsFilter(); + List metricDOS = metricRepository.getMetrics(metricsFilter); + return getDataEvent(metricDOS, EventType.ADD); + } + + private DataEvent getDataEvent(List metricDOS, EventType eventType) { List dataItems = metricDOS.stream().map(this::getDataItem) .collect(Collectors.toList()); - eventPublisher.publishEvent(new DataEvent(this, dataItems, eventType)); + return new DataEvent(this, dataItems, eventType); } private void sendEvent(DataItem dataItem, EventType eventType) { @@ -662,7 +673,7 @@ public class MetricServiceImpl implements MetricService { } private Set getModelIds(Set modelIdsByDomainId, List metricResps, - List dimensionResps) { + List dimensionResps) { Set result = new HashSet<>(); if (org.apache.commons.collections.CollectionUtils.isNotEmpty(modelIdsByDomainId)) { result.addAll(modelIdsByDomainId); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/QueryServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/QueryServiceImpl.java index 8d890ad71..85a396977 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/QueryServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/QueryServiceImpl.java @@ -5,8 +5,6 @@ import com.google.common.collect.Sets; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; -import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; -import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.headless.api.pojo.Dim; import com.tencent.supersonic.headless.api.pojo.QueryParam; import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq; @@ -17,7 +15,6 @@ import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; -import com.tencent.supersonic.headless.api.pojo.response.DataSetResp; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; import com.tencent.supersonic.headless.api.pojo.response.ExplainResp; import com.tencent.supersonic.headless.api.pojo.response.ItemUseResp; @@ -137,8 +134,8 @@ public class QueryServiceImpl implements QueryService { private QueryStatement buildSqlQueryStatement(QuerySqlReq querySqlReq, User user) throws Exception { //If dataSetId or DataSetName is empty, parse dataSetId from the SQL - if (needGetDataSetId(querySqlReq)) { - Long dataSetId = getDataSetIdFromSql(querySqlReq, user); + if (querySqlReq.needGetDataSetId()) { + Long dataSetId = dataSetService.getDataSetIdFromSql(querySqlReq.getSql(), user); querySqlReq.setDataSetId(dataSetId); } SchemaFilterReq filter = buildSchemaFilterReq(querySqlReq); @@ -151,27 +148,6 @@ public class QueryServiceImpl implements QueryService { return queryStatement; } - private static boolean needGetDataSetId(QuerySqlReq querySqlReq) { - return (Objects.isNull(querySqlReq.getDataSetId()) || querySqlReq.getDataSetId() <= 0) - && (CollectionUtils.isEmpty(querySqlReq.getModelIds())); - } - - private Long getDataSetIdFromSql(QuerySqlReq querySqlReq, User user) { - List dataSets = null; - try { - String tableName = SqlSelectHelper.getTableName(querySqlReq.getSql()); - dataSets = dataSetService.getDataSets(tableName, user); - } catch (Exception e) { - log.error("getDataSetIdFromSql error:{}", e); - } - if (CollectionUtils.isEmpty(dataSets)) { - throw new InvalidArgumentException("从Sql参数中无法获取到DataSetId"); - } - Long dataSetId = dataSets.get(0).getId(); - log.info("getDataSetIdFromSql dataSetId:{}", dataSetId); - return dataSetId; - } - private QueryStatement buildQueryStatement(SemanticQueryReq semanticQueryReq, User user) throws Exception { if (semanticQueryReq instanceof QuerySqlReq) { return buildSqlQueryStatement((QuerySqlReq) semanticQueryReq, user);