mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(Headless) Periodically refresh embedding metadata in full and optimize the code. (#917)
This commit is contained in:
@@ -1,32 +0,0 @@
|
||||
package com.tencent.supersonic.common.util.embedding;
|
||||
|
||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||
import javax.annotation.PreDestroy;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.scheduling.annotation.Scheduled;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class EmbeddingPersistentTask {
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
|
||||
@PreDestroy
|
||||
public void onShutdown() {
|
||||
embeddingStorePersistentToFile();
|
||||
}
|
||||
|
||||
private void embeddingStorePersistentToFile() {
|
||||
if (s2EmbeddingStore instanceof InMemoryS2EmbeddingStore) {
|
||||
log.info("start persistentToFile");
|
||||
((InMemoryS2EmbeddingStore) s2EmbeddingStore).persistentToFile();
|
||||
log.info("end persistentToFile");
|
||||
}
|
||||
}
|
||||
|
||||
@Scheduled(cron = "${inMemoryEmbeddingStore.persistent.cron:0 0 * * * ?}")
|
||||
public void executeTask() {
|
||||
embeddingStorePersistentToFile();
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,13 @@
|
||||
package com.tencent.supersonic.common.util.embedding;
|
||||
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.common.pojo.DataItem;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
public class EmbeddingQuery {
|
||||
@@ -17,4 +20,17 @@ public class EmbeddingQuery {
|
||||
|
||||
private List<Double> queryEmbedding;
|
||||
|
||||
public static List<EmbeddingQuery> convertToEmbedding(List<DataItem> dataItems) {
|
||||
return dataItems.stream().map(dataItem -> {
|
||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||
embeddingQuery.setQueryId(
|
||||
dataItem.getId() + dataItem.getType().name().toLowerCase());
|
||||
embeddingQuery.setQuery(dataItem.getName());
|
||||
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
|
||||
embeddingQuery.setMetadata(meta);
|
||||
embeddingQuery.setQueryEmbedding(null);
|
||||
return embeddingQuery;
|
||||
}).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -2,6 +2,9 @@ package com.tencent.supersonic.headless.api.pojo.request;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
@@ -25,4 +28,8 @@ public class QuerySqlReq extends SemanticQueryReq {
|
||||
return stringBuilder.toString();
|
||||
}
|
||||
|
||||
public boolean needGetDataSetId() {
|
||||
return (Objects.isNull(this.getDataSetId()) || this.getDataSetId() <= 0)
|
||||
&& (CollectionUtils.isEmpty(this.getModelIds()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
package com.tencent.supersonic.headless.server.listener;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.DataEvent;
|
||||
import com.tencent.supersonic.common.pojo.DataItem;
|
||||
import com.tencent.supersonic.common.pojo.enums.EventType;
|
||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
@@ -18,6 +15,8 @@ import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
|
||||
@@ -33,30 +32,15 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
|
||||
@Async
|
||||
@Override
|
||||
public void onApplicationEvent(DataEvent event) {
|
||||
if (CollectionUtils.isEmpty(event.getDataItems())) {
|
||||
List<DataItem> dataItems = event.getDataItems();
|
||||
if (CollectionUtils.isEmpty(dataItems)) {
|
||||
return;
|
||||
}
|
||||
|
||||
List<EmbeddingQuery> embeddingQueries = event.getDataItems()
|
||||
.stream()
|
||||
.map(dataItem -> {
|
||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||
embeddingQuery.setQueryId(
|
||||
dataItem.getId() + dataItem.getType().name().toLowerCase());
|
||||
embeddingQuery.setQuery(dataItem.getName());
|
||||
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
|
||||
embeddingQuery.setMetadata(meta);
|
||||
embeddingQuery.setQueryEmbedding(null);
|
||||
return embeddingQuery;
|
||||
}).collect(Collectors.toList());
|
||||
List<EmbeddingQuery> embeddingQueries = EmbeddingQuery.convertToEmbedding(dataItems);
|
||||
if (CollectionUtils.isEmpty(embeddingQueries)) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
Thread.sleep(embeddingOperationSleepTime);
|
||||
} catch (InterruptedException e) {
|
||||
log.error("", e);
|
||||
}
|
||||
sleep();
|
||||
s2EmbeddingStore.addCollection(embeddingConfig.getMetaCollectionName());
|
||||
if (event.getEventType().equals(EventType.ADD)) {
|
||||
s2EmbeddingStore.addQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries);
|
||||
@@ -68,4 +52,12 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
|
||||
}
|
||||
}
|
||||
|
||||
private void sleep() {
|
||||
try {
|
||||
Thread.sleep(embeddingOperationSleepTime);
|
||||
} catch (InterruptedException e) {
|
||||
log.error("", e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -3,25 +3,25 @@ package com.tencent.supersonic.headless.server.rest;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.validation.Valid;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DictItemFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DictItemReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DictSingleTaskReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DictItemResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DictTaskResp;
|
||||
import com.tencent.supersonic.headless.server.schedule.EmbeddingTask;
|
||||
import com.tencent.supersonic.headless.server.service.DictConfService;
|
||||
import com.tencent.supersonic.headless.server.service.DictTaskService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.PutMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.validation.Valid;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@@ -35,6 +35,9 @@ public class KnowledgeController {
|
||||
@Autowired
|
||||
private DictConfService confService;
|
||||
|
||||
@Autowired
|
||||
private EmbeddingTask embeddingTask;
|
||||
|
||||
/**
|
||||
* addDictConf-新增item的字典配置
|
||||
* Add configuration information for dictionary entries
|
||||
@@ -43,8 +46,8 @@ public class KnowledgeController {
|
||||
*/
|
||||
@PostMapping("/conf")
|
||||
public DictItemResp addDictConf(@RequestBody @Valid DictItemReq dictItemReq,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
return confService.addDictConf(dictItemReq, user);
|
||||
}
|
||||
@@ -57,8 +60,8 @@ public class KnowledgeController {
|
||||
*/
|
||||
@PutMapping("/conf")
|
||||
public DictItemResp editDictConf(@RequestBody @Valid DictItemReq dictItemReq,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
return confService.editDictConf(dictItemReq, user);
|
||||
}
|
||||
@@ -129,4 +132,9 @@ public class KnowledgeController {
|
||||
return taskService.queryLatestDictTask(taskReq, user);
|
||||
}
|
||||
|
||||
@GetMapping("/meta/embedding/reload")
|
||||
public Object reloadMetaEmbedding() {
|
||||
embeddingTask.reloadMetaEmbedding();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,18 +2,11 @@ package com.tencent.supersonic.headless.server.rest.api;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlsReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.core.chat.corrector.GrammarCorrector;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.headless.server.service.QueryService;
|
||||
import com.tencent.supersonic.headless.server.service.impl.SemanticService;
|
||||
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
@@ -36,14 +29,14 @@ public class SqlQueryApiController {
|
||||
private QueryService queryService;
|
||||
|
||||
@Autowired
|
||||
private SemanticService semanticService;
|
||||
private ChatQueryService chatQueryService;
|
||||
|
||||
@PostMapping("/sql")
|
||||
public Object queryBySql(@RequestBody QuerySqlReq querySqlReq,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) throws Exception {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
correct(querySqlReq);
|
||||
chatQueryService.correct(querySqlReq, user);
|
||||
return queryService.queryByReq(querySqlReq, user);
|
||||
}
|
||||
|
||||
@@ -57,28 +50,9 @@ public class SqlQueryApiController {
|
||||
QuerySqlReq querySqlReq = new QuerySqlReq();
|
||||
BeanUtils.copyProperties(querySqlsReq, querySqlReq);
|
||||
querySqlReq.setSql(sql);
|
||||
correct(querySqlReq);
|
||||
chatQueryService.correct(querySqlReq, user);
|
||||
return querySqlReq;
|
||||
}).collect(Collectors.toList());
|
||||
return queryService.queryByReqs(semanticQueryReqs, user);
|
||||
}
|
||||
|
||||
private void correct(QuerySqlReq querySqlReq) {
|
||||
QueryContext queryCtx = new QueryContext();
|
||||
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
|
||||
queryCtx.setSemanticSchema(semanticSchema);
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
sqlInfo.setCorrectS2SQL(querySqlReq.getSql());
|
||||
sqlInfo.setS2SQL(querySqlReq.getSql());
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
semanticParseInfo.setQueryType(QueryType.TAG);
|
||||
|
||||
ComponentFactory.getSemanticCorrectors().forEach(corrector -> {
|
||||
if (!(corrector instanceof GrammarCorrector)) {
|
||||
corrector.correct(queryCtx, semanticParseInfo);
|
||||
}
|
||||
});
|
||||
querySqlReq.setSql(sqlInfo.getCorrectS2SQL());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
package com.tencent.supersonic.headless.server.schedule;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.DataItem;
|
||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import com.tencent.supersonic.headless.server.service.DimensionService;
|
||||
import com.tencent.supersonic.headless.server.service.MetricService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.scheduling.annotation.Scheduled;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import javax.annotation.PreDestroy;
|
||||
import java.util.List;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class EmbeddingTask {
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
@Autowired
|
||||
private MetricService metricService;
|
||||
|
||||
@Autowired
|
||||
private DimensionService dimensionService;
|
||||
|
||||
@PreDestroy
|
||||
public void onShutdown() {
|
||||
embeddingStorePersistentToFile();
|
||||
}
|
||||
|
||||
private void embeddingStorePersistentToFile() {
|
||||
if (s2EmbeddingStore instanceof InMemoryS2EmbeddingStore) {
|
||||
log.info("start persistentToFile");
|
||||
((InMemoryS2EmbeddingStore) s2EmbeddingStore).persistentToFile();
|
||||
log.info("end persistentToFile");
|
||||
}
|
||||
}
|
||||
|
||||
@Scheduled(cron = "${inMemoryEmbeddingStore.persistent.cron:0 0 * * * ?}")
|
||||
public void executeTask() {
|
||||
embeddingStorePersistentToFile();
|
||||
}
|
||||
|
||||
|
||||
/***
|
||||
* reload meta embedding
|
||||
*/
|
||||
@Scheduled(cron = "${reload.meta.embedding.corn:0 0 */2 * * ?}")
|
||||
public void reloadMetaEmbedding() {
|
||||
log.info("reload.meta.embedding start");
|
||||
try {
|
||||
List<DataItem> metricDataItems = metricService.getDataEvent().getDataItems();
|
||||
|
||||
s2EmbeddingStore.addQuery(embeddingConfig.getMetaCollectionName(),
|
||||
EmbeddingQuery.convertToEmbedding(metricDataItems));
|
||||
|
||||
List<DataItem> dimensionDataItems = dimensionService.getDataEvent().getDataItems();
|
||||
s2EmbeddingStore.addQuery(embeddingConfig.getMetaCollectionName(),
|
||||
EmbeddingQuery.convertToEmbedding(dimensionDataItems));
|
||||
} catch (Exception e) {
|
||||
log.error("reload.meta.embedding error", e);
|
||||
}
|
||||
|
||||
log.info("reload.meta.embedding end");
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
@@ -29,5 +30,7 @@ public interface ChatQueryService {
|
||||
EntityInfo getEntityInfo(SemanticParseInfo parseInfo, User user);
|
||||
|
||||
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
|
||||
|
||||
void correct(QuerySqlReq querySqlReq, User user);
|
||||
}
|
||||
|
||||
|
||||
@@ -38,4 +38,6 @@ public interface DataSetService {
|
||||
|
||||
SemanticQueryReq convert(QueryDataSetReq queryDataSetReq);
|
||||
|
||||
Long getDataSetIdFromSql(String sql, User user);
|
||||
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.server.service;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.DataEvent;
|
||||
import com.tencent.supersonic.common.pojo.enums.EventType;
|
||||
import com.tencent.supersonic.headless.api.pojo.DimValueMap;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionReq;
|
||||
@@ -10,6 +11,7 @@ import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.server.pojo.DimensionsFilter;
|
||||
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface DimensionService {
|
||||
@@ -41,4 +43,6 @@ public interface DimensionService {
|
||||
List<DimValueMap> mockDimensionValueAlias(DimensionReq dimensionReq, User user);
|
||||
|
||||
void sendDimensionEventBatch(List<Long> modelIds, EventType eventType);
|
||||
|
||||
DataEvent getDataEvent();
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.server.service;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.DataEvent;
|
||||
import com.tencent.supersonic.common.pojo.enums.EventType;
|
||||
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
|
||||
import com.tencent.supersonic.headless.api.pojo.MetricQueryDefaultConfig;
|
||||
@@ -14,6 +15,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
||||
import com.tencent.supersonic.headless.server.pojo.MetricsFilter;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
@@ -60,4 +62,6 @@ public interface MetricService {
|
||||
List<MetricResp> queryMetrics(MetricsFilter metricsFilter);
|
||||
|
||||
QueryStructReq convert(QueryMetricReq queryMetricReq);
|
||||
|
||||
DataEvent getDataEvent();
|
||||
}
|
||||
|
||||
@@ -21,5 +21,4 @@ public interface QueryService {
|
||||
List<ItemUseResp> getStatInfo(ItemUseReq itemUseCommend);
|
||||
|
||||
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
|
||||
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.CostType;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
@@ -29,6 +30,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
|
||||
@@ -37,6 +39,7 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.core.chat.corrector.GrammarCorrector;
|
||||
import com.tencent.supersonic.headless.core.chat.corrector.SemanticCorrector;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService;
|
||||
@@ -647,5 +650,30 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
return queryService.queryByReq(queryStructReq, user);
|
||||
}
|
||||
|
||||
}
|
||||
public void correct(QuerySqlReq querySqlReq, User user) {
|
||||
QueryContext queryCtx = new QueryContext();
|
||||
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
|
||||
queryCtx.setSemanticSchema(semanticSchema);
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
sqlInfo.setCorrectS2SQL(querySqlReq.getSql());
|
||||
sqlInfo.setS2SQL(querySqlReq.getSql());
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
semanticParseInfo.setQueryType(QueryType.TAG);
|
||||
|
||||
Long dataSetId = querySqlReq.getDataSetId();
|
||||
if (Objects.isNull(dataSetId)) {
|
||||
dataSetId = dataSetService.getDataSetIdFromSql(querySqlReq.getSql(), user);
|
||||
}
|
||||
SchemaElement dataSet = semanticSchema.getDataSet(dataSetId);
|
||||
semanticParseInfo.setDataSet(dataSet);
|
||||
|
||||
ComponentFactory.getSemanticCorrectors().forEach(corrector -> {
|
||||
if (!(corrector instanceof GrammarCorrector)) {
|
||||
corrector.correct(queryCtx, semanticParseInfo);
|
||||
}
|
||||
});
|
||||
querySqlReq.setSql(sqlInfo.getCorrectS2SQL());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetDetail;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType;
|
||||
@@ -34,6 +35,15 @@ import com.tencent.supersonic.headless.server.service.DimensionService;
|
||||
import com.tencent.supersonic.headless.server.service.DomainService;
|
||||
import com.tencent.supersonic.headless.server.service.MetricService;
|
||||
import com.tencent.supersonic.headless.server.service.TagMetaService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.Date;
|
||||
@@ -45,15 +55,9 @@ import java.util.Set;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class DataSetServiceImpl
|
||||
extends ServiceImpl<DataSetDOMapper, DataSetDO> implements DataSetService {
|
||||
|
||||
@@ -311,4 +315,21 @@ public class DataSetServiceImpl
|
||||
.map(Object::toString)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public Long getDataSetIdFromSql(String sql, User user) {
|
||||
List<DataSetResp> dataSets = null;
|
||||
try {
|
||||
String tableName = SqlSelectHelper.getTableName(sql);
|
||||
dataSets = getDataSets(tableName, user);
|
||||
} catch (Exception e) {
|
||||
log.error("getDataSetIdFromSql error:{}", e);
|
||||
}
|
||||
if (org.apache.commons.collections.CollectionUtils.isEmpty(dataSets)) {
|
||||
throw new InvalidArgumentException("从Sql参数中无法获取到DataSetId");
|
||||
}
|
||||
Long dataSetId = dataSets.get(0).getId();
|
||||
log.info("getDataSetIdFromSql dataSetId:{}", dataSetId);
|
||||
return dataSetId;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -430,17 +430,27 @@ public class DimensionServiceImpl implements DimensionService {
|
||||
}
|
||||
|
||||
private void sendEventBatch(List<DimensionDO> dimensionDOS, EventType eventType) {
|
||||
DataEvent dataEvent = getDataEvent(dimensionDOS, eventType);
|
||||
eventPublisher.publishEvent(dataEvent);
|
||||
}
|
||||
|
||||
public DataEvent getDataEvent() {
|
||||
DimensionFilter dimensionFilter = new DimensionFilter();
|
||||
List<DimensionDO> dimensionDOS = queryDimension(dimensionFilter);
|
||||
return getDataEvent(dimensionDOS, EventType.ADD);
|
||||
}
|
||||
|
||||
private DataEvent getDataEvent(List<DimensionDO> dimensionDOS, EventType eventType) {
|
||||
List<DataItem> dataItems = dimensionDOS.stream()
|
||||
.map(dimensionDO -> DataItem.builder().id(dimensionDO.getId() + Constants.UNDERLINE)
|
||||
.name(dimensionDO.getName()).modelId(dimensionDO.getModelId() + Constants.UNDERLINE)
|
||||
.type(TypeEnums.DIMENSION).build())
|
||||
.collect(Collectors.toList());
|
||||
eventPublisher.publishEvent(new DataEvent(this, dataItems, eventType));
|
||||
return new DataEvent(this, dataItems, eventType);
|
||||
}
|
||||
|
||||
private void sendEvent(DataItem dataItem, EventType eventType) {
|
||||
eventPublisher.publishEvent(new DataEvent(this,
|
||||
Lists.newArrayList(dataItem), eventType));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
@@ -97,13 +97,13 @@ public class MetricServiceImpl implements MetricService {
|
||||
private TagMetaService tagMetaService;
|
||||
|
||||
public MetricServiceImpl(MetricRepository metricRepository,
|
||||
ModelService modelService,
|
||||
ChatGptHelper chatGptHelper,
|
||||
CollectService collectService,
|
||||
DataSetService dataSetService,
|
||||
ApplicationEventPublisher eventPublisher,
|
||||
DimensionService dimensionService,
|
||||
TagMetaService tagMetaService) {
|
||||
ModelService modelService,
|
||||
ChatGptHelper chatGptHelper,
|
||||
CollectService collectService,
|
||||
DataSetService dataSetService,
|
||||
ApplicationEventPublisher eventPublisher,
|
||||
DimensionService dimensionService,
|
||||
TagMetaService tagMetaService) {
|
||||
this.metricRepository = metricRepository;
|
||||
this.modelService = modelService;
|
||||
this.chatGptHelper = chatGptHelper;
|
||||
@@ -326,7 +326,7 @@ public class MetricServiceImpl implements MetricService {
|
||||
}
|
||||
|
||||
private boolean filterByField(List<MetricResp> metricResps, MetricResp metricResp,
|
||||
List<String> fields, Set<MetricResp> metricRespFiltered) {
|
||||
List<String> fields, Set<MetricResp> metricRespFiltered) {
|
||||
if (MetricDefineType.METRIC.equals(metricResp.getMetricDefineType())) {
|
||||
List<Long> ids = metricResp.getMetricDefineByMetricParams().getMetrics()
|
||||
.stream().map(MetricParam::getId).collect(Collectors.toList());
|
||||
@@ -556,9 +556,20 @@ public class MetricServiceImpl implements MetricService {
|
||||
}
|
||||
|
||||
private void sendEventBatch(List<MetricDO> metricDOS, EventType eventType) {
|
||||
DataEvent dataEvent = getDataEvent(metricDOS, eventType);
|
||||
eventPublisher.publishEvent(dataEvent);
|
||||
}
|
||||
|
||||
public DataEvent getDataEvent() {
|
||||
MetricsFilter metricsFilter = new MetricsFilter();
|
||||
List<MetricDO> metricDOS = metricRepository.getMetrics(metricsFilter);
|
||||
return getDataEvent(metricDOS, EventType.ADD);
|
||||
}
|
||||
|
||||
private DataEvent getDataEvent(List<MetricDO> metricDOS, EventType eventType) {
|
||||
List<DataItem> dataItems = metricDOS.stream().map(this::getDataItem)
|
||||
.collect(Collectors.toList());
|
||||
eventPublisher.publishEvent(new DataEvent(this, dataItems, eventType));
|
||||
return new DataEvent(this, dataItems, eventType);
|
||||
}
|
||||
|
||||
private void sendEvent(DataItem dataItem, EventType eventType) {
|
||||
@@ -662,7 +673,7 @@ public class MetricServiceImpl implements MetricService {
|
||||
}
|
||||
|
||||
private Set<Long> getModelIds(Set<Long> modelIdsByDomainId, List<MetricResp> metricResps,
|
||||
List<DimensionResp> dimensionResps) {
|
||||
List<DimensionResp> dimensionResps) {
|
||||
Set<Long> result = new HashSet<>();
|
||||
if (org.apache.commons.collections.CollectionUtils.isNotEmpty(modelIdsByDomainId)) {
|
||||
result.addAll(modelIdsByDomainId);
|
||||
|
||||
@@ -5,8 +5,6 @@ import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.Dim;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryParam;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
|
||||
@@ -17,7 +15,6 @@ import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ItemUseResp;
|
||||
@@ -137,8 +134,8 @@ public class QueryServiceImpl implements QueryService {
|
||||
|
||||
private QueryStatement buildSqlQueryStatement(QuerySqlReq querySqlReq, User user) throws Exception {
|
||||
//If dataSetId or DataSetName is empty, parse dataSetId from the SQL
|
||||
if (needGetDataSetId(querySqlReq)) {
|
||||
Long dataSetId = getDataSetIdFromSql(querySqlReq, user);
|
||||
if (querySqlReq.needGetDataSetId()) {
|
||||
Long dataSetId = dataSetService.getDataSetIdFromSql(querySqlReq.getSql(), user);
|
||||
querySqlReq.setDataSetId(dataSetId);
|
||||
}
|
||||
SchemaFilterReq filter = buildSchemaFilterReq(querySqlReq);
|
||||
@@ -151,27 +148,6 @@ public class QueryServiceImpl implements QueryService {
|
||||
return queryStatement;
|
||||
}
|
||||
|
||||
private static boolean needGetDataSetId(QuerySqlReq querySqlReq) {
|
||||
return (Objects.isNull(querySqlReq.getDataSetId()) || querySqlReq.getDataSetId() <= 0)
|
||||
&& (CollectionUtils.isEmpty(querySqlReq.getModelIds()));
|
||||
}
|
||||
|
||||
private Long getDataSetIdFromSql(QuerySqlReq querySqlReq, User user) {
|
||||
List<DataSetResp> dataSets = null;
|
||||
try {
|
||||
String tableName = SqlSelectHelper.getTableName(querySqlReq.getSql());
|
||||
dataSets = dataSetService.getDataSets(tableName, user);
|
||||
} catch (Exception e) {
|
||||
log.error("getDataSetIdFromSql error:{}", e);
|
||||
}
|
||||
if (CollectionUtils.isEmpty(dataSets)) {
|
||||
throw new InvalidArgumentException("从Sql参数中无法获取到DataSetId");
|
||||
}
|
||||
Long dataSetId = dataSets.get(0).getId();
|
||||
log.info("getDataSetIdFromSql dataSetId:{}", dataSetId);
|
||||
return dataSetId;
|
||||
}
|
||||
|
||||
private QueryStatement buildQueryStatement(SemanticQueryReq semanticQueryReq, User user) throws Exception {
|
||||
if (semanticQueryReq instanceof QuerySqlReq) {
|
||||
return buildSqlQueryStatement((QuerySqlReq) semanticQueryReq, user);
|
||||
|
||||
Reference in New Issue
Block a user