(improvement)(Headless) Periodically refresh embedding metadata in full and optimize the code. (#917)

This commit is contained in:
lexluo09
2024-04-17 22:38:19 +08:00
committed by GitHub
parent d8c23cca05
commit ee798b7671
17 changed files with 237 additions and 142 deletions

View File

@@ -1,32 +0,0 @@
package com.tencent.supersonic.common.util.embedding;
import com.tencent.supersonic.common.util.ComponentFactory;
import javax.annotation.PreDestroy;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class EmbeddingPersistentTask {
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
@PreDestroy
public void onShutdown() {
embeddingStorePersistentToFile();
}
private void embeddingStorePersistentToFile() {
if (s2EmbeddingStore instanceof InMemoryS2EmbeddingStore) {
log.info("start persistentToFile");
((InMemoryS2EmbeddingStore) s2EmbeddingStore).persistentToFile();
log.info("end persistentToFile");
}
}
@Scheduled(cron = "${inMemoryEmbeddingStore.persistent.cron:0 0 * * * ?}")
public void executeTask() {
embeddingStorePersistentToFile();
}
}

View File

@@ -1,10 +1,13 @@
package com.tencent.supersonic.common.util.embedding; package com.tencent.supersonic.common.util.embedding;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.common.pojo.DataItem;
import lombok.Data; import lombok.Data;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors;
@Data @Data
public class EmbeddingQuery { public class EmbeddingQuery {
@@ -17,4 +20,17 @@ public class EmbeddingQuery {
private List<Double> queryEmbedding; private List<Double> queryEmbedding;
public static List<EmbeddingQuery> convertToEmbedding(List<DataItem> dataItems) {
return dataItems.stream().map(dataItem -> {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(
dataItem.getId() + dataItem.getType().name().toLowerCase());
embeddingQuery.setQuery(dataItem.getName());
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
embeddingQuery.setMetadata(meta);
embeddingQuery.setQueryEmbedding(null);
return embeddingQuery;
}).collect(Collectors.toList());
}
} }

View File

@@ -2,6 +2,9 @@ package com.tencent.supersonic.headless.api.pojo.request;
import lombok.Data; import lombok.Data;
import lombok.ToString; import lombok.ToString;
import org.apache.commons.collections.CollectionUtils;
import java.util.Objects;
@Data @Data
@ToString @ToString
@@ -25,4 +28,8 @@ public class QuerySqlReq extends SemanticQueryReq {
return stringBuilder.toString(); return stringBuilder.toString();
} }
public boolean needGetDataSetId() {
return (Objects.isNull(this.getDataSetId()) || this.getDataSetId() <= 0)
&& (CollectionUtils.isEmpty(this.getModelIds()));
}
} }

View File

@@ -1,15 +1,12 @@
package com.tencent.supersonic.headless.server.listener; package com.tencent.supersonic.headless.server.listener;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.DataEvent; import com.tencent.supersonic.common.pojo.DataEvent;
import com.tencent.supersonic.common.pojo.DataItem;
import com.tencent.supersonic.common.pojo.enums.EventType; import com.tencent.supersonic.common.pojo.enums.EventType;
import com.tencent.supersonic.common.util.ComponentFactory; import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery; import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore; import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
@@ -18,6 +15,8 @@ import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.List;
@Component @Component
@Slf4j @Slf4j
public class MetaEmbeddingListener implements ApplicationListener<DataEvent> { public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
@@ -33,30 +32,15 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
@Async @Async
@Override @Override
public void onApplicationEvent(DataEvent event) { public void onApplicationEvent(DataEvent event) {
if (CollectionUtils.isEmpty(event.getDataItems())) { List<DataItem> dataItems = event.getDataItems();
if (CollectionUtils.isEmpty(dataItems)) {
return; return;
} }
List<EmbeddingQuery> embeddingQueries = EmbeddingQuery.convertToEmbedding(dataItems);
List<EmbeddingQuery> embeddingQueries = event.getDataItems()
.stream()
.map(dataItem -> {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(
dataItem.getId() + dataItem.getType().name().toLowerCase());
embeddingQuery.setQuery(dataItem.getName());
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
embeddingQuery.setMetadata(meta);
embeddingQuery.setQueryEmbedding(null);
return embeddingQuery;
}).collect(Collectors.toList());
if (CollectionUtils.isEmpty(embeddingQueries)) { if (CollectionUtils.isEmpty(embeddingQueries)) {
return; return;
} }
try { sleep();
Thread.sleep(embeddingOperationSleepTime);
} catch (InterruptedException e) {
log.error("", e);
}
s2EmbeddingStore.addCollection(embeddingConfig.getMetaCollectionName()); s2EmbeddingStore.addCollection(embeddingConfig.getMetaCollectionName());
if (event.getEventType().equals(EventType.ADD)) { if (event.getEventType().equals(EventType.ADD)) {
s2EmbeddingStore.addQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries); s2EmbeddingStore.addQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries);
@@ -68,4 +52,12 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
} }
} }
private void sleep() {
try {
Thread.sleep(embeddingOperationSleepTime);
} catch (InterruptedException e) {
log.error("", e);
}
}
} }

View File

@@ -3,25 +3,25 @@ package com.tencent.supersonic.headless.server.rest;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.validation.Valid;
import com.tencent.supersonic.headless.api.pojo.request.DictItemFilter; import com.tencent.supersonic.headless.api.pojo.request.DictItemFilter;
import com.tencent.supersonic.headless.api.pojo.request.DictItemReq; import com.tencent.supersonic.headless.api.pojo.request.DictItemReq;
import com.tencent.supersonic.headless.api.pojo.request.DictSingleTaskReq; import com.tencent.supersonic.headless.api.pojo.request.DictSingleTaskReq;
import com.tencent.supersonic.headless.api.pojo.response.DictItemResp; import com.tencent.supersonic.headless.api.pojo.response.DictItemResp;
import com.tencent.supersonic.headless.api.pojo.response.DictTaskResp; import com.tencent.supersonic.headless.api.pojo.response.DictTaskResp;
import com.tencent.supersonic.headless.server.schedule.EmbeddingTask;
import com.tencent.supersonic.headless.server.service.DictConfService; import com.tencent.supersonic.headless.server.service.DictConfService;
import com.tencent.supersonic.headless.server.service.DictTaskService; import com.tencent.supersonic.headless.server.service.DictTaskService;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.PutMapping; import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.validation.Valid;
import java.util.List; import java.util.List;
@@ -35,6 +35,9 @@ public class KnowledgeController {
@Autowired @Autowired
private DictConfService confService; private DictConfService confService;
@Autowired
private EmbeddingTask embeddingTask;
/** /**
* addDictConf-新增item的字典配置 * addDictConf-新增item的字典配置
* Add configuration information for dictionary entries * Add configuration information for dictionary entries
@@ -43,8 +46,8 @@ public class KnowledgeController {
*/ */
@PostMapping("/conf") @PostMapping("/conf")
public DictItemResp addDictConf(@RequestBody @Valid DictItemReq dictItemReq, public DictItemResp addDictConf(@RequestBody @Valid DictItemReq dictItemReq,
HttpServletRequest request, HttpServletRequest request,
HttpServletResponse response) { HttpServletResponse response) {
User user = UserHolder.findUser(request, response); User user = UserHolder.findUser(request, response);
return confService.addDictConf(dictItemReq, user); return confService.addDictConf(dictItemReq, user);
} }
@@ -57,8 +60,8 @@ public class KnowledgeController {
*/ */
@PutMapping("/conf") @PutMapping("/conf")
public DictItemResp editDictConf(@RequestBody @Valid DictItemReq dictItemReq, public DictItemResp editDictConf(@RequestBody @Valid DictItemReq dictItemReq,
HttpServletRequest request, HttpServletRequest request,
HttpServletResponse response) { HttpServletResponse response) {
User user = UserHolder.findUser(request, response); User user = UserHolder.findUser(request, response);
return confService.editDictConf(dictItemReq, user); return confService.editDictConf(dictItemReq, user);
} }
@@ -129,4 +132,9 @@ public class KnowledgeController {
return taskService.queryLatestDictTask(taskReq, user); return taskService.queryLatestDictTask(taskReq, user);
} }
@GetMapping("/meta/embedding/reload")
public Object reloadMetaEmbedding() {
embeddingTask.reloadMetaEmbedding();
return true;
}
} }

View File

@@ -2,18 +2,11 @@ package com.tencent.supersonic.headless.server.rest.api;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlsReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlsReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.core.chat.corrector.GrammarCorrector; import com.tencent.supersonic.headless.server.service.ChatQueryService;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.server.service.QueryService; import com.tencent.supersonic.headless.server.service.QueryService;
import com.tencent.supersonic.headless.server.service.impl.SemanticService;
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
@@ -36,14 +29,14 @@ public class SqlQueryApiController {
private QueryService queryService; private QueryService queryService;
@Autowired @Autowired
private SemanticService semanticService; private ChatQueryService chatQueryService;
@PostMapping("/sql") @PostMapping("/sql")
public Object queryBySql(@RequestBody QuerySqlReq querySqlReq, public Object queryBySql(@RequestBody QuerySqlReq querySqlReq,
HttpServletRequest request, HttpServletRequest request,
HttpServletResponse response) throws Exception { HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response); User user = UserHolder.findUser(request, response);
correct(querySqlReq); chatQueryService.correct(querySqlReq, user);
return queryService.queryByReq(querySqlReq, user); return queryService.queryByReq(querySqlReq, user);
} }
@@ -57,28 +50,9 @@ public class SqlQueryApiController {
QuerySqlReq querySqlReq = new QuerySqlReq(); QuerySqlReq querySqlReq = new QuerySqlReq();
BeanUtils.copyProperties(querySqlsReq, querySqlReq); BeanUtils.copyProperties(querySqlsReq, querySqlReq);
querySqlReq.setSql(sql); querySqlReq.setSql(sql);
correct(querySqlReq); chatQueryService.correct(querySqlReq, user);
return querySqlReq; return querySqlReq;
}).collect(Collectors.toList()); }).collect(Collectors.toList());
return queryService.queryByReqs(semanticQueryReqs, user); return queryService.queryByReqs(semanticQueryReqs, user);
} }
private void correct(QuerySqlReq querySqlReq) {
QueryContext queryCtx = new QueryContext();
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
queryCtx.setSemanticSchema(semanticSchema);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setCorrectS2SQL(querySqlReq.getSql());
sqlInfo.setS2SQL(querySqlReq.getSql());
semanticParseInfo.setSqlInfo(sqlInfo);
semanticParseInfo.setQueryType(QueryType.TAG);
ComponentFactory.getSemanticCorrectors().forEach(corrector -> {
if (!(corrector instanceof GrammarCorrector)) {
corrector.correct(queryCtx, semanticParseInfo);
}
});
querySqlReq.setSql(sqlInfo.getCorrectS2SQL());
}
} }

View File

@@ -0,0 +1,72 @@
package com.tencent.supersonic.headless.server.schedule;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.DataItem;
import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.MetricService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import javax.annotation.PreDestroy;
import java.util.List;
@Component
@Slf4j
public class EmbeddingTask {
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
@Autowired
private EmbeddingConfig embeddingConfig;
@Autowired
private MetricService metricService;
@Autowired
private DimensionService dimensionService;
@PreDestroy
public void onShutdown() {
embeddingStorePersistentToFile();
}
private void embeddingStorePersistentToFile() {
if (s2EmbeddingStore instanceof InMemoryS2EmbeddingStore) {
log.info("start persistentToFile");
((InMemoryS2EmbeddingStore) s2EmbeddingStore).persistentToFile();
log.info("end persistentToFile");
}
}
@Scheduled(cron = "${inMemoryEmbeddingStore.persistent.cron:0 0 * * * ?}")
public void executeTask() {
embeddingStorePersistentToFile();
}
/***
* reload meta embedding
*/
@Scheduled(cron = "${reload.meta.embedding.corn:0 0 */2 * * ?}")
public void reloadMetaEmbedding() {
log.info("reload.meta.embedding start");
try {
List<DataItem> metricDataItems = metricService.getDataEvent().getDataItems();
s2EmbeddingStore.addQuery(embeddingConfig.getMetaCollectionName(),
EmbeddingQuery.convertToEmbedding(metricDataItems));
List<DataItem> dimensionDataItems = dimensionService.getDataEvent().getDataItems();
s2EmbeddingStore.addQuery(embeddingConfig.getMetaCollectionName(),
EmbeddingQuery.convertToEmbedding(dimensionDataItems));
} catch (Exception e) {
log.error("reload.meta.embedding error", e);
}
log.info("reload.meta.embedding end");
}
}

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
@@ -29,5 +30,7 @@ public interface ChatQueryService {
EntityInfo getEntityInfo(SemanticParseInfo parseInfo, User user); EntityInfo getEntityInfo(SemanticParseInfo parseInfo, User user);
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception; Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
void correct(QuerySqlReq querySqlReq, User user);
} }

View File

@@ -38,4 +38,6 @@ public interface DataSetService {
SemanticQueryReq convert(QueryDataSetReq queryDataSetReq); SemanticQueryReq convert(QueryDataSetReq queryDataSetReq);
Long getDataSetIdFromSql(String sql, User user);
} }

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.server.service;
import com.github.pagehelper.PageInfo; import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.DataEvent;
import com.tencent.supersonic.common.pojo.enums.EventType; import com.tencent.supersonic.common.pojo.enums.EventType;
import com.tencent.supersonic.headless.api.pojo.DimValueMap; import com.tencent.supersonic.headless.api.pojo.DimValueMap;
import com.tencent.supersonic.headless.api.pojo.request.DimensionReq; import com.tencent.supersonic.headless.api.pojo.request.DimensionReq;
@@ -10,6 +11,7 @@ import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.server.pojo.DimensionsFilter; import com.tencent.supersonic.headless.server.pojo.DimensionsFilter;
import com.tencent.supersonic.headless.server.pojo.MetaFilter; import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import java.util.List; import java.util.List;
public interface DimensionService { public interface DimensionService {
@@ -41,4 +43,6 @@ public interface DimensionService {
List<DimValueMap> mockDimensionValueAlias(DimensionReq dimensionReq, User user); List<DimValueMap> mockDimensionValueAlias(DimensionReq dimensionReq, User user);
void sendDimensionEventBatch(List<Long> modelIds, EventType eventType); void sendDimensionEventBatch(List<Long> modelIds, EventType eventType);
DataEvent getDataEvent();
} }

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.server.service;
import com.github.pagehelper.PageInfo; import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.DataEvent;
import com.tencent.supersonic.common.pojo.enums.EventType; import com.tencent.supersonic.common.pojo.enums.EventType;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension; import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.MetricQueryDefaultConfig; import com.tencent.supersonic.headless.api.pojo.MetricQueryDefaultConfig;
@@ -14,6 +15,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp; import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.server.pojo.MetaFilter; import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.pojo.MetricsFilter; import com.tencent.supersonic.headless.server.pojo.MetricsFilter;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@@ -60,4 +62,6 @@ public interface MetricService {
List<MetricResp> queryMetrics(MetricsFilter metricsFilter); List<MetricResp> queryMetrics(MetricsFilter metricsFilter);
QueryStructReq convert(QueryMetricReq queryMetricReq); QueryStructReq convert(QueryMetricReq queryMetricReq);
DataEvent getDataEvent();
} }

View File

@@ -21,5 +21,4 @@ public interface QueryService {
List<ItemUseResp> getStatInfo(ItemUseReq itemUseCommend); List<ItemUseResp> getStatInfo(ItemUseReq itemUseCommend);
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception; <T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
} }

View File

@@ -20,6 +20,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.enums.CostType; import com.tencent.supersonic.headless.api.pojo.enums.CostType;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod; import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
@@ -29,6 +30,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp; import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
@@ -37,6 +39,7 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.core.chat.corrector.GrammarCorrector;
import com.tencent.supersonic.headless.core.chat.corrector.SemanticCorrector; import com.tencent.supersonic.headless.core.chat.corrector.SemanticCorrector;
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult; import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService; import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService;
@@ -647,5 +650,30 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return queryService.queryByReq(queryStructReq, user); return queryService.queryByReq(queryStructReq, user);
} }
} public void correct(QuerySqlReq querySqlReq, User user) {
QueryContext queryCtx = new QueryContext();
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
queryCtx.setSemanticSchema(semanticSchema);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setCorrectS2SQL(querySqlReq.getSql());
sqlInfo.setS2SQL(querySqlReq.getSql());
semanticParseInfo.setSqlInfo(sqlInfo);
semanticParseInfo.setQueryType(QueryType.TAG);
Long dataSetId = querySqlReq.getDataSetId();
if (Objects.isNull(dataSetId)) {
dataSetId = dataSetService.getDataSetIdFromSql(querySqlReq.getSql(), user);
}
SchemaElement dataSet = semanticSchema.getDataSet(dataSetId);
semanticParseInfo.setDataSet(dataSet);
ComponentFactory.getSemanticCorrectors().forEach(corrector -> {
if (!(corrector instanceof GrammarCorrector)) {
corrector.correct(queryCtx, semanticParseInfo);
}
});
querySqlReq.setSql(sqlInfo.getCorrectS2SQL());
}
}

View File

@@ -13,6 +13,7 @@ import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.DataSetDetail; import com.tencent.supersonic.headless.api.pojo.DataSetDetail;
import com.tencent.supersonic.headless.api.pojo.QueryConfig; import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType; import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType;
@@ -34,6 +35,15 @@ import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.DomainService; import com.tencent.supersonic.headless.server.service.DomainService;
import com.tencent.supersonic.headless.server.service.MetricService; import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.TagMetaService; import com.tencent.supersonic.headless.server.service.TagMetaService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.Arrays; import java.util.Arrays;
import java.util.Comparator; import java.util.Comparator;
import java.util.Date; import java.util.Date;
@@ -45,15 +55,9 @@ import java.util.Set;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Service @Service
@Slf4j
public class DataSetServiceImpl public class DataSetServiceImpl
extends ServiceImpl<DataSetDOMapper, DataSetDO> implements DataSetService { extends ServiceImpl<DataSetDOMapper, DataSetDO> implements DataSetService {
@@ -311,4 +315,21 @@ public class DataSetServiceImpl
.map(Object::toString) .map(Object::toString)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
public Long getDataSetIdFromSql(String sql, User user) {
List<DataSetResp> dataSets = null;
try {
String tableName = SqlSelectHelper.getTableName(sql);
dataSets = getDataSets(tableName, user);
} catch (Exception e) {
log.error("getDataSetIdFromSql error:{}", e);
}
if (org.apache.commons.collections.CollectionUtils.isEmpty(dataSets)) {
throw new InvalidArgumentException("从Sql参数中无法获取到DataSetId");
}
Long dataSetId = dataSets.get(0).getId();
log.info("getDataSetIdFromSql dataSetId:{}", dataSetId);
return dataSetId;
}
} }

View File

@@ -430,17 +430,27 @@ public class DimensionServiceImpl implements DimensionService {
} }
private void sendEventBatch(List<DimensionDO> dimensionDOS, EventType eventType) { private void sendEventBatch(List<DimensionDO> dimensionDOS, EventType eventType) {
DataEvent dataEvent = getDataEvent(dimensionDOS, eventType);
eventPublisher.publishEvent(dataEvent);
}
public DataEvent getDataEvent() {
DimensionFilter dimensionFilter = new DimensionFilter();
List<DimensionDO> dimensionDOS = queryDimension(dimensionFilter);
return getDataEvent(dimensionDOS, EventType.ADD);
}
private DataEvent getDataEvent(List<DimensionDO> dimensionDOS, EventType eventType) {
List<DataItem> dataItems = dimensionDOS.stream() List<DataItem> dataItems = dimensionDOS.stream()
.map(dimensionDO -> DataItem.builder().id(dimensionDO.getId() + Constants.UNDERLINE) .map(dimensionDO -> DataItem.builder().id(dimensionDO.getId() + Constants.UNDERLINE)
.name(dimensionDO.getName()).modelId(dimensionDO.getModelId() + Constants.UNDERLINE) .name(dimensionDO.getName()).modelId(dimensionDO.getModelId() + Constants.UNDERLINE)
.type(TypeEnums.DIMENSION).build()) .type(TypeEnums.DIMENSION).build())
.collect(Collectors.toList()); .collect(Collectors.toList());
eventPublisher.publishEvent(new DataEvent(this, dataItems, eventType)); return new DataEvent(this, dataItems, eventType);
} }
private void sendEvent(DataItem dataItem, EventType eventType) { private void sendEvent(DataItem dataItem, EventType eventType) {
eventPublisher.publishEvent(new DataEvent(this, eventPublisher.publishEvent(new DataEvent(this,
Lists.newArrayList(dataItem), eventType)); Lists.newArrayList(dataItem), eventType));
} }
} }

View File

@@ -97,13 +97,13 @@ public class MetricServiceImpl implements MetricService {
private TagMetaService tagMetaService; private TagMetaService tagMetaService;
public MetricServiceImpl(MetricRepository metricRepository, public MetricServiceImpl(MetricRepository metricRepository,
ModelService modelService, ModelService modelService,
ChatGptHelper chatGptHelper, ChatGptHelper chatGptHelper,
CollectService collectService, CollectService collectService,
DataSetService dataSetService, DataSetService dataSetService,
ApplicationEventPublisher eventPublisher, ApplicationEventPublisher eventPublisher,
DimensionService dimensionService, DimensionService dimensionService,
TagMetaService tagMetaService) { TagMetaService tagMetaService) {
this.metricRepository = metricRepository; this.metricRepository = metricRepository;
this.modelService = modelService; this.modelService = modelService;
this.chatGptHelper = chatGptHelper; this.chatGptHelper = chatGptHelper;
@@ -326,7 +326,7 @@ public class MetricServiceImpl implements MetricService {
} }
private boolean filterByField(List<MetricResp> metricResps, MetricResp metricResp, private boolean filterByField(List<MetricResp> metricResps, MetricResp metricResp,
List<String> fields, Set<MetricResp> metricRespFiltered) { List<String> fields, Set<MetricResp> metricRespFiltered) {
if (MetricDefineType.METRIC.equals(metricResp.getMetricDefineType())) { if (MetricDefineType.METRIC.equals(metricResp.getMetricDefineType())) {
List<Long> ids = metricResp.getMetricDefineByMetricParams().getMetrics() List<Long> ids = metricResp.getMetricDefineByMetricParams().getMetrics()
.stream().map(MetricParam::getId).collect(Collectors.toList()); .stream().map(MetricParam::getId).collect(Collectors.toList());
@@ -556,9 +556,20 @@ public class MetricServiceImpl implements MetricService {
} }
private void sendEventBatch(List<MetricDO> metricDOS, EventType eventType) { private void sendEventBatch(List<MetricDO> metricDOS, EventType eventType) {
DataEvent dataEvent = getDataEvent(metricDOS, eventType);
eventPublisher.publishEvent(dataEvent);
}
public DataEvent getDataEvent() {
MetricsFilter metricsFilter = new MetricsFilter();
List<MetricDO> metricDOS = metricRepository.getMetrics(metricsFilter);
return getDataEvent(metricDOS, EventType.ADD);
}
private DataEvent getDataEvent(List<MetricDO> metricDOS, EventType eventType) {
List<DataItem> dataItems = metricDOS.stream().map(this::getDataItem) List<DataItem> dataItems = metricDOS.stream().map(this::getDataItem)
.collect(Collectors.toList()); .collect(Collectors.toList());
eventPublisher.publishEvent(new DataEvent(this, dataItems, eventType)); return new DataEvent(this, dataItems, eventType);
} }
private void sendEvent(DataItem dataItem, EventType eventType) { private void sendEvent(DataItem dataItem, EventType eventType) {
@@ -662,7 +673,7 @@ public class MetricServiceImpl implements MetricService {
} }
private Set<Long> getModelIds(Set<Long> modelIdsByDomainId, List<MetricResp> metricResps, private Set<Long> getModelIds(Set<Long> modelIdsByDomainId, List<MetricResp> metricResps,
List<DimensionResp> dimensionResps) { List<DimensionResp> dimensionResps) {
Set<Long> result = new HashSet<>(); Set<Long> result = new HashSet<>();
if (org.apache.commons.collections.CollectionUtils.isNotEmpty(modelIdsByDomainId)) { if (org.apache.commons.collections.CollectionUtils.isNotEmpty(modelIdsByDomainId)) {
result.addAll(modelIdsByDomainId); result.addAll(modelIdsByDomainId);

View File

@@ -5,8 +5,6 @@ import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum; import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.Dim; import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.QueryParam; import com.tencent.supersonic.headless.api.pojo.QueryParam;
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq; import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
@@ -17,7 +15,6 @@ import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq; import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp; import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemUseResp; import com.tencent.supersonic.headless.api.pojo.response.ItemUseResp;
@@ -137,8 +134,8 @@ public class QueryServiceImpl implements QueryService {
private QueryStatement buildSqlQueryStatement(QuerySqlReq querySqlReq, User user) throws Exception { private QueryStatement buildSqlQueryStatement(QuerySqlReq querySqlReq, User user) throws Exception {
//If dataSetId or DataSetName is empty, parse dataSetId from the SQL //If dataSetId or DataSetName is empty, parse dataSetId from the SQL
if (needGetDataSetId(querySqlReq)) { if (querySqlReq.needGetDataSetId()) {
Long dataSetId = getDataSetIdFromSql(querySqlReq, user); Long dataSetId = dataSetService.getDataSetIdFromSql(querySqlReq.getSql(), user);
querySqlReq.setDataSetId(dataSetId); querySqlReq.setDataSetId(dataSetId);
} }
SchemaFilterReq filter = buildSchemaFilterReq(querySqlReq); SchemaFilterReq filter = buildSchemaFilterReq(querySqlReq);
@@ -151,27 +148,6 @@ public class QueryServiceImpl implements QueryService {
return queryStatement; return queryStatement;
} }
private static boolean needGetDataSetId(QuerySqlReq querySqlReq) {
return (Objects.isNull(querySqlReq.getDataSetId()) || querySqlReq.getDataSetId() <= 0)
&& (CollectionUtils.isEmpty(querySqlReq.getModelIds()));
}
private Long getDataSetIdFromSql(QuerySqlReq querySqlReq, User user) {
List<DataSetResp> dataSets = null;
try {
String tableName = SqlSelectHelper.getTableName(querySqlReq.getSql());
dataSets = dataSetService.getDataSets(tableName, user);
} catch (Exception e) {
log.error("getDataSetIdFromSql error:{}", e);
}
if (CollectionUtils.isEmpty(dataSets)) {
throw new InvalidArgumentException("从Sql参数中无法获取到DataSetId");
}
Long dataSetId = dataSets.get(0).getId();
log.info("getDataSetIdFromSql dataSetId:{}", dataSetId);
return dataSetId;
}
private QueryStatement buildQueryStatement(SemanticQueryReq semanticQueryReq, User user) throws Exception { private QueryStatement buildQueryStatement(SemanticQueryReq semanticQueryReq, User user) throws Exception {
if (semanticQueryReq instanceof QuerySqlReq) { if (semanticQueryReq instanceof QuerySqlReq) {
return buildSqlQueryStatement((QuerySqlReq) semanticQueryReq, user); return buildSqlQueryStatement((QuerySqlReq) semanticQueryReq, user);