(improvement)(Headless) Periodically refresh embedding metadata in full and optimize the code. (#917)

This commit is contained in:
lexluo09
2024-04-17 22:38:19 +08:00
committed by GitHub
parent d8c23cca05
commit ee798b7671
17 changed files with 237 additions and 142 deletions

View File

@@ -1,32 +0,0 @@
package com.tencent.supersonic.common.util.embedding;
import com.tencent.supersonic.common.util.ComponentFactory;
import javax.annotation.PreDestroy;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class EmbeddingPersistentTask {
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
@PreDestroy
public void onShutdown() {
embeddingStorePersistentToFile();
}
private void embeddingStorePersistentToFile() {
if (s2EmbeddingStore instanceof InMemoryS2EmbeddingStore) {
log.info("start persistentToFile");
((InMemoryS2EmbeddingStore) s2EmbeddingStore).persistentToFile();
log.info("end persistentToFile");
}
}
@Scheduled(cron = "${inMemoryEmbeddingStore.persistent.cron:0 0 * * * ?}")
public void executeTask() {
embeddingStorePersistentToFile();
}
}

View File

@@ -1,10 +1,13 @@
package com.tencent.supersonic.common.util.embedding; package com.tencent.supersonic.common.util.embedding;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.common.pojo.DataItem;
import lombok.Data; import lombok.Data;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors;
@Data @Data
public class EmbeddingQuery { public class EmbeddingQuery {
@@ -17,4 +20,17 @@ public class EmbeddingQuery {
private List<Double> queryEmbedding; private List<Double> queryEmbedding;
public static List<EmbeddingQuery> convertToEmbedding(List<DataItem> dataItems) {
return dataItems.stream().map(dataItem -> {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(
dataItem.getId() + dataItem.getType().name().toLowerCase());
embeddingQuery.setQuery(dataItem.getName());
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
embeddingQuery.setMetadata(meta);
embeddingQuery.setQueryEmbedding(null);
return embeddingQuery;
}).collect(Collectors.toList());
}
} }

View File

@@ -2,6 +2,9 @@ package com.tencent.supersonic.headless.api.pojo.request;
import lombok.Data; import lombok.Data;
import lombok.ToString; import lombok.ToString;
import org.apache.commons.collections.CollectionUtils;
import java.util.Objects;
@Data @Data
@ToString @ToString
@@ -25,4 +28,8 @@ public class QuerySqlReq extends SemanticQueryReq {
return stringBuilder.toString(); return stringBuilder.toString();
} }
public boolean needGetDataSetId() {
return (Objects.isNull(this.getDataSetId()) || this.getDataSetId() <= 0)
&& (CollectionUtils.isEmpty(this.getModelIds()));
}
} }

View File

@@ -1,15 +1,12 @@
package com.tencent.supersonic.headless.server.listener; package com.tencent.supersonic.headless.server.listener;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.DataEvent; import com.tencent.supersonic.common.pojo.DataEvent;
import com.tencent.supersonic.common.pojo.DataItem;
import com.tencent.supersonic.common.pojo.enums.EventType; import com.tencent.supersonic.common.pojo.enums.EventType;
import com.tencent.supersonic.common.util.ComponentFactory; import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery; import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore; import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
@@ -18,6 +15,8 @@ import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.List;
@Component @Component
@Slf4j @Slf4j
public class MetaEmbeddingListener implements ApplicationListener<DataEvent> { public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
@@ -33,30 +32,15 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
@Async @Async
@Override @Override
public void onApplicationEvent(DataEvent event) { public void onApplicationEvent(DataEvent event) {
if (CollectionUtils.isEmpty(event.getDataItems())) { List<DataItem> dataItems = event.getDataItems();
if (CollectionUtils.isEmpty(dataItems)) {
return; return;
} }
List<EmbeddingQuery> embeddingQueries = EmbeddingQuery.convertToEmbedding(dataItems);
List<EmbeddingQuery> embeddingQueries = event.getDataItems()
.stream()
.map(dataItem -> {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(
dataItem.getId() + dataItem.getType().name().toLowerCase());
embeddingQuery.setQuery(dataItem.getName());
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
embeddingQuery.setMetadata(meta);
embeddingQuery.setQueryEmbedding(null);
return embeddingQuery;
}).collect(Collectors.toList());
if (CollectionUtils.isEmpty(embeddingQueries)) { if (CollectionUtils.isEmpty(embeddingQueries)) {
return; return;
} }
try { sleep();
Thread.sleep(embeddingOperationSleepTime);
} catch (InterruptedException e) {
log.error("", e);
}
s2EmbeddingStore.addCollection(embeddingConfig.getMetaCollectionName()); s2EmbeddingStore.addCollection(embeddingConfig.getMetaCollectionName());
if (event.getEventType().equals(EventType.ADD)) { if (event.getEventType().equals(EventType.ADD)) {
s2EmbeddingStore.addQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries); s2EmbeddingStore.addQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries);
@@ -68,4 +52,12 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
} }
} }
private void sleep() {
try {
Thread.sleep(embeddingOperationSleepTime);
} catch (InterruptedException e) {
log.error("", e);
}
}
} }

View File

@@ -3,25 +3,25 @@ package com.tencent.supersonic.headless.server.rest;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.validation.Valid;
import com.tencent.supersonic.headless.api.pojo.request.DictItemFilter; import com.tencent.supersonic.headless.api.pojo.request.DictItemFilter;
import com.tencent.supersonic.headless.api.pojo.request.DictItemReq; import com.tencent.supersonic.headless.api.pojo.request.DictItemReq;
import com.tencent.supersonic.headless.api.pojo.request.DictSingleTaskReq; import com.tencent.supersonic.headless.api.pojo.request.DictSingleTaskReq;
import com.tencent.supersonic.headless.api.pojo.response.DictItemResp; import com.tencent.supersonic.headless.api.pojo.response.DictItemResp;
import com.tencent.supersonic.headless.api.pojo.response.DictTaskResp; import com.tencent.supersonic.headless.api.pojo.response.DictTaskResp;
import com.tencent.supersonic.headless.server.schedule.EmbeddingTask;
import com.tencent.supersonic.headless.server.service.DictConfService; import com.tencent.supersonic.headless.server.service.DictConfService;
import com.tencent.supersonic.headless.server.service.DictTaskService; import com.tencent.supersonic.headless.server.service.DictTaskService;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.PutMapping; import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.validation.Valid;
import java.util.List; import java.util.List;
@@ -35,6 +35,9 @@ public class KnowledgeController {
@Autowired @Autowired
private DictConfService confService; private DictConfService confService;
@Autowired
private EmbeddingTask embeddingTask;
/** /**
* addDictConf-新增item的字典配置 * addDictConf-新增item的字典配置
* Add configuration information for dictionary entries * Add configuration information for dictionary entries
@@ -129,4 +132,9 @@ public class KnowledgeController {
return taskService.queryLatestDictTask(taskReq, user); return taskService.queryLatestDictTask(taskReq, user);
} }
@GetMapping("/meta/embedding/reload")
public Object reloadMetaEmbedding() {
embeddingTask.reloadMetaEmbedding();
return true;
}
} }

View File

@@ -2,18 +2,11 @@ package com.tencent.supersonic.headless.server.rest.api;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlsReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlsReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.core.chat.corrector.GrammarCorrector; import com.tencent.supersonic.headless.server.service.ChatQueryService;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.server.service.QueryService; import com.tencent.supersonic.headless.server.service.QueryService;
import com.tencent.supersonic.headless.server.service.impl.SemanticService;
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
@@ -36,14 +29,14 @@ public class SqlQueryApiController {
private QueryService queryService; private QueryService queryService;
@Autowired @Autowired
private SemanticService semanticService; private ChatQueryService chatQueryService;
@PostMapping("/sql") @PostMapping("/sql")
public Object queryBySql(@RequestBody QuerySqlReq querySqlReq, public Object queryBySql(@RequestBody QuerySqlReq querySqlReq,
HttpServletRequest request, HttpServletRequest request,
HttpServletResponse response) throws Exception { HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response); User user = UserHolder.findUser(request, response);
correct(querySqlReq); chatQueryService.correct(querySqlReq, user);
return queryService.queryByReq(querySqlReq, user); return queryService.queryByReq(querySqlReq, user);
} }
@@ -57,28 +50,9 @@ public class SqlQueryApiController {
QuerySqlReq querySqlReq = new QuerySqlReq(); QuerySqlReq querySqlReq = new QuerySqlReq();
BeanUtils.copyProperties(querySqlsReq, querySqlReq); BeanUtils.copyProperties(querySqlsReq, querySqlReq);
querySqlReq.setSql(sql); querySqlReq.setSql(sql);
correct(querySqlReq); chatQueryService.correct(querySqlReq, user);
return querySqlReq; return querySqlReq;
}).collect(Collectors.toList()); }).collect(Collectors.toList());
return queryService.queryByReqs(semanticQueryReqs, user); return queryService.queryByReqs(semanticQueryReqs, user);
} }
private void correct(QuerySqlReq querySqlReq) {
QueryContext queryCtx = new QueryContext();
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
queryCtx.setSemanticSchema(semanticSchema);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setCorrectS2SQL(querySqlReq.getSql());
sqlInfo.setS2SQL(querySqlReq.getSql());
semanticParseInfo.setSqlInfo(sqlInfo);
semanticParseInfo.setQueryType(QueryType.TAG);
ComponentFactory.getSemanticCorrectors().forEach(corrector -> {
if (!(corrector instanceof GrammarCorrector)) {
corrector.correct(queryCtx, semanticParseInfo);
}
});
querySqlReq.setSql(sqlInfo.getCorrectS2SQL());
}
} }

View File

@@ -0,0 +1,72 @@
package com.tencent.supersonic.headless.server.schedule;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.DataItem;
import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.MetricService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import javax.annotation.PreDestroy;
import java.util.List;
@Component
@Slf4j
public class EmbeddingTask {
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
@Autowired
private EmbeddingConfig embeddingConfig;
@Autowired
private MetricService metricService;
@Autowired
private DimensionService dimensionService;
@PreDestroy
public void onShutdown() {
embeddingStorePersistentToFile();
}
private void embeddingStorePersistentToFile() {
if (s2EmbeddingStore instanceof InMemoryS2EmbeddingStore) {
log.info("start persistentToFile");
((InMemoryS2EmbeddingStore) s2EmbeddingStore).persistentToFile();
log.info("end persistentToFile");
}
}
@Scheduled(cron = "${inMemoryEmbeddingStore.persistent.cron:0 0 * * * ?}")
public void executeTask() {
embeddingStorePersistentToFile();
}
/***
* reload meta embedding
*/
@Scheduled(cron = "${reload.meta.embedding.corn:0 0 */2 * * ?}")
public void reloadMetaEmbedding() {
log.info("reload.meta.embedding start");
try {
List<DataItem> metricDataItems = metricService.getDataEvent().getDataItems();
s2EmbeddingStore.addQuery(embeddingConfig.getMetaCollectionName(),
EmbeddingQuery.convertToEmbedding(metricDataItems));
List<DataItem> dimensionDataItems = dimensionService.getDataEvent().getDataItems();
s2EmbeddingStore.addQuery(embeddingConfig.getMetaCollectionName(),
EmbeddingQuery.convertToEmbedding(dimensionDataItems));
} catch (Exception e) {
log.error("reload.meta.embedding error", e);
}
log.info("reload.meta.embedding end");
}
}

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
@@ -29,5 +30,7 @@ public interface ChatQueryService {
EntityInfo getEntityInfo(SemanticParseInfo parseInfo, User user); EntityInfo getEntityInfo(SemanticParseInfo parseInfo, User user);
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception; Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
void correct(QuerySqlReq querySqlReq, User user);
} }

View File

@@ -38,4 +38,6 @@ public interface DataSetService {
SemanticQueryReq convert(QueryDataSetReq queryDataSetReq); SemanticQueryReq convert(QueryDataSetReq queryDataSetReq);
Long getDataSetIdFromSql(String sql, User user);
} }

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.server.service;
import com.github.pagehelper.PageInfo; import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.DataEvent;
import com.tencent.supersonic.common.pojo.enums.EventType; import com.tencent.supersonic.common.pojo.enums.EventType;
import com.tencent.supersonic.headless.api.pojo.DimValueMap; import com.tencent.supersonic.headless.api.pojo.DimValueMap;
import com.tencent.supersonic.headless.api.pojo.request.DimensionReq; import com.tencent.supersonic.headless.api.pojo.request.DimensionReq;
@@ -10,6 +11,7 @@ import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.server.pojo.DimensionsFilter; import com.tencent.supersonic.headless.server.pojo.DimensionsFilter;
import com.tencent.supersonic.headless.server.pojo.MetaFilter; import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import java.util.List; import java.util.List;
public interface DimensionService { public interface DimensionService {
@@ -41,4 +43,6 @@ public interface DimensionService {
List<DimValueMap> mockDimensionValueAlias(DimensionReq dimensionReq, User user); List<DimValueMap> mockDimensionValueAlias(DimensionReq dimensionReq, User user);
void sendDimensionEventBatch(List<Long> modelIds, EventType eventType); void sendDimensionEventBatch(List<Long> modelIds, EventType eventType);
DataEvent getDataEvent();
} }

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.server.service;
import com.github.pagehelper.PageInfo; import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.DataEvent;
import com.tencent.supersonic.common.pojo.enums.EventType; import com.tencent.supersonic.common.pojo.enums.EventType;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension; import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.MetricQueryDefaultConfig; import com.tencent.supersonic.headless.api.pojo.MetricQueryDefaultConfig;
@@ -14,6 +15,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp; import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.server.pojo.MetaFilter; import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.pojo.MetricsFilter; import com.tencent.supersonic.headless.server.pojo.MetricsFilter;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@@ -60,4 +62,6 @@ public interface MetricService {
List<MetricResp> queryMetrics(MetricsFilter metricsFilter); List<MetricResp> queryMetrics(MetricsFilter metricsFilter);
QueryStructReq convert(QueryMetricReq queryMetricReq); QueryStructReq convert(QueryMetricReq queryMetricReq);
DataEvent getDataEvent();
} }

View File

@@ -21,5 +21,4 @@ public interface QueryService {
List<ItemUseResp> getStatInfo(ItemUseReq itemUseCommend); List<ItemUseResp> getStatInfo(ItemUseReq itemUseCommend);
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception; <T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
} }

View File

@@ -20,6 +20,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.enums.CostType; import com.tencent.supersonic.headless.api.pojo.enums.CostType;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod; import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
@@ -29,6 +30,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp; import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
@@ -37,6 +39,7 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.core.chat.corrector.GrammarCorrector;
import com.tencent.supersonic.headless.core.chat.corrector.SemanticCorrector; import com.tencent.supersonic.headless.core.chat.corrector.SemanticCorrector;
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult; import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService; import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService;
@@ -647,5 +650,30 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return queryService.queryByReq(queryStructReq, user); return queryService.queryByReq(queryStructReq, user);
} }
public void correct(QuerySqlReq querySqlReq, User user) {
QueryContext queryCtx = new QueryContext();
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
queryCtx.setSemanticSchema(semanticSchema);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setCorrectS2SQL(querySqlReq.getSql());
sqlInfo.setS2SQL(querySqlReq.getSql());
semanticParseInfo.setSqlInfo(sqlInfo);
semanticParseInfo.setQueryType(QueryType.TAG);
Long dataSetId = querySqlReq.getDataSetId();
if (Objects.isNull(dataSetId)) {
dataSetId = dataSetService.getDataSetIdFromSql(querySqlReq.getSql(), user);
}
SchemaElement dataSet = semanticSchema.getDataSet(dataSetId);
semanticParseInfo.setDataSet(dataSet);
ComponentFactory.getSemanticCorrectors().forEach(corrector -> {
if (!(corrector instanceof GrammarCorrector)) {
corrector.correct(queryCtx, semanticParseInfo);
}
});
querySqlReq.setSql(sqlInfo.getCorrectS2SQL());
} }
}

View File

@@ -13,6 +13,7 @@ import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.DataSetDetail; import com.tencent.supersonic.headless.api.pojo.DataSetDetail;
import com.tencent.supersonic.headless.api.pojo.QueryConfig; import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType; import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType;
@@ -34,6 +35,15 @@ import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.DomainService; import com.tencent.supersonic.headless.server.service.DomainService;
import com.tencent.supersonic.headless.server.service.MetricService; import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.TagMetaService; import com.tencent.supersonic.headless.server.service.TagMetaService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.Arrays; import java.util.Arrays;
import java.util.Comparator; import java.util.Comparator;
import java.util.Date; import java.util.Date;
@@ -45,15 +55,9 @@ import java.util.Set;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Service @Service
@Slf4j
public class DataSetServiceImpl public class DataSetServiceImpl
extends ServiceImpl<DataSetDOMapper, DataSetDO> implements DataSetService { extends ServiceImpl<DataSetDOMapper, DataSetDO> implements DataSetService {
@@ -311,4 +315,21 @@ public class DataSetServiceImpl
.map(Object::toString) .map(Object::toString)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
public Long getDataSetIdFromSql(String sql, User user) {
List<DataSetResp> dataSets = null;
try {
String tableName = SqlSelectHelper.getTableName(sql);
dataSets = getDataSets(tableName, user);
} catch (Exception e) {
log.error("getDataSetIdFromSql error:{}", e);
}
if (org.apache.commons.collections.CollectionUtils.isEmpty(dataSets)) {
throw new InvalidArgumentException("从Sql参数中无法获取到DataSetId");
}
Long dataSetId = dataSets.get(0).getId();
log.info("getDataSetIdFromSql dataSetId:{}", dataSetId);
return dataSetId;
}
} }

View File

@@ -430,17 +430,27 @@ public class DimensionServiceImpl implements DimensionService {
} }
private void sendEventBatch(List<DimensionDO> dimensionDOS, EventType eventType) { private void sendEventBatch(List<DimensionDO> dimensionDOS, EventType eventType) {
DataEvent dataEvent = getDataEvent(dimensionDOS, eventType);
eventPublisher.publishEvent(dataEvent);
}
public DataEvent getDataEvent() {
DimensionFilter dimensionFilter = new DimensionFilter();
List<DimensionDO> dimensionDOS = queryDimension(dimensionFilter);
return getDataEvent(dimensionDOS, EventType.ADD);
}
private DataEvent getDataEvent(List<DimensionDO> dimensionDOS, EventType eventType) {
List<DataItem> dataItems = dimensionDOS.stream() List<DataItem> dataItems = dimensionDOS.stream()
.map(dimensionDO -> DataItem.builder().id(dimensionDO.getId() + Constants.UNDERLINE) .map(dimensionDO -> DataItem.builder().id(dimensionDO.getId() + Constants.UNDERLINE)
.name(dimensionDO.getName()).modelId(dimensionDO.getModelId() + Constants.UNDERLINE) .name(dimensionDO.getName()).modelId(dimensionDO.getModelId() + Constants.UNDERLINE)
.type(TypeEnums.DIMENSION).build()) .type(TypeEnums.DIMENSION).build())
.collect(Collectors.toList()); .collect(Collectors.toList());
eventPublisher.publishEvent(new DataEvent(this, dataItems, eventType)); return new DataEvent(this, dataItems, eventType);
} }
private void sendEvent(DataItem dataItem, EventType eventType) { private void sendEvent(DataItem dataItem, EventType eventType) {
eventPublisher.publishEvent(new DataEvent(this, eventPublisher.publishEvent(new DataEvent(this,
Lists.newArrayList(dataItem), eventType)); Lists.newArrayList(dataItem), eventType));
} }
} }

View File

@@ -556,9 +556,20 @@ public class MetricServiceImpl implements MetricService {
} }
private void sendEventBatch(List<MetricDO> metricDOS, EventType eventType) { private void sendEventBatch(List<MetricDO> metricDOS, EventType eventType) {
DataEvent dataEvent = getDataEvent(metricDOS, eventType);
eventPublisher.publishEvent(dataEvent);
}
public DataEvent getDataEvent() {
MetricsFilter metricsFilter = new MetricsFilter();
List<MetricDO> metricDOS = metricRepository.getMetrics(metricsFilter);
return getDataEvent(metricDOS, EventType.ADD);
}
private DataEvent getDataEvent(List<MetricDO> metricDOS, EventType eventType) {
List<DataItem> dataItems = metricDOS.stream().map(this::getDataItem) List<DataItem> dataItems = metricDOS.stream().map(this::getDataItem)
.collect(Collectors.toList()); .collect(Collectors.toList());
eventPublisher.publishEvent(new DataEvent(this, dataItems, eventType)); return new DataEvent(this, dataItems, eventType);
} }
private void sendEvent(DataItem dataItem, EventType eventType) { private void sendEvent(DataItem dataItem, EventType eventType) {

View File

@@ -5,8 +5,6 @@ import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum; import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.Dim; import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.QueryParam; import com.tencent.supersonic.headless.api.pojo.QueryParam;
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq; import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
@@ -17,7 +15,6 @@ import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq; import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp; import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemUseResp; import com.tencent.supersonic.headless.api.pojo.response.ItemUseResp;
@@ -137,8 +134,8 @@ public class QueryServiceImpl implements QueryService {
private QueryStatement buildSqlQueryStatement(QuerySqlReq querySqlReq, User user) throws Exception { private QueryStatement buildSqlQueryStatement(QuerySqlReq querySqlReq, User user) throws Exception {
//If dataSetId or DataSetName is empty, parse dataSetId from the SQL //If dataSetId or DataSetName is empty, parse dataSetId from the SQL
if (needGetDataSetId(querySqlReq)) { if (querySqlReq.needGetDataSetId()) {
Long dataSetId = getDataSetIdFromSql(querySqlReq, user); Long dataSetId = dataSetService.getDataSetIdFromSql(querySqlReq.getSql(), user);
querySqlReq.setDataSetId(dataSetId); querySqlReq.setDataSetId(dataSetId);
} }
SchemaFilterReq filter = buildSchemaFilterReq(querySqlReq); SchemaFilterReq filter = buildSchemaFilterReq(querySqlReq);
@@ -151,27 +148,6 @@ public class QueryServiceImpl implements QueryService {
return queryStatement; return queryStatement;
} }
private static boolean needGetDataSetId(QuerySqlReq querySqlReq) {
return (Objects.isNull(querySqlReq.getDataSetId()) || querySqlReq.getDataSetId() <= 0)
&& (CollectionUtils.isEmpty(querySqlReq.getModelIds()));
}
private Long getDataSetIdFromSql(QuerySqlReq querySqlReq, User user) {
List<DataSetResp> dataSets = null;
try {
String tableName = SqlSelectHelper.getTableName(querySqlReq.getSql());
dataSets = dataSetService.getDataSets(tableName, user);
} catch (Exception e) {
log.error("getDataSetIdFromSql error:{}", e);
}
if (CollectionUtils.isEmpty(dataSets)) {
throw new InvalidArgumentException("从Sql参数中无法获取到DataSetId");
}
Long dataSetId = dataSets.get(0).getId();
log.info("getDataSetIdFromSql dataSetId:{}", dataSetId);
return dataSetId;
}
private QueryStatement buildQueryStatement(SemanticQueryReq semanticQueryReq, User user) throws Exception { private QueryStatement buildQueryStatement(SemanticQueryReq semanticQueryReq, User user) throws Exception {
if (semanticQueryReq instanceof QuerySqlReq) { if (semanticQueryReq instanceof QuerySqlReq) {
return buildSqlQueryStatement((QuerySqlReq) semanticQueryReq, user); return buildSqlQueryStatement((QuerySqlReq) semanticQueryReq, user);