Merge pull request #759 from lexluo09/master

(improvement)(project) merge master to dev-0.9
This commit is contained in:
lexluo09
2024-02-26 14:41:23 +08:00
committed by GitHub
142 changed files with 2041 additions and 664 deletions

View File

@@ -4,6 +4,14 @@
- "Breaking Changes" describes any changes that may break existing functionality or cause - "Breaking Changes" describes any changes that may break existing functionality or cause
compatibility issues with previous versions. compatibility issues with previous versions.
## SuperSonic [0.8.6] - 2024-02-23
### Added
- support view abstraction to Headless.
- add the Metric API to Headless and optimizing the Headless API.
- add integration tests to Headless.
- add TimeCorrector to Chat.
## SuperSonic [0.8.4] - 2024-01-19 ## SuperSonic [0.8.4] - 2024-01-19
### Added ### Added

View File

@@ -2,16 +2,14 @@ package com.tencent.supersonic.chat.api.pojo;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import org.springframework.util.CollectionUtils;
import java.io.Serializable; import java.io.Serializable;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.springframework.util.CollectionUtils;
public class SemanticSchema implements Serializable { public class SemanticSchema implements Serializable {
@@ -54,35 +52,6 @@ public class SemanticSchema implements Serializable {
} }
} }
public SchemaElement getElementByName(SchemaElementType elementType, String name) {
Optional<SchemaElement> element = Optional.empty();
switch (elementType) {
case ENTITY:
element = getElementsByNameOrAlias(name, getEntities());
break;
case VIEW:
element = getElementsByNameOrAlias(name, getViews());
break;
case METRIC:
element = getElementsByNameOrAlias(name, getMetrics());
break;
case DIMENSION:
element = getElementsByNameOrAlias(name, getDimensions());
break;
case VALUE:
element = getElementsByNameOrAlias(name, getDimensionValues());
break;
default:
}
if (element.isPresent()) {
return element.get();
} else {
return null;
}
}
public Map<Long, String> getViewIdToName() { public Map<Long, String> getViewIdToName() {
return viewSchemaList.stream() return viewSchemaList.stream()
.collect(Collectors.toMap(a -> a.getView().getId(), a -> a.getView().getName(), (k1, k2) -> k1)); .collect(Collectors.toMap(a -> a.getView().getId(), a -> a.getView().getName(), (k1, k2) -> k1));
@@ -159,14 +128,6 @@ public class SemanticSchema implements Serializable {
.findFirst(); .findFirst();
} }
private Optional<SchemaElement> getElementsByNameOrAlias(String name, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement ->
name.equals(schemaElement.getName()) || (Objects.nonNull(schemaElement.getAlias())
&& schemaElement.getAlias().contains(name))
).findFirst();
}
public SchemaElement getView(Long viewId) { public SchemaElement getView(Long viewId) {
List<SchemaElement> views = getViews(); List<SchemaElement> views = getViews();
return getElementsById(viewId, views).orElse(null); return getElementsById(viewId, views).orElse(null);

View File

@@ -73,9 +73,6 @@ public class OptimizationConfig {
@Value("${text2sql.self.consistency.num:5}") @Value("${text2sql.self.consistency.num:5}")
private int text2sqlSelfConsistencyNum; private int text2sqlSelfConsistencyNum;
@Value("${text2sql.collection.name:text2dsl_agent_collection}")
private String text2sqlCollectionName;
@Value("${parse.show.count:3}") @Value("${parse.show.count:3}")
private Integer parseShowCount; private Integer parseShowCount;

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.chat.core.corrector; package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
@@ -12,6 +13,7 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.springframework.core.env.Environment;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
@@ -77,7 +79,13 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) { protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL)); Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL)); Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
//needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
//decide whether add order by expression field to select
Environment environment = ContextUtils.getBean(Environment.class);
String correctorAdditionalInfo = environment.getProperty("corrector.additional.information");
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
}
// If there is no aggregate function in the S2SQL statement and // If there is no aggregate function in the S2SQL statement and
// there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement. // there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement.

View File

@@ -45,7 +45,8 @@ public class GroupByCorrector extends BaseSemanticCorrector {
ViewResp viewResp = viewService.getView(viewId); ViewResp viewResp = viewService.getView(viewId);
List<Long> modelIds = viewResp.getViewDetail().getViewModelConfigs().stream().map(config -> config.getId()) List<Long> modelIds = viewResp.getViewDetail().getViewModelConfigs().stream().map(config -> config.getId())
.collect(Collectors.toList()); .collect(Collectors.toList());
MetaFilter metaFilter = new MetaFilter(modelIds); MetaFilter metaFilter = new MetaFilter();
metaFilter.setIds(modelIds);
List<ModelResp> modelRespList = modelService.getModelList(metaFilter); List<ModelResp> modelRespList = modelService.getModelList(metaFilter);
for (ModelResp modelResp : modelRespList) { for (ModelResp modelResp : modelRespList) {
List<Dim> dimList = modelResp.getModelDetail().getDimensions(); List<Dim> dimList = modelResp.getModelDetail().getDimensions();

View File

@@ -3,11 +3,14 @@ package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext; import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.env.Environment;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.List; import java.util.List;
@@ -26,8 +29,12 @@ public class HavingCorrector extends BaseSemanticCorrector {
//add aggregate to all metric //add aggregate to all metric
addHaving(queryContext, semanticParseInfo); addHaving(queryContext, semanticParseInfo);
//add having expression filed to select //decide whether add having expression field to select
//addHavingToSelect(semanticParseInfo); Environment environment = ContextUtils.getBean(Environment.class);
String correctorAdditionalInfo = environment.getProperty("corrector.additional.information");
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
addHavingToSelect(semanticParseInfo);
}
} }

View File

@@ -5,7 +5,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.core.parser.sql.llm.S2SqlDateHelper; import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper;
import com.tencent.supersonic.chat.core.pojo.QueryContext; import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
@@ -18,6 +18,7 @@ import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.util.Strings; import org.apache.logging.log4j.util.Strings;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@@ -65,11 +66,20 @@ public class WhereCorrector extends BaseSemanticCorrector {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL); List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) { if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
String currentDate = S2SqlDateHelper.getReferenceDate(queryContext, semanticParseInfo.getViewId()); Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext,
if (StringUtils.isNotBlank(currentDate)) { semanticParseInfo.getViewId(), semanticParseInfo.getQueryType());
if (StringUtils.isNotBlank(startEndDate.getLeft())
&& StringUtils.isNotBlank(startEndDate.getRight())) {
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL); correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
correctS2SQL = SqlAddHelper.addWhere( String dateChName = TimeDimensionEnum.DAY.getChName();
correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate); String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )", dateChName,
startEndDate.getLeft(), dateChName, startEndDate.getRight());
try {
Expression expression = CCJSqlParserUtil.parseCondExpression(condExpr);
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
} catch (JSQLParserException e) {
log.error("parseCondExpression:{}", e);
}
} }
} }
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);

View File

@@ -10,9 +10,10 @@ import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult; import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder; import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper; import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.server.service.KnowledgeService;
import lombok.extern.slf4j.Slf4j;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
/*** /***
* A mapper that recognizes schema elements with vector embedding. * A mapper that recognizes schema elements with vector embedding.
@@ -24,7 +25,8 @@ public class EmbeddingMapper extends BaseMapper {
public void doMap(QueryContext queryContext) { public void doMap(QueryContext queryContext) {
//1. query from embedding by queryText //1. query from embedding by queryText
String queryText = queryContext.getQueryText(); String queryText = queryContext.getQueryText();
List<S2Term> terms = HanlpHelper.getTerms(queryText); KnowledgeService knowledgeService = ContextUtils.getBean(KnowledgeService.class);
List<S2Term> terms = knowledgeService.getTerms(queryText);
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class); EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms); List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.core.parser.sql.llm; package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;

View File

@@ -48,7 +48,7 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq); keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(), List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum()); optimizationConfig.getText2sqlExampleNum());
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples, List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum()); optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());

View File

@@ -45,7 +45,7 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
//1.retriever sqlExamples //1.retriever sqlExamples
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq); keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(), List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum()); optimizationConfig.getText2sqlExampleNum());
//2.generator linking and sql prompt by sqlExamples,and generate response. //2.generator linking and sql prompt by sqlExamples,and generate response.
String promptStr = sqlPromptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples); String promptStr = sqlPromptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);

View File

@@ -1,49 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.util.DatePeriodEnum;
import com.tencent.supersonic.common.util.DateUtils;
import java.util.Objects;
public class S2SqlDateHelper {
public static String getReferenceDate(QueryContext queryContext, Long modelId) {
String defaultDate = DateUtils.getBeforeDate(0);
if (Objects.isNull(modelId)) {
return defaultDate;
}
ChatConfigFilter filter = new ChatConfigFilter();
filter.setModelId(modelId);
ChatConfigRichResp chatConfigRichResp = queryContext.getModelIdToChatRichConfig().get(modelId);
if (Objects.isNull(chatConfigRichResp)) {
return defaultDate;
}
if (Objects.isNull(chatConfigRichResp.getChatDetailRichConfig()) || Objects.isNull(
chatConfigRichResp.getChatDetailRichConfig().getChatDefaultConfig())) {
return defaultDate;
}
ChatDefaultRichConfigResp chatDefaultConfig = chatConfigRichResp.getChatDetailRichConfig()
.getChatDefaultConfig();
Integer unit = chatDefaultConfig.getUnit();
String period = chatDefaultConfig.getPeriod();
if (Objects.nonNull(unit)) {
// If the unit is set to less than 0, then do not add relative date.
if (unit < 0) {
return null;
}
DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(period);
if (Objects.isNull(datePeriodEnum)) {
return DateUtils.getBeforeDate(unit);
} else {
return DateUtils.getBeforeDate(unit, datePeriodEnum);
}
}
return defaultDate;
}
}

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.core.type.TypeReference;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ComponentFactory; import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery; import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
@@ -19,6 +20,7 @@ import java.util.Objects;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.ClassPathResource;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
@@ -32,6 +34,9 @@ public class SqlExamplarLoader {
private TypeReference<List<SqlExample>> valueTypeRef = new TypeReference<List<SqlExample>>() { private TypeReference<List<SqlExample>> valueTypeRef = new TypeReference<List<SqlExample>>() {
}; };
@Autowired
private EmbeddingConfig embeddingConfig;
public List<SqlExample> getSqlExamples() throws IOException { public List<SqlExample> getSqlExamples() throws IOException {
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE); ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
InputStream inputStream = resource.getInputStream(); InputStream inputStream = resource.getInputStream();
@@ -53,8 +58,8 @@ public class SqlExamplarLoader {
s2EmbeddingStore.addQuery(collectionName, queries); s2EmbeddingStore.addQuery(collectionName, queries);
} }
public List<Map<String, String>> retrieverSqlExamples(String queryText, String collectionName, int maxResults) { public List<Map<String, String>> retrieverSqlExamples(String queryText, int maxResults) {
String collectionName = embeddingConfig.getText2sqlCollectionName();
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText)) RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText))
.queryEmbeddings(null).build(); .queryEmbeddings(null).build();

View File

@@ -44,7 +44,7 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
//1.retriever sqlExamples and generate exampleListPool //1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq); keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(), List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum()); optimizationConfig.getText2sqlExampleNum());
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples, List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum()); optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());

View File

@@ -43,7 +43,7 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
public LLMResp generation(LLMReq llmReq, Long viewId) { public LLMResp generation(LLMReq llmReq, Long viewId) {
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq); keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(), List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum()); optimizationConfig.getText2sqlExampleNum());
String linkingPromptStr = sqlPromptGenerator.generateLinkingPrompt(llmReq, sqlExamples); String linkingPromptStr = sqlPromptGenerator.generateLinkingPrompt(llmReq, sqlExamples);

View File

@@ -121,7 +121,7 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
} }
QueryStructReq queryStructReq = convertQueryStruct(); QueryStructReq queryStructReq = convertQueryStruct();
convertBizNameToName(semanticSchema, queryStructReq); convertBizNameToName(semanticSchema, queryStructReq);
QuerySqlReq querySQLReq = queryStructReq.convert(queryStructReq); QuerySqlReq querySQLReq = queryStructReq.convert();
parseInfo.getSqlInfo().setS2SQL(querySQLReq.getSql()); parseInfo.getSqlInfo().setS2SQL(querySQLReq.getSql());
parseInfo.getSqlInfo().setCorrectS2SQL(querySQLReq.getSql()); parseInfo.getSqlInfo().setCorrectS2SQL(querySQLReq.getSql());
} }

View File

@@ -44,6 +44,7 @@ public class ViewSchemaBuilder {
SchemaElement metricToAdd = SchemaElement.builder() SchemaElement metricToAdd = SchemaElement.builder()
.view(resp.getId()) .view(resp.getId())
.model(metric.getModelId())
.id(metric.getId()) .id(metric.getId())
.name(metric.getName()) .name(metric.getName())
.bizName(metric.getBizName()) .bizName(metric.getBizName())
@@ -84,6 +85,7 @@ public class ViewSchemaBuilder {
} }
SchemaElement dimToAdd = SchemaElement.builder() SchemaElement dimToAdd = SchemaElement.builder()
.view(resp.getId()) .view(resp.getId())
.model(dim.getModelId())
.id(dim.getId()) .id(dim.getId())
.name(dim.getName()) .name(dim.getName())
.bizName(dim.getBizName()) .bizName(dim.getBizName())
@@ -96,6 +98,7 @@ public class ViewSchemaBuilder {
SchemaElement dimValueToAdd = SchemaElement.builder() SchemaElement dimValueToAdd = SchemaElement.builder()
.view(resp.getId()) .view(resp.getId())
.model(dim.getModelId())
.id(dim.getId()) .id(dim.getId())
.name(dim.getName()) .name(dim.getName())
.bizName(dim.getBizName()) .bizName(dim.getBizName())
@@ -107,6 +110,7 @@ public class ViewSchemaBuilder {
if (dim.getIsTag() == 1) { if (dim.getIsTag() == 1) {
SchemaElement tagToAdd = SchemaElement.builder() SchemaElement tagToAdd = SchemaElement.builder()
.view(resp.getId()) .view(resp.getId())
.model(dim.getModelId())
.id(dim.getId()) .id(dim.getId())
.name(dim.getName()) .name(dim.getName())
.bizName(dim.getBizName()) .bizName(dim.getBizName())
@@ -126,6 +130,7 @@ public class ViewSchemaBuilder {
if (dim != null) { if (dim != null) {
SchemaElement entity = SchemaElement.builder() SchemaElement entity = SchemaElement.builder()
.view(resp.getId()) .view(resp.getId())
.model(dim.getModelId())
.id(dim.getId()) .id(dim.getId())
.name(dim.getName()) .name(dim.getName())
.bizName(dim.getBizName()) .bizName(dim.getBizName())

View File

@@ -0,0 +1,68 @@
package com.tencent.supersonic.chat.core.utils;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.common.util.DatePeriodEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import java.util.Objects;
import org.apache.commons.lang3.tuple.Pair;
public class S2SqlDateHelper {
public static String getReferenceDate(QueryContext queryContext, Long viewId) {
String defaultDate = DateUtils.getBeforeDate(0);
if (Objects.isNull(viewId)) {
return defaultDate;
}
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId);
if (viewSchema == null || viewSchema.getTagTypeTimeDefaultConfig() == null) {
return defaultDate;
}
TimeDefaultConfig tagTypeTimeDefaultConfig = viewSchema.getTagTypeTimeDefaultConfig();
return getDefaultDate(defaultDate, tagTypeTimeDefaultConfig).getLeft();
}
public static Pair<String, String> getStartEndDate(QueryContext queryContext,
Long viewId, QueryType queryType) {
String defaultDate = DateUtils.getBeforeDate(0);
if (Objects.isNull(viewId)) {
return Pair.of(defaultDate, defaultDate);
}
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId);
if (viewSchema == null) {
return Pair.of(defaultDate, defaultDate);
}
TimeDefaultConfig defaultConfig = viewSchema.getMetricTypeTimeDefaultConfig();
if (QueryType.TAG.equals(queryType)) {
defaultConfig = viewSchema.getTagTypeTimeDefaultConfig();
}
return getDefaultDate(defaultDate, defaultConfig);
}
private static Pair<String, String> getDefaultDate(String defaultDate, TimeDefaultConfig defaultConfig) {
if (Objects.isNull(defaultConfig)) {
return Pair.of(null, null);
}
Integer unit = defaultConfig.getUnit();
String period = defaultConfig.getPeriod();
TimeMode timeMode = defaultConfig.getTimeMode();
if (Objects.nonNull(unit)) {
// If the unit is set to less than 0, then do not add relative date.
if (unit < 0) {
return Pair.of(null, null);
}
DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(period);
String startDate = DateUtils.getBeforeDate(unit, datePeriodEnum);
String endDate = DateUtils.getBeforeDate(1, datePeriodEnum);
if (TimeMode.LAST.equals(timeMode)) {
endDate = startDate;
}
return Pair.of(startDate, endDate);
}
return Pair.of(defaultDate, defaultDate);
}
}

View File

@@ -0,0 +1,121 @@
package com.tencent.supersonic.chat.core.utils;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
class S2SqlDateHelperTest {
@Test
void getReferenceDate() {
Long viewId = 1L;
QueryContext queryContext = buildQueryContext(viewId);
String referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, null);
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0));
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, viewId);
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0));
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId);
QueryConfig queryConfig = viewSchema.getQueryConfig();
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
timeDefaultConfig.setTimeMode(TimeMode.LAST);
timeDefaultConfig.setPeriod(Constants.DAY);
timeDefaultConfig.setUnit(20);
queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, viewId);
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(20));
timeDefaultConfig.setUnit(1);
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, viewId);
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(1));
timeDefaultConfig.setUnit(-1);
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, viewId);
Assert.assertNull(referenceDate);
}
@Test
void getStartEndDate() {
Long viewId = 1L;
QueryContext queryContext = buildQueryContext(viewId);
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, null, QueryType.TAG);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(0));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(0));
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.TAG);
Assert.assertNull(startEndDate.getLeft());
Assert.assertNull(startEndDate.getRight());
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId);
QueryConfig queryConfig = viewSchema.getQueryConfig();
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
timeDefaultConfig.setTimeMode(TimeMode.LAST);
timeDefaultConfig.setPeriod(Constants.DAY);
timeDefaultConfig.setUnit(20);
queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
queryConfig.getMetricTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.TAG);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20));
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.METRIC);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20));
timeDefaultConfig.setUnit(2);
timeDefaultConfig.setTimeMode(TimeMode.RECENT);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.METRIC);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.TAG);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));
timeDefaultConfig.setUnit(-1);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.METRIC);
Assert.assertNull(startEndDate.getLeft());
Assert.assertNull(startEndDate.getRight());
timeDefaultConfig.setTimeMode(TimeMode.LAST);
timeDefaultConfig.setPeriod(Constants.DAY);
timeDefaultConfig.setUnit(5);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.METRIC);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(5));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(5));
}
private QueryContext buildQueryContext(Long viewId) {
QueryContext queryContext = new QueryContext();
List<ViewSchema> viewSchemaList = new ArrayList<>();
ViewSchema viewSchema = new ViewSchema();
QueryConfig queryConfig = new QueryConfig();
viewSchema.setQueryConfig(queryConfig);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setView(viewId);
viewSchema.setView(schemaElement);
viewSchemaList.add(viewSchema);
SemanticSchema semanticSchema = new SemanticSchema(viewSchemaList);
queryContext.setSemanticSchema(semanticSchema);
return queryContext;
}
}

View File

@@ -38,11 +38,11 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
queryResult.setRecommendedDimensions(dimensionRecommended); queryResult.setRecommendedDimensions(dimensionRecommended);
} }
private List<SchemaElement> getDimensions(Long metricId, Long modelId) { private List<SchemaElement> getDimensions(Long metricId, Long viewId) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class); SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
ViewSchema modelSchema = semanticService.getModelSchema(modelId); ViewSchema viewSchema = semanticService.getViewSchema(viewId);
List<Long> drillDownDimensions = Lists.newArrayList(); List<Long> drillDownDimensions = Lists.newArrayList();
Set<SchemaElement> metricElements = modelSchema.getMetrics(); Set<SchemaElement> metricElements = viewSchema.getMetrics();
if (!CollectionUtils.isEmpty(metricElements)) { if (!CollectionUtils.isEmpty(metricElements)) {
Optional<SchemaElement> metric = metricElements.stream().filter(schemaElement -> Optional<SchemaElement> metric = metricElements.stream().filter(schemaElement ->
metricId.equals(schemaElement.getId()) metricId.equals(schemaElement.getId())
@@ -54,7 +54,7 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
} }
} }
final List<Long> drillDownDimensionsFinal = drillDownDimensions; final List<Long> drillDownDimensionsFinal = drillDownDimensions;
return modelSchema.getDimensions().stream() return viewSchema.getDimensions().stream()
.filter(dim -> filterDimension(drillDownDimensionsFinal, dim)) .filter(dim -> filterDimension(drillDownDimensionsFinal, dim))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(recommend_dimension_size) .limit(recommend_dimension_size)

View File

@@ -50,7 +50,7 @@ public class SemanticService {
return schemaService.getSemanticSchema(); return schemaService.getSemanticSchema();
} }
public ViewSchema getModelSchema(Long id) { public ViewSchema getViewSchema(Long id) {
return schemaService.getViewSchema(id); return schemaService.getViewSchema(id);
} }

View File

@@ -219,7 +219,7 @@ public class ConfigServiceImpl implements ConfigService {
} }
BeanUtils.copyProperties(chatConfigResp, chatConfigRich); BeanUtils.copyProperties(chatConfigResp, chatConfigRich);
ViewSchema viewSchema = semanticService.getModelSchema(modelId); ViewSchema viewSchema = semanticService.getViewSchema(modelId);
if (viewSchema == null) { if (viewSchema == null) {
return chatConfigRich; return chatConfigRich;
} }

View File

@@ -48,7 +48,7 @@ public class RecommendServiceImpl implements RecommendService {
if (Objects.isNull(modelId)) { if (Objects.isNull(modelId)) {
return new RecommendResp(); return new RecommendResp();
} }
ViewSchema modelSchema = semanticService.getModelSchema(modelId); ViewSchema modelSchema = semanticService.getViewSchema(modelId);
if (Objects.isNull(modelSchema)) { if (Objects.isNull(modelSchema)) {
return new RecommendResp(); return new RecommendResp();
} }

View File

@@ -1,17 +1,19 @@
package com.tencent.supersonic.chat.server.service.impl; package com.tencent.supersonic.chat.server.service.impl;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.core.knowledge.DictWord;
import com.tencent.supersonic.headless.core.knowledge.builder.WordBuilderFactory;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter; import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.utils.ComponentFactory; import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.core.knowledge.DictWord;
import com.tencent.supersonic.headless.core.knowledge.builder.WordBuilderFactory;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
@Service @Service
@@ -28,7 +30,6 @@ public class WordService {
addWordsByType(DictWordType.DIMENSION, semanticSchema.getDimensions(), words); addWordsByType(DictWordType.DIMENSION, semanticSchema.getDimensions(), words);
addWordsByType(DictWordType.METRIC, semanticSchema.getMetrics(), words); addWordsByType(DictWordType.METRIC, semanticSchema.getMetrics(), words);
addWordsByType(DictWordType.VIEW, semanticSchema.getViews(), words);
addWordsByType(DictWordType.ENTITY, semanticSchema.getEntities(), words); addWordsByType(DictWordType.ENTITY, semanticSchema.getEntities(), words);
addWordsByType(DictWordType.VALUE, semanticSchema.getDimensionValues(), words); addWordsByType(DictWordType.VALUE, semanticSchema.getDimensionValues(), words);
@@ -36,6 +37,7 @@ public class WordService {
} }
private void addWordsByType(DictWordType value, List<SchemaElement> metas, List<DictWord> natures) { private void addWordsByType(DictWordType value, List<SchemaElement> metas, List<DictWord> natures) {
metas = distinct(metas);
List<DictWord> natureList = WordBuilderFactory.get(value).getDictWords(metas); List<DictWord> natureList = WordBuilderFactory.get(value).getDictWords(metas);
log.debug("nature type:{} , nature size:{}", value.name(), natureList.size()); log.debug("nature type:{} , nature size:{}", value.name(), natureList.size());
natures.addAll(natureList); natures.addAll(natureList);
@@ -48,4 +50,13 @@ public class WordService {
public void setPreDictWords(List<DictWord> preDictWords) { public void setPreDictWords(List<DictWord> preDictWords) {
this.preDictWords = preDictWords; this.preDictWords = preDictWords;
} }
private List<SchemaElement> distinct(List<SchemaElement> metas) {
return metas.stream()
.collect(Collectors.toMap(SchemaElement::getId, Function.identity(), (e1, e2) -> e1))
.values()
.stream()
.collect(Collectors.toList());
}
} }

View File

@@ -52,14 +52,14 @@ class QueryReqBuilderTest {
orders.add(order); orders.add(order);
queryStructReq.setOrders(orders); queryStructReq.setOrders(orders);
QuerySqlReq querySQLReq = queryStructReq.convert(queryStructReq); QuerySqlReq querySQLReq = queryStructReq.convert();
Assert.assertEquals( Assert.assertEquals(
"SELECT department, SUM(pv) AS pv FROM 内容库 " "SELECT department, SUM(pv) AS pv FROM 内容库 "
+ "WHERE (sys_imp_date IN ('2023-08-01')) GROUP " + "WHERE (sys_imp_date IN ('2023-08-01')) GROUP "
+ "BY department ORDER BY uv LIMIT 2000", querySQLReq.getSql()); + "BY department ORDER BY uv LIMIT 2000", querySQLReq.getSql());
queryStructReq.setQueryType(QueryType.TAG); queryStructReq.setQueryType(QueryType.TAG);
querySQLReq = queryStructReq.convert(queryStructReq); querySQLReq = queryStructReq.convert();
Assert.assertEquals( Assert.assertEquals(
"SELECT department, pv FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) " "SELECT department, pv FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
+ "ORDER BY uv LIMIT 2000", + "ORDER BY uv LIMIT 2000",

View File

@@ -32,6 +32,9 @@ public class EmbeddingConfig {
@Value("${embedding.metric.analyzeQuery.collection:solved_query_collection}") @Value("${embedding.metric.analyzeQuery.collection:solved_query_collection}")
private String metricAnalyzeQueryCollection; private String metricAnalyzeQueryCollection;
@Value("${text2sql.collection.name:text2dsl_agent_collection}")
private String text2sqlCollectionName;
@Value("${embedding.metric.analyzeQuery.nResult:5}") @Value("${embedding.metric.analyzeQuery.nResult:5}")
private int metricAnalyzeQueryResultNum; private int metricAnalyzeQueryResultNum;

View File

@@ -68,6 +68,9 @@ public class DateUtils {
} }
public static String getBeforeDate(int intervalDay, DatePeriodEnum datePeriodEnum) { public static String getBeforeDate(int intervalDay, DatePeriodEnum datePeriodEnum) {
if (Objects.isNull(datePeriodEnum)) {
return getBeforeDate(intervalDay);
}
SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_FORMAT); SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_FORMAT);
String currentDate = dateFormat.format(new Date()); String currentDate = dateFormat.format(new Date());
return getBeforeDate(currentDate, intervalDay, datePeriodEnum); return getBeforeDate(currentDate, intervalDay, datePeriodEnum);
@@ -101,7 +104,7 @@ public class DateUtils {
int month = tempDate.get(ChronoField.MONTH_OF_YEAR); int month = tempDate.get(ChronoField.MONTH_OF_YEAR);
int firstMonthOfQuarter = ((month - 1) / 3) * 3 + 1; int firstMonthOfQuarter = ((month - 1) / 3) * 3 + 1;
return tempDate.with(ChronoField.MONTH_OF_YEAR, firstMonthOfQuarter) return tempDate.with(ChronoField.MONTH_OF_YEAR, firstMonthOfQuarter)
.with(TemporalAdjusters.firstDayOfMonth()); .with(TemporalAdjusters.firstDayOfMonth());
}; };
result = result.with(firstDayOfQuarter); result = result.with(firstDayOfQuarter);
} }

View File

@@ -48,7 +48,9 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = null; InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = null;
Path filePath = getPersistentPath(collectionName); Path filePath = getPersistentPath(collectionName);
try { try {
if (Files.exists(filePath)) { EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
if (Files.exists(filePath) && !collectionName.equals(embeddingConfig.getMetaCollectionName())
&& !collectionName.equals(embeddingConfig.getText2sqlCollectionName())) {
embeddingStore = InMemoryEmbeddingStore.fromFile(filePath); embeddingStore = InMemoryEmbeddingStore.fromFile(filePath);
embeddingStore.entries = new CopyOnWriteArraySet<>(embeddingStore.entries); embeddingStore.entries = new CopyOnWriteArraySet<>(embeddingStore.entries);
log.info("embeddingStore reload from file:{}", filePath); log.info("embeddingStore reload from file:{}", filePath);

View File

@@ -1,5 +1,11 @@
package com.tencent.supersonic.common.util.jsqlparser; package com.tencent.supersonic.common.util.jsqlparser;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias; import net.sf.jsqlparser.expression.Alias;
@@ -9,6 +15,7 @@ import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.WhenClause; import net.sf.jsqlparser.expression.WhenClause;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
@@ -33,13 +40,6 @@ import net.sf.jsqlparser.statement.select.SubSelect;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/** /**
* Sql Parser Select Helper * Sql Parser Select Helper
*/ */
@@ -543,6 +543,10 @@ public class SqlSelectHelper {
getColumnFromExpr(expr.getLeftExpression(), columns); getColumnFromExpr(expr.getLeftExpression(), columns);
getColumnFromExpr(expr.getRightExpression(), columns); getColumnFromExpr(expr.getRightExpression(), columns);
} }
if (expression instanceof Parenthesis) {
Parenthesis expr = (Parenthesis) expression;
getColumnFromExpr(expr.getExpression(), columns);
}
} }
} }

View File

@@ -65,12 +65,19 @@ def get_pred_result():
questions=read_query(input_path) questions=read_query(input_path)
pred_sql_list=[] pred_sql_list=[]
default_sql="select * from tablea " default_sql="select * from tablea "
time_cost=[]
for i in range(0,len(questions)): for i in range(0,len(questions)):
start_time = time.time()
pred_sql=get_pred_sql(questions[i],url,agent_id,chat_id,authorization,default_sql) pred_sql=get_pred_sql(questions[i],url,agent_id,chat_id,authorization,default_sql)
end_time = time.time()
cost='%.3f'%(end_time-start_time)
time_cost.append(cost)
pred_sql_list.append(pred_sql) pred_sql_list.append(pred_sql)
time.sleep(60) time.sleep(60)
write_sql(pred_sql_path, pred_sql_list) write_sql(pred_sql_path, pred_sql_list)
return [float(cost) for cost in time_cost]
if __name__ == "__main__": if __name__ == "__main__":
print("pred") print("pred")

View File

@@ -482,7 +482,7 @@ def print_scores(scores, etype):
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
def evaluate(gold, predict, db_dir, etype, kmaps,query_path): def evaluate(gold, predict, db_dir, etype, kmaps,query_path,time_cost):
with open(gold) as f: with open(gold) as f:
glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
@@ -597,7 +597,11 @@ def evaluate(gold, predict, db_dir, etype, kmaps,query_path):
scores[level]['partial'][type_]['f1'] = \ scores[level]['partial'][type_]['f1'] = \
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
cost_dic = {}
cost_dic["max_time"] = max(time_cost)
cost_dic["min_time"] = min(time_cost)
cost_dic["avg_time"] = sum(time_cost)/len(time_cost)
log_list.append(cost_dic)
print_scores(scores, etype) print_scores(scores, etype)
print(scores['all']['exec']) print(scores['all']['exec'])
current_directory = os.path.dirname(os.path.abspath(__file__)) current_directory = os.path.dirname(os.path.abspath(__file__))
@@ -608,7 +612,6 @@ def evaluate(gold, predict, db_dir, etype, kmaps,query_path):
with open(file_name, 'w') as json_file: with open(file_name, 'w') as json_file:
json.dump(log_list, json_file, indent=4, ensure_ascii=False) json.dump(log_list, json_file, indent=4, ensure_ascii=False)
def eval_exec_match(db, p_str, g_str, pred, gold): def eval_exec_match(db, p_str, g_str, pred, gold):
""" """
return 1 if the values between prediction and gold are matching return 1 if the values between prediction and gold are matching
@@ -890,7 +893,7 @@ def build_foreign_key_map_from_json(table):
tables[entry['db_id']] = build_foreign_key_map(entry) tables[entry['db_id']] = build_foreign_key_map(entry)
return tables return tables
def get_evaluation_result(): def get_evaluation_result(time_cost):
current_directory = os.path.dirname(os.path.abspath(__file__)) current_directory = os.path.dirname(os.path.abspath(__file__))
config_file=current_directory+"/config/config.yaml" config_file=current_directory+"/config/config.yaml"
with open(config_file, 'r') as file: with open(config_file, 'r') as file:
@@ -905,7 +908,7 @@ def get_evaluation_result():
etype="exec" etype="exec"
kmaps = build_foreign_key_map_from_json(table) kmaps = build_foreign_key_map_from_json(table)
evaluate(gold, pred, db_dir, etype, kmaps,query_path) evaluate(gold, pred, db_dir, etype, kmaps,query_path,time_cost)
def remove_unused_file(): def remove_unused_file():
current_directory = os.path.dirname(os.path.abspath(__file__)) current_directory = os.path.dirname(os.path.abspath(__file__))
@@ -927,8 +930,8 @@ def remove_unused_file():
if __name__ == "__main__": if __name__ == "__main__":
build_table() build_table()
get_pred_result() time_cost=get_pred_result()
get_evaluation_result() get_evaluation_result(time_cost)
remove_unused_file() remove_unused_file()

View File

@@ -19,13 +19,15 @@ public class ModelDetail {
private String tableQuery; private String tableQuery;
private List<Identify> identifiers; private List<Identify> identifiers = Lists.newArrayList();
private List<Dim> dimensions; private List<Dim> dimensions = Lists.newArrayList();
private List<Measure> measures; private List<Measure> measures = Lists.newArrayList();
private List<Field> fields; private List<Field> fields = Lists.newArrayList();
private List<SqlVariable> sqlVariables = Lists.newArrayList();
public String getSqlQuery() { public String getSqlQuery() {
if (StringUtils.isNotBlank(sqlQuery) && sqlQuery.endsWith(";")) { if (StringUtils.isNotBlank(sqlQuery) && sqlQuery.endsWith(";")) {

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.api.pojo;
import lombok.Data; import lombok.Data;
import javax.validation.constraints.NotBlank; import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
@Data @Data
public class Param { public class Param {
@@ -10,7 +11,7 @@ public class Param {
@NotBlank(message = "Invald parameter name") @NotBlank(message = "Invald parameter name")
private String name; private String name;
@NotBlank(message = "Invalid parameter value") @NotNull(message = "Invalid parameter value")
private String value; private String value;
public Param() { public Param() {
@@ -21,14 +22,4 @@ public class Param {
this.value = value; this.value = value;
} }
@Override
public String toString() {
final StringBuilder sb = new StringBuilder("{");
sb.append("\"name\":\"")
.append(name).append('\"');
sb.append(",\"value\":\"")
.append(value).append('\"');
sb.append('}');
return sb.toString();
}
} }

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.api.pojo; package com.tencent.supersonic.headless.api.pojo;
import com.google.common.base.Objects; import com.google.common.base.Objects;
import com.google.common.collect.Lists;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
@@ -19,6 +18,7 @@ import java.util.List;
public class SchemaElement implements Serializable { public class SchemaElement implements Serializable {
private Long view; private Long view;
private Long model;
private Long id; private Long id;
private String name; private String name;
private String bizName; private String bizName;
@@ -52,7 +52,4 @@ public class SchemaElement implements Serializable {
return Objects.hashCode(view, id, name, bizName, type); return Objects.hashCode(view, id, name, bizName, type);
} }
public List<String> getModelNames() {
return Lists.newArrayList(name);
}
} }

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.headless.api.pojo;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.enums.VariableValueType;
import lombok.Data;
import java.util.List;
@Data
public class SqlVariable {
private String name;
private VariableValueType valueType;
private List<Object> defaultValues = Lists.newArrayList();
}

View File

@@ -0,0 +1,7 @@
package com.tencent.supersonic.headless.api.pojo.enums;
public enum VariableValueType {
STRING,
NUMBER,
EXPR
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.api.pojo.request; package com.tencent.supersonic.headless.api.pojo.request;
import com.tencent.supersonic.common.pojo.DateConf;
import java.util.List; import java.util.List;
import lombok.Data; import lombok.Data;
import lombok.ToString; import lombok.ToString;
@@ -18,4 +19,8 @@ public class QueryMetricReq {
private List<String> dimensionNames; private List<String> dimensionNames;
private DateConf dateInfo = new DateConf();
private Long limit = 2000L;
} }

View File

@@ -44,6 +44,7 @@ import java.util.stream.Collectors;
@Data @Data
@Slf4j @Slf4j
public class QueryStructReq extends SemanticQueryReq { public class QueryStructReq extends SemanticQueryReq {
private List<String> groups = new ArrayList<>(); private List<String> groups = new ArrayList<>();
private List<Aggregator> aggregators = new ArrayList<>(); private List<Aggregator> aggregators = new ArrayList<>();
private List<Order> orders = new ArrayList<>(); private List<Order> orders = new ArrayList<>();
@@ -151,28 +152,27 @@ public class QueryStructReq extends SemanticQueryReq {
return sb.toString(); return sb.toString();
} }
public QuerySqlReq convert(QueryStructReq queryStructReq) { public QuerySqlReq convert() {
return convert(queryStructReq, false); return convert(false);
} }
/** /**
* convert queryStructReq to QueryS2SQLReq * convert queryStructReq to QueryS2SQLReq
* *
* @param queryStructReq
* @return * @return
*/ */
public QuerySqlReq convert(QueryStructReq queryStructReq, boolean isBizName) { public QuerySqlReq convert(boolean isBizName) {
String sql = null; String sql = null;
try { try {
sql = buildSql(queryStructReq, isBizName); sql = buildSql(this, isBizName);
} catch (Exception e) { } catch (Exception e) {
log.error("buildSql error", e); log.error("buildSql error", e);
} }
QuerySqlReq result = new QuerySqlReq(); QuerySqlReq result = new QuerySqlReq();
result.setSql(sql); result.setSql(sql);
result.setViewId(queryStructReq.getViewId()); result.setViewId(this.getViewId());
result.setModelIds(queryStructReq.getModelIdSet()); result.setModelIds(this.getModelIdSet());
result.setParams(new ArrayList<>()); result.setParams(new ArrayList<>());
return result; return result;
} }

View File

@@ -16,8 +16,6 @@ public class ViewReq extends SchemaItem {
private String alias; private String alias;
private String filterSql;
private QueryConfig queryConfig; private QueryConfig queryConfig;
private List<String> admins; private List<String> admins;

View File

@@ -30,6 +30,8 @@ public class MetricResp extends SchemaItem {
private Long domainId; private Long domainId;
private String modelBizName;
private String modelName; private String modelName;
//ATOMIC DERIVED //ATOMIC DERIVED

View File

@@ -1,12 +1,14 @@
package com.tencent.supersonic.headless.api.pojo.response; package com.tencent.supersonic.headless.api.pojo.response;
import com.google.common.collect.Sets;
import com.tencent.supersonic.common.pojo.ModelRela; import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.headless.api.pojo.Identify; import java.util.HashSet;
import java.util.List;
import java.util.Set;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.apache.commons.collections4.CollectionUtils;
import java.util.List;
@Data @Data
@AllArgsConstructor @AllArgsConstructor
@@ -17,18 +19,17 @@ public class ModelSchemaResp extends ModelResp {
private List<DimSchemaResp> dimensions; private List<DimSchemaResp> dimensions;
private List<ModelRela> modelRelas; private List<ModelRela> modelRelas;
public DimSchemaResp getPrimaryKey() { public Set<Long> getModelClusterSet() {
Identify identify = getPrimaryIdentify(); if (CollectionUtils.isEmpty(this.modelRelas)) {
if (identify == null) { return Sets.newHashSet();
return null; } else {
Set<Long> modelClusterSet = new HashSet();
this.modelRelas.forEach((modelRela) -> {
modelClusterSet.add(modelRela.getToModelId());
modelClusterSet.add(modelRela.getFromModelId());
});
return modelClusterSet;
} }
for (DimSchemaResp dimension : dimensions) {
if (identify.getBizName().equals(dimension.getBizName())) {
dimension.setEntityAlias(identify.getEntityNames());
return dimension;
}
}
return null;
} }
} }

View File

@@ -24,6 +24,11 @@
<version>${lombok.version}</version> <version>${lombok.version}</version>
<scope>provided</scope> <scope>provided</scope>
</dependency> </dependency>
<dependency>
<groupId>org.antlr</groupId>
<artifactId>ST4</artifactId>
<version>${st.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.springframework</groupId> <groupId>org.springframework</groupId>
@@ -106,6 +111,28 @@
<groupId>org.apache.hadoop</groupId> <groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-hdfs</artifactId> <artifactId>hadoop-hdfs</artifactId>
<version>${hadoop.version}</version> <version>${hadoop.version}</version>
<exclusions>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
</exclusion>
<exclusion>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.zookeeper</groupId>
<artifactId>zookeeper</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.curator</groupId>
<artifactId>*</artifactId>
</exclusion>
<exclusion>
<groupId>javax.servlet</groupId>
<artifactId>servlet-api</artifactId>
</exclusion>
</exclusions>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.apache.commons</groupId> <groupId>org.apache.commons</groupId>

View File

@@ -6,8 +6,15 @@ import com.hankcs.hanlp.corpus.tag.Nature;
import com.hankcs.hanlp.dictionary.CoreDictionary; import com.hankcs.hanlp.dictionary.CoreDictionary;
import com.hankcs.hanlp.seg.common.Term; import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@@ -17,11 +24,6 @@ import java.util.TreeMap;
import java.util.TreeSet; import java.util.TreeSet;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
@Slf4j @Slf4j
public class SearchService { public class SearchService {
@@ -39,14 +41,14 @@ public class SearchService {
* @param key * @param key
* @return * @return
*/ */
public static List<HanlpMapResult> prefixSearch(String key, int limit, Set<Long> detectModelIds) { public static List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToViewIds) {
return prefixSearch(key, limit, trie, detectModelIds); return prefixSearch(key, limit, trie, modelIdToViewIds);
} }
public static List<HanlpMapResult> prefixSearch(String key, int limit, BinTrie<List<String>> binTrie, public static List<HanlpMapResult> prefixSearch(String key, int limit, BinTrie<List<String>> binTrie,
Set<Long> detectModelIds) { Map<Long, List<Long>> modelIdToViewIds) {
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, detectModelIds); Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, modelIdToViewIds.keySet());
return result.stream().map( List<HanlpMapResult> hanlpMapResults = result.stream().map(
entry -> { entry -> {
String name = entry.getKey().replace("#", " "); String name = entry.getKey().replace("#", " ");
return new HanlpMapResult(name, entry.getValue(), key); return new HanlpMapResult(name, entry.getValue(), key);
@@ -54,6 +56,13 @@ public class SearchService {
).sorted((a, b) -> -(b.getName().length() - a.getName().length())) ).sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.limit(SEARCH_SIZE) .limit(SEARCH_SIZE)
.collect(Collectors.toList()); .collect(Collectors.toList());
for (HanlpMapResult hanlpMapResult : hanlpMapResults) {
List<String> natures = hanlpMapResult.getNatures().stream()
.map(nature -> NatureHelper.changeModel2View(nature, modelIdToViewIds))
.flatMap(Collection::stream).collect(Collectors.toList());
hanlpMapResult.setNatures(natures);
}
return hanlpMapResults;
} }
/*** /***

View File

@@ -38,11 +38,11 @@ public class DimensionWordBuilder extends BaseWordBuilder {
private DictWord getOnwWordNature(String word, SchemaElement schemaElement, boolean isSuffix) { private DictWord getOnwWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
DictWord dictWord = new DictWord(); DictWord dictWord = new DictWord();
dictWord.setWord(word); dictWord.setWord(word);
Long viewId = schemaElement.getView(); Long modelId = schemaElement.getModel();
String nature = DictWordType.NATURE_SPILT + viewId + DictWordType.NATURE_SPILT + schemaElement.getId() String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
+ DictWordType.DIMENSION.getType(); + DictWordType.DIMENSION.getType();
if (isSuffix) { if (isSuffix) {
nature = DictWordType.NATURE_SPILT + viewId + DictWordType.NATURE_SPILT + schemaElement.getId() nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
+ DictWordType.SUFFIX.getType() + DictWordType.DIMENSION.getType(); + DictWordType.SUFFIX.getType() + DictWordType.DIMENSION.getType();
} }
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature)); dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));

View File

@@ -27,8 +27,8 @@ public class EntityWordBuilder extends BaseWordBuilder {
return result; return result;
} }
Long view = schemaElement.getView(); Long modelId = schemaElement.getModel();
String nature = DictWordType.NATURE_SPILT + view + DictWordType.NATURE_SPILT + schemaElement.getId() String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
+ DictWordType.ENTITY.getType(); + DictWordType.ENTITY.getType();
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) { if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {

View File

@@ -38,11 +38,11 @@ public class MetricWordBuilder extends BaseWordBuilder {
private DictWord getOnwWordNature(String word, SchemaElement schemaElement, boolean isSuffix) { private DictWord getOnwWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
DictWord dictWord = new DictWord(); DictWord dictWord = new DictWord();
dictWord.setWord(word); dictWord.setWord(word);
Long viewId = schemaElement.getView(); Long modelId = schemaElement.getModel();
String nature = DictWordType.NATURE_SPILT + viewId + DictWordType.NATURE_SPILT + schemaElement.getId() String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
+ DictWordType.METRIC.getType(); + DictWordType.METRIC.getType();
if (isSuffix) { if (isSuffix) {
nature = DictWordType.NATURE_SPILT + viewId + DictWordType.NATURE_SPILT + schemaElement.getId() nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
+ DictWordType.SUFFIX.getType() + DictWordType.METRIC.getType(); + DictWordType.SUFFIX.getType() + DictWordType.METRIC.getType();
} }
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature)); dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));

View File

@@ -27,8 +27,8 @@ public class ValueWordBuilder extends BaseWordBuilder {
schemaElement.getAlias().stream().forEach(value -> { schemaElement.getAlias().stream().forEach(value -> {
DictWord dictWord = new DictWord(); DictWord dictWord = new DictWord();
Long viewId = schemaElement.getView(); Long modelId = schemaElement.getModel();
String nature = DictWordType.NATURE_SPILT + viewId + DictWordType.NATURE_SPILT + schemaElement.getId(); String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId();
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature)); dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
dictWord.setWord(value); dictWord.setWord(value);
result.add(dictWord); result.add(dictWord);

View File

@@ -1,34 +1,36 @@
package com.tencent.supersonic.headless.core.knowledge.helper; package com.tencent.supersonic.headless.core.knowledge.helper;
import static com.hankcs.hanlp.HanLP.Config.CustomDictionaryPath; import com.google.common.collect.Lists;
import com.hankcs.hanlp.HanLP; import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.corpus.tag.Nature; import com.hankcs.hanlp.corpus.tag.Nature;
import com.hankcs.hanlp.dictionary.CoreDictionary; import com.hankcs.hanlp.dictionary.CoreDictionary;
import com.hankcs.hanlp.dictionary.DynamicCustomDictionary; import com.hankcs.hanlp.dictionary.DynamicCustomDictionary;
import com.hankcs.hanlp.seg.Segment; import com.hankcs.hanlp.seg.Segment;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.DictWord; import com.tencent.supersonic.headless.core.knowledge.DictWord;
import com.tencent.supersonic.headless.core.knowledge.HadoopFileIOAdapter; import com.tencent.supersonic.headless.core.knowledge.HadoopFileIOAdapter;
import com.tencent.supersonic.headless.core.knowledge.MapResult; import com.tencent.supersonic.headless.core.knowledge.MapResult;
import com.tencent.supersonic.headless.core.knowledge.MultiCustomDictionary; import com.tencent.supersonic.headless.core.knowledge.MultiCustomDictionary;
import com.tencent.supersonic.headless.core.knowledge.SearchService; import com.tencent.supersonic.headless.core.knowledge.SearchService;
import com.tencent.supersonic.common.pojo.enums.DictWordType; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ResourceUtils;
import java.io.File; import java.io.File;
import java.io.FileNotFoundException; import java.io.FileNotFoundException;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import com.hankcs.hanlp.seg.common.Term; import static com.hankcs.hanlp.HanLP.Config.CustomDictionaryPath;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ResourceUtils;
/** /**
* HanLP helper * HanLP helper
@@ -212,18 +214,25 @@ public class HanlpHelper {
} }
} }
public static List<com.tencent.supersonic.headless.api.pojo.response.S2Term> getTerms(String text) { public static List<S2Term> getTerms(String text, Map<Long, List<Long>> modelIdToViewIds) {
return getSegment().seg(text.toLowerCase()).stream() return getSegment().seg(text.toLowerCase()).stream()
.filter(term -> term.getNature().startsWith(DictWordType.NATURE_SPILT)) .filter(term -> term.getNature().startsWith(DictWordType.NATURE_SPILT))
.map(term -> transform2ApiTerm(term)) .map(term -> transform2ApiTerm(term, modelIdToViewIds))
.flatMap(Collection::stream)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
public static S2Term transform2ApiTerm(Term term) { public static List<S2Term> transform2ApiTerm(Term term, Map<Long, List<Long>> modelIdToViewIds) {
S2Term knowledgeTerm = new S2Term(); List<S2Term> s2Terms = Lists.newArrayList();
BeanUtils.copyProperties(term, knowledgeTerm); List<String> natures = NatureHelper.changeModel2View(String.valueOf(term.getNature()), modelIdToViewIds);
knowledgeTerm.setFrequency(term.getFrequency()); for (String nature : natures) {
return knowledgeTerm; S2Term s2Term = new S2Term();
BeanUtils.copyProperties(term, s2Term);
s2Term.setNature(Nature.create(nature));
s2Term.setFrequency(term.getFrequency());
s2Terms.add(s2Term);
}
return s2Terms;
} }
} }

View File

@@ -1,12 +1,14 @@
package com.tencent.supersonic.headless.core.knowledge.helper; package com.tencent.supersonic.headless.core.knowledge.helper;
import com.google.common.collect.Lists;
import com.hankcs.hanlp.corpus.tag.Nature; import com.hankcs.hanlp.corpus.tag.Nature;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.ViewInfoStat; import com.tencent.supersonic.headless.core.knowledge.ViewInfoStat;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
@@ -81,6 +83,43 @@ public class NatureHelper {
return null; return null;
} }
private static Long getModelId(String nature) {
try {
String[] split = nature.split(DictWordType.NATURE_SPILT);
if (split.length <= 1) {
return null;
}
return Long.valueOf(split[1]);
} catch (NumberFormatException e) {
log.error("", e);
}
return null;
}
private static Nature changeModel2View(String nature, Long viewId) {
try {
String[] split = nature.split(DictWordType.NATURE_SPILT);
if (split.length <= 1) {
return null;
}
split[1] = String.valueOf(viewId);
return Nature.create(StringUtils.join(split, DictWordType.NATURE_SPILT));
} catch (NumberFormatException e) {
log.error("", e);
}
return null;
}
public static List<String> changeModel2View(String nature, Map<Long, List<Long>> modelIdToViewIds) {
Long modelId = getModelId(nature);
List<Long> viewIds = modelIdToViewIds.get(modelId);
if (CollectionUtils.isEmpty(viewIds)) {
return Lists.newArrayList();
}
return viewIds.stream().map(viewId -> String.valueOf(changeModel2View(nature, viewId)))
.collect(Collectors.toList());
}
public static boolean isDimensionValueViewId(String nature) { public static boolean isDimensionValueViewId(String nature) {
if (StringUtils.isEmpty(nature)) { if (StringUtils.isEmpty(nature)) {
return false; return false;

View File

@@ -0,0 +1,120 @@
package com.tencent.supersonic.headless.core.utils;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.headless.api.pojo.Param;
import com.tencent.supersonic.headless.api.pojo.SqlVariable;
import com.tencent.supersonic.headless.api.pojo.enums.VariableValueType;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import org.stringtemplate.v4.ST;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static com.tencent.supersonic.common.pojo.Constants.COMMA;
import static com.tencent.supersonic.common.pojo.Constants.EMPTY;
@Slf4j
public class SqlVariableParseUtils {
public static final String REG_SENSITIVE_SQL = "drop\\s|alter\\s|grant\\s|insert\\s|replace\\s|delete\\s|"
+ "truncate\\s|update\\s|remove\\s";
public static final Pattern PATTERN_SENSITIVE_SQL = Pattern.compile(REG_SENSITIVE_SQL);
public static final String APOSTROPHE = "'";
private static final char delimiter = '$';
public static String parse(String sql, List<SqlVariable> sqlVariables, List<Param> params) {
if (CollectionUtils.isEmpty(sqlVariables)) {
return sql;
}
Map<String, Object> queryParams = new HashMap<>();
//1. handle default variable value
sqlVariables.forEach(variable -> {
queryParams.put(variable.getName().trim(),
getValues(variable.getValueType(), variable.getDefaultValues()));
});
//override by variable param
if (!CollectionUtils.isEmpty(params)) {
Map<String, List<SqlVariable>> map =
sqlVariables.stream().collect(Collectors.groupingBy(SqlVariable::getName));
params.forEach(p -> {
if (map.containsKey(p.getName())) {
List<SqlVariable> list = map.get(p.getName());
if (!CollectionUtils.isEmpty(list)) {
SqlVariable v = list.get(list.size() - 1);
queryParams.put(p.getName().trim(), getValue(v.getValueType(), p.getValue()));
}
}
});
}
queryParams.forEach((k, v) -> {
if (v instanceof List && ((List) v).size() > 0) {
v = ((List) v).stream().collect(Collectors.joining(COMMA)).toString();
}
queryParams.put(k, v);
});
ST st = new ST(sql, delimiter, delimiter);
if (!CollectionUtils.isEmpty(queryParams)) {
queryParams.forEach(st::add);
}
return st.render();
}
public static List<String> getValues(VariableValueType valueType, List<Object> values) {
if (CollectionUtils.isEmpty(values)) {
return new ArrayList<>();
}
if (null != valueType) {
switch (valueType) {
case STRING:
return values.stream().map(String::valueOf)
.map(s -> s.startsWith(APOSTROPHE) && s.endsWith(APOSTROPHE)
? s : String.join(EMPTY, APOSTROPHE, s, APOSTROPHE))
.collect(Collectors.toList());
case EXPR:
values.stream().map(String::valueOf).forEach(SqlVariableParseUtils::checkSensitiveSql);
return values.stream().map(String::valueOf).collect(Collectors.toList());
case NUMBER:
return values.stream().map(String::valueOf).collect(Collectors.toList());
default:
return values.stream().map(String::valueOf).collect(Collectors.toList());
}
}
return values.stream().map(String::valueOf).collect(Collectors.toList());
}
public static Object getValue(VariableValueType valueType, String value) {
if (!StringUtils.isEmpty(value)) {
if (null != valueType) {
switch (valueType) {
case STRING:
return String.join(EMPTY, value.startsWith(APOSTROPHE) ? EMPTY : APOSTROPHE,
value, value.endsWith(APOSTROPHE) ? EMPTY : APOSTROPHE);
case NUMBER:
case EXPR:
default:
return value;
}
}
}
return value;
}
public static void checkSensitiveSql(String sql) {
Matcher matcher = PATTERN_SENSITIVE_SQL.matcher(sql.toLowerCase());
if (matcher.find()) {
String group = matcher.group();
log.warn("Sensitive SQL operations are not allowed: {}", group.toUpperCase());
throw new InvalidArgumentException("Sensitive SQL operations are not allowed: " + group.toUpperCase());
}
}
}

View File

@@ -11,11 +11,13 @@ import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
import com.tencent.supersonic.headless.core.pojo.QueryStatement; import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import com.tencent.supersonic.headless.server.service.MetricService;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around; import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect; import org.aspectj.lang.annotation.Aspect;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.Collection; import java.util.Collection;
@@ -28,6 +30,9 @@ import java.util.stream.Collectors;
@Slf4j @Slf4j
public class MetricDrillDownChecker { public class MetricDrillDownChecker {
@Autowired
private MetricService metricService;
@Around("execution(* com.tencent.supersonic.headless.core.parser.QueryParser.parse(..))") @Around("execution(* com.tencent.supersonic.headless.core.parser.QueryParser.parse(..))")
public Object doAround(ProceedingJoinPoint joinPoint) throws Throwable { public Object doAround(ProceedingJoinPoint joinPoint) throws Throwable {
Object[] objects = joinPoint.getArgs(); Object[] objects = joinPoint.getArgs();
@@ -52,7 +57,7 @@ public class MetricDrillDownChecker {
List<DimensionResp> necessaryDimensions = getNecessaryDimensions(metric, semanticSchemaResp); List<DimensionResp> necessaryDimensions = getNecessaryDimensions(metric, semanticSchemaResp);
List<DimensionResp> dimensionsMissing = getNecessaryDimensionMissing(necessaryDimensions, dimensionFields); List<DimensionResp> dimensionsMissing = getNecessaryDimensionMissing(necessaryDimensions, dimensionFields);
if (!CollectionUtils.isEmpty(dimensionsMissing)) { if (!CollectionUtils.isEmpty(dimensionsMissing)) {
String errMsg = String.format("指标:%s 缺失必要维度:%s", metric.getName(), String errMsg = String.format("指标:%s 缺失必要下钻维度:%s", metric.getName(),
dimensionsMissing.stream().map(DimensionResp::getName).collect(Collectors.toList())); dimensionsMissing.stream().map(DimensionResp::getName).collect(Collectors.toList()));
throw new InvalidArgumentException(errMsg); throw new InvalidArgumentException(errMsg);
} }
@@ -92,8 +97,9 @@ public class MetricDrillDownChecker {
return true; return true;
} }
List<String> relateDimensions = metricResps.stream() List<String> relateDimensions = metricResps.stream()
.filter(metric -> !CollectionUtils.isEmpty(metric.getDrillDownDimensions())) .map(this::getDrillDownDimensions)
.map(metric -> metric.getDrillDownDimensions().stream() .filter(drillDownDimensions -> !CollectionUtils.isEmpty(drillDownDimensions))
.map(drillDownDimensions -> drillDownDimensions.stream()
.map(DrillDownDimension::getDimensionId).collect(Collectors.toList())) .map(DrillDownDimension::getDimensionId).collect(Collectors.toList()))
.flatMap(Collection::stream) .flatMap(Collection::stream)
.map(id -> convertDimensionIdToBizName(id, semanticSchemaResp)) .map(id -> convertDimensionIdToBizName(id, semanticSchemaResp))
@@ -111,7 +117,7 @@ public class MetricDrillDownChecker {
if (metric == null) { if (metric == null) {
return Lists.newArrayList(); return Lists.newArrayList();
} }
List<DrillDownDimension> drillDownDimensions = metric.getDrillDownDimensions(); List<DrillDownDimension> drillDownDimensions = getDrillDownDimensions(metric);
if (CollectionUtils.isEmpty(drillDownDimensions)) { if (CollectionUtils.isEmpty(drillDownDimensions)) {
return Lists.newArrayList(); return Lists.newArrayList();
} }
@@ -147,4 +153,8 @@ public class MetricDrillDownChecker {
return dimension.getBizName(); return dimension.getBizName();
} }
private List<DrillDownDimension> getDrillDownDimensions(MetricResp metricResp) {
return metricService.getDrillDownDimension(metricResp.getId());
}
} }

View File

@@ -30,7 +30,7 @@ import java.util.stream.Collectors;
@Slf4j @Slf4j
public class ModelYamlManager { public class ModelYamlManager {
public static DataModelYamlTpl convert2YamlObj(ModelResp modelResp, DatabaseResp databaseResp) { public static synchronized DataModelYamlTpl convert2YamlObj(ModelResp modelResp, DatabaseResp databaseResp) {
ModelDetail modelDetail = modelResp.getModelDetail(); ModelDetail modelDetail = modelResp.getModelDetail();
DbAdaptor engineAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType()); DbAdaptor engineAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType());
SysTimeDimensionBuilder.addSysTimeDimension(modelDetail.getDimensions(), engineAdaptor); SysTimeDimensionBuilder.addSysTimeDimension(modelDetail.getDimensions(), engineAdaptor);

View File

@@ -35,8 +35,6 @@ public class ViewDO {
private String updatedBy; private String updatedBy;
private String filterSql;
private String queryConfig; private String queryConfig;
private String admin; private String admin;

View File

@@ -2,9 +2,9 @@ package com.tencent.supersonic.headless.server.persistence.mapper;
import com.tencent.supersonic.headless.server.persistence.dataobject.DimensionDO; import com.tencent.supersonic.headless.server.persistence.dataobject.DimensionDO;
import com.tencent.supersonic.headless.server.pojo.DimensionFilter; import com.tencent.supersonic.headless.server.pojo.DimensionFilter;
import org.apache.ibatis.annotations.Mapper; import com.tencent.supersonic.headless.server.pojo.DimensionsFilter;
import java.util.List; import java.util.List;
import org.apache.ibatis.annotations.Mapper;
@Mapper @Mapper
public interface DimensionDOCustomMapper { public interface DimensionDOCustomMapper {
@@ -16,4 +16,7 @@ public interface DimensionDOCustomMapper {
void batchUpdateStatus(List<DimensionDO> dimensionDOS); void batchUpdateStatus(List<DimensionDO> dimensionDOS);
List<DimensionDO> query(DimensionFilter dimensionFilter); List<DimensionDO> query(DimensionFilter dimensionFilter);
List<DimensionDO> queryDimensions(DimensionsFilter dimensionsFilter);
} }

View File

@@ -2,9 +2,9 @@ package com.tencent.supersonic.headless.server.persistence.mapper;
import com.tencent.supersonic.headless.server.persistence.dataobject.MetricDO; import com.tencent.supersonic.headless.server.persistence.dataobject.MetricDO;
import com.tencent.supersonic.headless.server.pojo.MetricFilter; import com.tencent.supersonic.headless.server.pojo.MetricFilter;
import org.apache.ibatis.annotations.Mapper; import com.tencent.supersonic.headless.server.pojo.MetricsFilter;
import java.util.List; import java.util.List;
import org.apache.ibatis.annotations.Mapper;
@Mapper @Mapper
public interface MetricDOCustomMapper { public interface MetricDOCustomMapper {
@@ -15,4 +15,6 @@ public interface MetricDOCustomMapper {
List<MetricDO> query(MetricFilter metricFilter); List<MetricDO> query(MetricFilter metricFilter);
List<MetricDO> queryMetrics(MetricsFilter metricsFilter);
} }

View File

@@ -4,6 +4,7 @@ package com.tencent.supersonic.headless.server.persistence.repository;
import com.tencent.supersonic.headless.server.persistence.dataobject.DimensionDO; import com.tencent.supersonic.headless.server.persistence.dataobject.DimensionDO;
import com.tencent.supersonic.headless.server.pojo.DimensionFilter; import com.tencent.supersonic.headless.server.pojo.DimensionFilter;
import com.tencent.supersonic.headless.server.pojo.DimensionsFilter;
import java.util.List; import java.util.List;
public interface DimensionRepository { public interface DimensionRepository {
@@ -19,4 +20,6 @@ public interface DimensionRepository {
DimensionDO getDimensionById(Long id); DimensionDO getDimensionById(Long id);
List<DimensionDO> getDimension(DimensionFilter dimensionFilter); List<DimensionDO> getDimension(DimensionFilter dimensionFilter);
List<DimensionDO> getDimensions(DimensionsFilter dimensionsFilter);
} }

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.MetricDO;
import com.tencent.supersonic.headless.server.persistence.dataobject.MetricQueryDefaultConfigDO; import com.tencent.supersonic.headless.server.persistence.dataobject.MetricQueryDefaultConfigDO;
import com.tencent.supersonic.headless.server.pojo.MetricFilter; import com.tencent.supersonic.headless.server.pojo.MetricFilter;
import com.tencent.supersonic.headless.server.pojo.MetricsFilter;
import java.util.List; import java.util.List;
public interface MetricRepository { public interface MetricRepository {
@@ -21,6 +22,8 @@ public interface MetricRepository {
List<MetricDO> getMetric(MetricFilter metricFilter); List<MetricDO> getMetric(MetricFilter metricFilter);
List<MetricDO> getMetrics(MetricsFilter metricsFilter);
void saveDefaultQueryConfig(MetricQueryDefaultConfigDO defaultConfigDO); void saveDefaultQueryConfig(MetricQueryDefaultConfigDO defaultConfigDO);
void updateDefaultQueryConfig(MetricQueryDefaultConfigDO defaultConfigDO); void updateDefaultQueryConfig(MetricQueryDefaultConfigDO defaultConfigDO);

View File

@@ -5,9 +5,9 @@ import com.tencent.supersonic.headless.server.persistence.mapper.DimensionDOCust
import com.tencent.supersonic.headless.server.persistence.mapper.DimensionDOMapper; import com.tencent.supersonic.headless.server.persistence.mapper.DimensionDOMapper;
import com.tencent.supersonic.headless.server.persistence.repository.DimensionRepository; import com.tencent.supersonic.headless.server.persistence.repository.DimensionRepository;
import com.tencent.supersonic.headless.server.pojo.DimensionFilter; import com.tencent.supersonic.headless.server.pojo.DimensionFilter;
import org.springframework.stereotype.Service; import com.tencent.supersonic.headless.server.pojo.DimensionsFilter;
import java.util.List; import java.util.List;
import org.springframework.stereotype.Service;
@Service @Service
public class DimensionRepositoryImpl implements DimensionRepository { public class DimensionRepositoryImpl implements DimensionRepository {
@@ -52,4 +52,9 @@ public class DimensionRepositoryImpl implements DimensionRepository {
return dimensionDOCustomMapper.query(dimensionFilter); return dimensionDOCustomMapper.query(dimensionFilter);
} }
@Override
public List<DimensionDO> getDimensions(DimensionsFilter dimensionsFilter) {
return dimensionDOCustomMapper.queryDimensions(dimensionsFilter);
}
} }

View File

@@ -8,9 +8,9 @@ import com.tencent.supersonic.headless.server.persistence.mapper.MetricDOMapper;
import com.tencent.supersonic.headless.server.persistence.mapper.MetricQueryDefaultConfigDOMapper; import com.tencent.supersonic.headless.server.persistence.mapper.MetricQueryDefaultConfigDOMapper;
import com.tencent.supersonic.headless.server.persistence.repository.MetricRepository; import com.tencent.supersonic.headless.server.persistence.repository.MetricRepository;
import com.tencent.supersonic.headless.server.pojo.MetricFilter; import com.tencent.supersonic.headless.server.pojo.MetricFilter;
import org.springframework.stereotype.Component; import com.tencent.supersonic.headless.server.pojo.MetricsFilter;
import java.util.List; import java.util.List;
import org.springframework.stereotype.Component;
@Component @Component
@@ -62,6 +62,11 @@ public class MetricRepositoryImpl implements MetricRepository {
return metricDOCustomMapper.query(metricFilter); return metricDOCustomMapper.query(metricFilter);
} }
@Override
public List<MetricDO> getMetrics(MetricsFilter metricsFilter) {
return metricDOCustomMapper.queryMetrics(metricsFilter);
}
@Override @Override
public void saveDefaultQueryConfig(MetricQueryDefaultConfigDO defaultConfigDO) { public void saveDefaultQueryConfig(MetricQueryDefaultConfigDO defaultConfigDO) {
metricQueryDefaultConfigDOMapper.insert(defaultConfigDO); metricQueryDefaultConfigDOMapper.insert(defaultConfigDO);

View File

@@ -0,0 +1,15 @@
package com.tencent.supersonic.headless.server.pojo;
import java.util.List;
import lombok.Data;
@Data
public class DimensionsFilter {
private List<Long> modelIds;
private List<Long> dimensionIds;
private List<String> dimensionNames;
}

View File

@@ -0,0 +1,15 @@
package com.tencent.supersonic.headless.server.pojo;
import java.util.List;
import lombok.Data;
@Data
public class MetricsFilter {
private List<Long> modelIds;
private List<Long> metricIds;
private List<String> metricNames;
}

View File

@@ -0,0 +1,60 @@
package com.tencent.supersonic.headless.server.pojo;
import lombok.Data;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
@Data
public class ModelCluster {
private static final String split = "_";
private Set<Long> modelIds = new LinkedHashSet<>();
private Set<String> modelNames = new LinkedHashSet<>();
private String key;
private String name;
public static ModelCluster build(Set<Long> modelIds) {
ModelCluster modelCluster = new ModelCluster();
modelCluster.setModelIds(modelIds);
modelCluster.setKey(StringUtils.join(modelIds, split));
return modelCluster;
}
public static ModelCluster build(String key) {
ModelCluster modelCluster = new ModelCluster();
modelCluster.setModelIds(getModelIdFromKey(key));
modelCluster.setKey(key);
return modelCluster;
}
public void buildName(Map<Long, String> modelNameMap) {
modelNames = modelNameMap.entrySet().stream().filter(entry ->
modelIds.contains(entry.getKey())).map(Map.Entry::getValue)
.collect(Collectors.toSet());
name = String.join(split, modelNames);
}
public static Set<Long> getModelIdFromKey(String key) {
return Arrays.stream(key.split(split))
.map(Long::parseLong).collect(Collectors.toSet());
}
public Long getFirstModel() {
if (CollectionUtils.isEmpty(modelIds)) {
return -1L;
}
return new ArrayList<>(modelIds).get(0);
}
}

View File

@@ -8,7 +8,6 @@ import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.server.pojo.DatabaseParameter; import com.tencent.supersonic.headless.server.pojo.DatabaseParameter;
import com.tencent.supersonic.headless.server.service.DatabaseService; import com.tencent.supersonic.headless.server.service.DatabaseService;
import java.util.Map;
import org.springframework.web.bind.annotation.DeleteMapping; import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PathVariable;
@@ -20,6 +19,7 @@ import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.util.List; import java.util.List;
import java.util.Map;
@RestController @RestController
@RequestMapping("/api/semantic/database") @RequestMapping("/api/semantic/database")
@@ -49,8 +49,10 @@ public class DatabaseController {
} }
@GetMapping("/{id}") @GetMapping("/{id}")
public DatabaseResp getDatabase(@PathVariable("id") Long id) { public DatabaseResp getDatabase(@PathVariable("id") Long id, HttpServletRequest request,
return databaseService.getDatabase(id); HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return databaseService.getDatabase(id, user);
} }
@GetMapping("/getDatabaseList") @GetMapping("/getDatabaseList")

View File

@@ -54,7 +54,7 @@ public class QueryController {
HttpServletRequest request, HttpServletRequest request,
HttpServletResponse response) throws Exception { HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response); User user = UserHolder.findUser(request, response);
QuerySqlReq querySqlReq = queryStructReq.convert(queryStructReq, true); QuerySqlReq querySqlReq = queryStructReq.convert(true);
return queryService.queryByReq(querySqlReq, user); return queryService.queryByReq(querySqlReq, user);
} }

View File

@@ -15,6 +15,10 @@ public interface DatabaseService {
SemanticQueryResp executeSql(String sql, Long id, User user); SemanticQueryResp executeSql(String sql, Long id, User user);
DatabaseResp getDatabase(Long id, User user);
DatabaseResp getDatabase(Long id);
Map<String, List<DatabaseParameter>> getDatabaseParameters(); Map<String, List<DatabaseParameter>> getDatabaseParameters();
boolean testConnect(DatabaseReq databaseReq, User user); boolean testConnect(DatabaseReq databaseReq, User user);
@@ -25,8 +29,6 @@ public interface DatabaseService {
void deleteDatabase(Long databaseId); void deleteDatabase(Long databaseId);
DatabaseResp getDatabase(Long id);
SemanticQueryResp getDbNames(Long id); SemanticQueryResp getDbNames(Long id);
SemanticQueryResp getTables(Long id, String db); SemanticQueryResp getTables(Long id, String db);

View File

@@ -8,8 +8,8 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionReq;
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq; import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq; import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.server.pojo.DimensionsFilter;
import com.tencent.supersonic.headless.server.pojo.MetaFilter; import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import java.util.List; import java.util.List;
public interface DimensionService { public interface DimensionService {
@@ -30,6 +30,8 @@ public interface DimensionService {
PageInfo<DimensionResp> queryDimension(PageDimensionReq pageDimensionReq); PageInfo<DimensionResp> queryDimension(PageDimensionReq pageDimensionReq);
List<DimensionResp> queryDimensions(DimensionsFilter dimensionsFilter);
void deleteDimension(Long id, User user); void deleteDimension(Long id, User user);
List<DimensionResp> getDimensionInModelCluster(Long modelId); List<DimensionResp> getDimensionInModelCluster(Long modelId);

View File

@@ -11,7 +11,7 @@ import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq; import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp; import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.server.pojo.MetaFilter; import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.pojo.MetricsFilter;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@@ -50,4 +50,6 @@ public interface MetricService {
MetricQueryDefaultConfig getMetricQueryDefaultConfig(Long metricId, User user); MetricQueryDefaultConfig getMetricQueryDefaultConfig(Long metricId, User user);
void sendMetricEventBatch(List<Long> modelIds, EventType eventType); void sendMetricEventBatch(List<Long> modelIds, EventType eventType);
List<MetricResp> queryMetrics(MetricsFilter metricsFilter);
} }

View File

@@ -37,6 +37,8 @@ public interface ModelService {
List<ModelResp> getModelByDomainIds(List<Long> domainIds); List<ModelResp> getModelByDomainIds(List<Long> domainIds);
List<ModelResp> getAllModelByDomainIds(List<Long> domainIds);
ModelResp getModel(Long id); ModelResp getModel(Long id);
List<String> getModelAdmin(Long id); List<String> getModelAdmin(Long id);

View File

@@ -15,6 +15,7 @@ import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemUseResp; import com.tencent.supersonic.headless.api.pojo.response.ItemUseResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp; import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.ViewResp; import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
@@ -26,6 +27,8 @@ public interface SchemaService {
List<ViewSchemaResp> fetchViewSchema(ViewFilterReq filter); List<ViewSchemaResp> fetchViewSchema(ViewFilterReq filter);
List<ModelSchemaResp> fetchModelSchemaResps(List<Long> modelIds);
PageInfo<DimensionResp> queryDimension(PageDimensionReq pageDimensionReq, User user); PageInfo<DimensionResp> queryDimension(PageDimensionReq pageDimensionReq, User user);
PageInfo<MetricResp> queryMetric(PageMetricReq pageMetricReq, User user); PageInfo<MetricResp> queryMetric(PageMetricReq pageMetricReq, User user);

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.server.service.impl; package com.tencent.supersonic.headless.server.service.impl;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.exception.InvalidPermissionException;
import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq; import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
@@ -18,15 +19,16 @@ import com.tencent.supersonic.headless.server.pojo.ModelFilter;
import com.tencent.supersonic.headless.server.service.DatabaseService; import com.tencent.supersonic.headless.server.service.DatabaseService;
import com.tencent.supersonic.headless.server.service.ModelService; import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.utils.DatabaseConverter; import com.tencent.supersonic.headless.server.utils.DatabaseConverter;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Lazy; import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Slf4j @Slf4j
@Service @Service
@@ -58,12 +60,12 @@ public class DatabaseServiceImpl implements DatabaseService {
database.updatedBy(user.getName()); database.updatedBy(user.getName());
DatabaseConverter.convert(database, databaseDO); DatabaseConverter.convert(database, databaseDO);
databaseRepository.updateDatabase(databaseDO); databaseRepository.updateDatabase(databaseDO);
return DatabaseConverter.convert(databaseDO); return DatabaseConverter.convertWithPassword(databaseDO);
} }
database.createdBy(user.getName()); database.createdBy(user.getName());
databaseDO = DatabaseConverter.convert(database); databaseDO = DatabaseConverter.convert(database);
databaseRepository.createDatabase(databaseDO); databaseRepository.createDatabase(databaseDO);
return DatabaseConverter.convert(databaseDO); return DatabaseConverter.convertWithPassword(databaseDO);
} }
@Override @Override
@@ -108,7 +110,19 @@ public class DatabaseServiceImpl implements DatabaseService {
@Override @Override
public DatabaseResp getDatabase(Long id) { public DatabaseResp getDatabase(Long id) {
DatabaseDO databaseDO = databaseRepository.getDatabase(id); DatabaseDO databaseDO = databaseRepository.getDatabase(id);
return DatabaseConverter.convert(databaseDO); return DatabaseConverter.convertWithPassword(databaseDO);
}
@Override
public DatabaseResp getDatabase(Long id, User user) {
DatabaseResp databaseResp = getDatabase(id);
if (!databaseResp.getAdmins().contains(user.getName())
&& !databaseResp.getViewers().contains(user.getName())
&& !databaseResp.getCreatedBy().equals(user.getName())) {
throw new InvalidPermissionException("您暂无查看该数据库详情的权限, 请联系创建人: "
+ databaseResp.getCreatedBy());
}
return databaseResp;
} }
@Override @Override

View File

@@ -29,6 +29,7 @@ import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
import com.tencent.supersonic.headless.server.persistence.dataobject.DimensionDO; import com.tencent.supersonic.headless.server.persistence.dataobject.DimensionDO;
import com.tencent.supersonic.headless.server.persistence.repository.DimensionRepository; import com.tencent.supersonic.headless.server.persistence.repository.DimensionRepository;
import com.tencent.supersonic.headless.server.pojo.DimensionFilter; import com.tencent.supersonic.headless.server.pojo.DimensionFilter;
import com.tencent.supersonic.headless.server.pojo.DimensionsFilter;
import com.tencent.supersonic.headless.server.pojo.MetaFilter; import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.service.DatabaseService; import com.tencent.supersonic.headless.server.service.DatabaseService;
import com.tencent.supersonic.headless.server.service.DimensionService; import com.tencent.supersonic.headless.server.service.DimensionService;
@@ -37,19 +38,18 @@ import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.ViewService; import com.tencent.supersonic.headless.server.service.ViewService;
import com.tencent.supersonic.headless.server.utils.DimensionConverter; import com.tencent.supersonic.headless.server.utils.DimensionConverter;
import com.tencent.supersonic.headless.server.utils.NameCheckUtils; import com.tencent.supersonic.headless.server.utils.NameCheckUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Date; import java.util.Date;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Service @Service
@Slf4j @Slf4j
@@ -217,6 +217,12 @@ public class DimensionServiceImpl implements DimensionService {
return dimensionRepository.getDimension(dimensionFilter); return dimensionRepository.getDimension(dimensionFilter);
} }
@Override
public List<DimensionResp> queryDimensions(DimensionsFilter dimensionsFilter) {
List<DimensionDO> dimensions = dimensionRepository.getDimensions(dimensionsFilter);
return convertList(dimensions, modelService.getModelMap());
}
@Override @Override
public List<DimensionResp> getDimensions(MetaFilter metaFilter) { public List<DimensionResp> getDimensions(MetaFilter metaFilter) {
DimensionFilter dimensionFilter = new DimensionFilter(); DimensionFilter dimensionFilter = new DimensionFilter();

View File

@@ -66,7 +66,7 @@ public class DownloadServiceImpl implements DownloadService {
private QueryService queryService; private QueryService queryService;
public DownloadServiceImpl(MetricService metricService, public DownloadServiceImpl(MetricService metricService,
DimensionService dimensionService, QueryService queryService) { DimensionService dimensionService, QueryService queryService) {
this.metricService = metricService; this.metricService = metricService;
this.dimensionService = dimensionService; this.dimensionService = dimensionService;
this.queryService = queryService; this.queryService = queryService;
@@ -74,11 +74,11 @@ public class DownloadServiceImpl implements DownloadService {
@Override @Override
public void downloadByStruct(DownloadStructReq downloadStructReq, public void downloadByStruct(DownloadStructReq downloadStructReq,
User user, HttpServletResponse response) throws Exception { User user, HttpServletResponse response) throws Exception {
String fileName = String.format("%s_%s.xlsx", "supersonic", DateUtils.format(new Date(), DateUtils.FORMAT)); String fileName = String.format("%s_%s.xlsx", "supersonic", DateUtils.format(new Date(), DateUtils.FORMAT));
File file = FileUtils.createTmpFile(fileName); File file = FileUtils.createTmpFile(fileName);
try { try {
QuerySqlReq querySqlReq = downloadStructReq.convert(downloadStructReq, true); QuerySqlReq querySqlReq = downloadStructReq.convert(true);
SemanticQueryResp queryResult = queryService.queryByReq(querySqlReq, user); SemanticQueryResp queryResult = queryService.queryByReq(querySqlReq, user);
DataDownload dataDownload = buildDataDownload(queryResult, downloadStructReq); DataDownload dataDownload = buildDataDownload(queryResult, downloadStructReq);
EasyExcel.write(file).sheet("Sheet1").head(dataDownload.getHeaders()).doWrite(dataDownload.getData()); EasyExcel.write(file).sheet("Sheet1").head(dataDownload.getHeaders()).doWrite(dataDownload.getData());
@@ -92,7 +92,7 @@ public class DownloadServiceImpl implements DownloadService {
@Override @Override
public void batchDownload(BatchDownloadReq batchDownloadReq, User user, public void batchDownload(BatchDownloadReq batchDownloadReq, User user,
HttpServletResponse response) throws Exception { HttpServletResponse response) throws Exception {
String fileName = String.format("%s_%s.xlsx", "supersonic", DateUtils.format(new Date(), DateUtils.FORMAT)); String fileName = String.format("%s_%s.xlsx", "supersonic", DateUtils.format(new Date(), DateUtils.FORMAT));
File file = FileUtils.createTmpFile(fileName); File file = FileUtils.createTmpFile(fileName);
List<Long> metricIds = batchDownloadReq.getMetricIds(); List<Long> metricIds = batchDownloadReq.getMetricIds();
@@ -109,8 +109,9 @@ public class DownloadServiceImpl implements DownloadService {
metaFilter.setIds(metricIds); metaFilter.setIds(metricIds);
List<MetricResp> metricResps = metricService.getMetrics(metaFilter); List<MetricResp> metricResps = metricService.getMetrics(metaFilter);
Map<String, List<MetricResp>> metricMap = getMetricMap(metricResps); Map<String, List<MetricResp>> metricMap = getMetricMap(metricResps);
List<Long> dimensionIds = metricResps.stream().map(MetricResp::getRelateDimension) List<Long> dimensionIds = metricResps.stream()
.map(RelateDimension::getDrillDownDimensions).flatMap(Collection::stream) .map(metricResp -> metricService.getDrillDownDimension(metricResp.getId()))
.flatMap(Collection::stream)
.map(DrillDownDimension::getDimensionId).collect(Collectors.toList()); .map(DrillDownDimension::getDimensionId).collect(Collectors.toList());
metaFilter.setIds(dimensionIds); metaFilter.setIds(dimensionIds);
Map<Long, DimensionResp> dimensionRespMap = dimensionService.getDimensions(metaFilter) Map<Long, DimensionResp> dimensionRespMap = dimensionService.getDimensions(metaFilter)
@@ -126,7 +127,7 @@ public class DownloadServiceImpl implements DownloadService {
for (MetricResp metric : metrics) { for (MetricResp metric : metrics) {
try { try {
DownloadStructReq downloadStructReq = buildDownloadReq(dimensions, metric, batchDownloadReq); DownloadStructReq downloadStructReq = buildDownloadReq(dimensions, metric, batchDownloadReq);
QuerySqlReq querySqlReq = downloadStructReq.convert(downloadStructReq); QuerySqlReq querySqlReq = downloadStructReq.convert();
querySqlReq.setNeedAuth(true); querySqlReq.setNeedAuth(true);
SemanticQueryResp queryResult = queryService.queryByReq(querySqlReq, user); SemanticQueryResp queryResult = queryService.queryByReq(querySqlReq, user);
DataDownload dataDownload = buildDataDownload(queryResult, downloadStructReq); DataDownload dataDownload = buildDataDownload(queryResult, downloadStructReq);
@@ -192,7 +193,7 @@ public class DownloadServiceImpl implements DownloadService {
} }
private List<List<String>> buildData(List<List<String>> headers, Map<String, String> nameMap, private List<List<String>> buildData(List<List<String>> headers, Map<String, String> nameMap,
List<Map<String, Object>> dataTransformed, String metricName) { List<Map<String, Object>> dataTransformed, String metricName) {
List<List<String>> data = Lists.newArrayList(); List<List<String>> data = Lists.newArrayList();
for (Map<String, Object> map : dataTransformed) { for (Map<String, Object> map : dataTransformed) {
List<String> row = Lists.newArrayList(); List<String> row = Lists.newArrayList();
@@ -234,7 +235,7 @@ public class DownloadServiceImpl implements DownloadService {
} }
private DownloadStructReq buildDownloadReq(List<DimensionResp> dimensionResps, MetricResp metricResp, private DownloadStructReq buildDownloadReq(List<DimensionResp> dimensionResps, MetricResp metricResp,
BatchDownloadReq batchDownloadReq) { BatchDownloadReq batchDownloadReq) {
DateConf dateConf = batchDownloadReq.getDateInfo(); DateConf dateConf = batchDownloadReq.getDateInfo();
Set<Long> modelIds = dimensionResps.stream().map(DimensionResp::getModelId).collect(Collectors.toSet()); Set<Long> modelIds = dimensionResps.stream().map(DimensionResp::getModelId).collect(Collectors.toSet());
modelIds.add(metricResp.getModelId()); modelIds.add(metricResp.getModelId());
@@ -277,7 +278,7 @@ public class DownloadServiceImpl implements DownloadService {
} }
private List<DimensionResp> getMetricRelaDimensions(MetricResp metricResp, private List<DimensionResp> getMetricRelaDimensions(MetricResp metricResp,
Map<Long, DimensionResp> dimensionRespMap) { Map<Long, DimensionResp> dimensionRespMap) {
if (metricResp.getRelateDimension() == null if (metricResp.getRelateDimension() == null
|| CollectionUtils.isEmpty(metricResp.getRelateDimension().getDrillDownDimensions())) { || CollectionUtils.isEmpty(metricResp.getRelateDimension().getDrillDownDimensions())) {
return Lists.newArrayList(); return Lists.newArrayList();

View File

@@ -8,13 +8,14 @@ import com.tencent.supersonic.headless.core.knowledge.SearchService;
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper; import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.server.service.KnowledgeService; import com.tencent.supersonic.headless.server.service.KnowledgeService;
import com.tencent.supersonic.headless.server.service.ViewService; import com.tencent.supersonic.headless.server.service.ViewService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@Service @Service
@Slf4j @Slf4j
@@ -68,17 +69,19 @@ public class KnowledgeServiceImpl implements KnowledgeService {
@Override @Override
public List<S2Term> getTerms(String text) { public List<S2Term> getTerms(String text) {
return HanlpHelper.getTerms(text); Map<Long, List<Long>> modelIdToViewIds = viewService.getModelIdToViewIds(new ArrayList<>());
return HanlpHelper.getTerms(text, modelIdToViewIds);
} }
@Override @Override
public List<HanlpMapResult> prefixSearch(String key, int limit, Set<Long> viewIds) { public List<HanlpMapResult> prefixSearch(String key, int limit, Set<Long> viewIds) {
Map<Long, List<Long>> modelIdToViewIds = viewService.getModelIdToViewIds(new ArrayList<>(viewIds)); Map<Long, List<Long>> modelIdToViewIds = viewService.getModelIdToViewIds(new ArrayList<>(viewIds));
return prefixSearchByModel(key, limit, modelIdToViewIds.keySet()); return prefixSearchByModel(key, limit, modelIdToViewIds);
} }
public List<HanlpMapResult> prefixSearchByModel(String key, int limit, Set<Long> models) { public List<HanlpMapResult> prefixSearchByModel(String key, int limit,
return SearchService.prefixSearch(key, limit, models); Map<Long, List<Long>> modelIdToViewIds) {
return SearchService.prefixSearch(key, limit, modelIdToViewIds);
} }
@Override @Override

View File

@@ -5,6 +5,7 @@ import com.alibaba.fastjson.TypeReference;
import com.github.pagehelper.PageHelper; import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo; import com.github.pagehelper.PageInfo;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DataEvent; import com.tencent.supersonic.common.pojo.DataEvent;
@@ -15,15 +16,15 @@ import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.common.util.ChatGptHelper; import com.tencent.supersonic.common.util.ChatGptHelper;
import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension; import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.MeasureParam;
import com.tencent.supersonic.headless.api.pojo.MetricParam; import com.tencent.supersonic.headless.api.pojo.MetricParam;
import com.tencent.supersonic.headless.api.pojo.MetricQueryDefaultConfig; import com.tencent.supersonic.headless.api.pojo.MetricQueryDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType;
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq; import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
import com.tencent.supersonic.headless.api.pojo.request.MetricBaseReq; import com.tencent.supersonic.headless.api.pojo.request.MetricBaseReq;
import com.tencent.supersonic.headless.api.pojo.request.MetricReq; import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq; import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq;
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp; import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.ViewResp; import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
@@ -33,6 +34,7 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.MetricQuery
import com.tencent.supersonic.headless.server.persistence.repository.MetricRepository; import com.tencent.supersonic.headless.server.persistence.repository.MetricRepository;
import com.tencent.supersonic.headless.server.pojo.MetaFilter; import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.pojo.MetricFilter; import com.tencent.supersonic.headless.server.pojo.MetricFilter;
import com.tencent.supersonic.headless.server.pojo.MetricsFilter;
import com.tencent.supersonic.headless.server.service.CollectService; import com.tencent.supersonic.headless.server.service.CollectService;
import com.tencent.supersonic.headless.server.service.DomainService; import com.tencent.supersonic.headless.server.service.DomainService;
import com.tencent.supersonic.headless.server.service.MetricService; import com.tencent.supersonic.headless.server.service.MetricService;
@@ -46,6 +48,7 @@ import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Date; import java.util.Date;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
@@ -184,9 +187,7 @@ public class MetricServiceImpl implements MetricService {
public PageInfo<MetricResp> queryMetric(PageMetricReq pageMetricReq, User user) { public PageInfo<MetricResp> queryMetric(PageMetricReq pageMetricReq, User user) {
MetricFilter metricFilter = new MetricFilter(); MetricFilter metricFilter = new MetricFilter();
BeanUtils.copyProperties(pageMetricReq, metricFilter); BeanUtils.copyProperties(pageMetricReq, metricFilter);
Set<DomainResp> domainResps = domainService.getDomainChildren(pageMetricReq.getDomainIds()); List<ModelResp> modelResps = modelService.getAllModelByDomainIds(pageMetricReq.getDomainIds());
List<Long> domainIds = domainResps.stream().map(DomainResp::getId).collect(Collectors.toList());
List<ModelResp> modelResps = modelService.getModelByDomainIds(domainIds);
List<Long> modelIds = modelResps.stream().map(ModelResp::getId).collect(Collectors.toList()); List<Long> modelIds = modelResps.stream().map(ModelResp::getId).collect(Collectors.toList());
pageMetricReq.getModelIds().addAll(modelIds); pageMetricReq.getModelIds().addAll(modelIds);
metricFilter.setModelIds(pageMetricReq.getModelIds()); metricFilter.setModelIds(pageMetricReq.getModelIds());
@@ -230,28 +231,44 @@ public class MetricServiceImpl implements MetricService {
} }
private List<MetricResp> filterByField(List<MetricResp> metricResps, List<String> fields) { private List<MetricResp> filterByField(List<MetricResp> metricResps, List<String> fields) {
List<MetricResp> metricRespFiltered = Lists.newArrayList(); Set<MetricResp> metricRespFiltered = Sets.newHashSet();
for (MetricResp metricResp : metricResps) { for (MetricResp metricResp : metricResps) {
for (String field : fields) { filterByField(metricResps, metricResp, fields, metricRespFiltered);
if (MetricDefineType.METRIC.equals(metricResp.getMetricDefineType())) { }
List<Long> ids = metricResp.getMetricDefineByMetricParams().getMetrics() return new ArrayList<>(metricRespFiltered);
.stream().map(MetricParam::getId).collect(Collectors.toList()); }
List<MetricResp> metricById = metricResps.stream()
.filter(metric -> ids.contains(metric.getId())) private boolean filterByField(List<MetricResp> metricResps, MetricResp metricResp,
.collect(Collectors.toList()); List<String> fields, Set<MetricResp> metricRespFiltered) {
for (MetricResp metric : metricById) { if (MetricDefineType.METRIC.equals(metricResp.getMetricDefineType())) {
if (metric.getExpr().contains(field)) { List<Long> ids = metricResp.getMetricDefineByMetricParams().getMetrics()
metricRespFiltered.add(metricResp); .stream().map(MetricParam::getId).collect(Collectors.toList());
} List<MetricResp> metricById = metricResps.stream()
} .filter(metric -> ids.contains(metric.getId()))
} else { .collect(Collectors.toList());
if (metricResp.getExpr().contains(field)) { for (MetricResp metric : metricById) {
metricRespFiltered.add(metricResp); if (filterByField(metricResps, metric, fields, metricRespFiltered)) {
} metricRespFiltered.add(metricResp);
return true;
} }
} }
} else if (MetricDefineType.FIELD.equals(metricResp.getMetricDefineType())) {
if (fields.stream().anyMatch(field -> metricResp.getExpr().contains(field))) {
metricRespFiltered.add(metricResp);
return true;
}
} else if (MetricDefineType.MEASURE.equals(metricResp.getMetricDefineType())) {
List<MeasureParam> measures = metricResp.getMetricDefineByMeasureParams().getMeasures();
List<String> fieldNameDepended = measures.stream().map(MeasureParam::getBizName)
//measure bizName = model bizName_fieldName
.map(name -> name.replaceFirst(metricResp.getModelBizName() + "_", ""))
.collect(Collectors.toList());
if (fields.stream().anyMatch(fieldNameDepended::contains)) {
metricRespFiltered.add(metricResp);
return true;
}
} }
return metricRespFiltered; return false;
} }
@Override @Override
@@ -343,7 +360,12 @@ public class MetricServiceImpl implements MetricService {
} }
if (metricResp.getRelateDimension() != null if (metricResp.getRelateDimension() != null
&& !CollectionUtils.isEmpty(metricResp.getRelateDimension().getDrillDownDimensions())) { && !CollectionUtils.isEmpty(metricResp.getRelateDimension().getDrillDownDimensions())) {
drillDownDimensions.addAll(metricResp.getRelateDimension().getDrillDownDimensions()); for (DrillDownDimension drillDownDimension : metricResp.getRelateDimension().getDrillDownDimensions()) {
if (drillDownDimension.isInheritedFromModel() && !drillDownDimension.isNecessary()) {
continue;
}
drillDownDimensions.add(drillDownDimension);
}
} }
ModelResp modelResp = modelService.getModel(metricResp.getModelId()); ModelResp modelResp = modelService.getModel(metricResp.getModelId());
if (modelResp.getDrillDownDimensions() == null) { if (modelResp.getDrillDownDimensions() == null) {
@@ -435,6 +457,12 @@ public class MetricServiceImpl implements MetricService {
sendEventBatch(metricDOS, eventType); sendEventBatch(metricDOS, eventType);
} }
@Override
public List<MetricResp> queryMetrics(MetricsFilter metricsFilter) {
List<MetricDO> metricDOS = metricRepository.getMetrics(metricsFilter);
return convertList(metricDOS, new ArrayList<>());
}
private void sendEventBatch(List<MetricDO> metricDOS, EventType eventType) { private void sendEventBatch(List<MetricDO> metricDOS, EventType eventType) {
List<DataItem> dataItems = metricDOS.stream().map(this::getDataItem) List<DataItem> dataItems = metricDOS.stream().map(this::getDataItem)
.collect(Collectors.toList()); .collect(Collectors.toList());

View File

@@ -81,13 +81,13 @@ public class ModelServiceImpl implements ModelService {
private DateInfoRepository dateInfoRepository; private DateInfoRepository dateInfoRepository;
public ModelServiceImpl(ModelRepository modelRepository, public ModelServiceImpl(ModelRepository modelRepository,
DatabaseService databaseService, DatabaseService databaseService,
@Lazy DimensionService dimensionService, @Lazy DimensionService dimensionService,
@Lazy MetricService metricService, @Lazy MetricService metricService,
DomainService domainService, DomainService domainService,
UserService userService, UserService userService,
ViewService viewService, ViewService viewService,
DateInfoRepository dateInfoRepository) { DateInfoRepository dateInfoRepository) {
this.modelRepository = modelRepository; this.modelRepository = modelRepository;
this.databaseService = databaseService; this.databaseService = databaseService;
this.dimensionService = dimensionService; this.dimensionService = dimensionService;
@@ -350,6 +350,13 @@ public class ModelServiceImpl implements ModelService {
domainIds.contains(modelResp.getDomainId())).collect(Collectors.toList()); domainIds.contains(modelResp.getDomainId())).collect(Collectors.toList());
} }
@Override
public List<ModelResp> getAllModelByDomainIds(List<Long> domainIds) {
Set<DomainResp> domainResps = domainService.getDomainChildren(domainIds);
List<Long> allDomainIds = domainResps.stream().map(DomainResp::getId).collect(Collectors.toList());
return getModelByDomainIds(allDomainIds);
}
@Override @Override
public ModelResp getModel(Long id) { public ModelResp getModel(Long id) {
ModelDO modelDO = getModelDO(id); ModelDO modelDO = getModelDO(id);

View File

@@ -13,6 +13,7 @@ import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.headless.api.pojo.Dim; import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.Item; import com.tencent.supersonic.headless.api.pojo.Item;
import com.tencent.supersonic.headless.api.pojo.QueryParam; import com.tencent.supersonic.headless.api.pojo.QueryParam;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.SingleItemQueryResult; import com.tencent.supersonic.headless.api.pojo.SingleItemQueryResult;
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq; import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.ItemUseReq; import com.tencent.supersonic.headless.api.pojo.request.ItemUseReq;
@@ -44,13 +45,24 @@ import com.tencent.supersonic.headless.server.annotation.S2DataPermission;
import com.tencent.supersonic.headless.server.aspect.ApiHeaderCheckAspect; import com.tencent.supersonic.headless.server.aspect.ApiHeaderCheckAspect;
import com.tencent.supersonic.headless.server.manager.SemanticSchemaManager; import com.tencent.supersonic.headless.server.manager.SemanticSchemaManager;
import com.tencent.supersonic.headless.server.pojo.DimensionFilter; import com.tencent.supersonic.headless.server.pojo.DimensionFilter;
import com.tencent.supersonic.headless.server.pojo.DimensionsFilter;
import com.tencent.supersonic.headless.server.pojo.MetricsFilter;
import com.tencent.supersonic.headless.server.pojo.ModelCluster;
import com.tencent.supersonic.headless.server.service.AppService; import com.tencent.supersonic.headless.server.service.AppService;
import com.tencent.supersonic.headless.server.service.Catalog; import com.tencent.supersonic.headless.server.service.Catalog;
import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.QueryService; import com.tencent.supersonic.headless.server.service.QueryService;
import com.tencent.supersonic.headless.server.utils.ModelClusterBuilder;
import com.tencent.supersonic.headless.server.utils.QueryReqConverter; import com.tencent.supersonic.headless.server.utils.QueryReqConverter;
import com.tencent.supersonic.headless.server.utils.QueryUtils; import com.tencent.supersonic.headless.server.utils.QueryUtils;
import com.tencent.supersonic.headless.server.utils.StatUtils; import com.tencent.supersonic.headless.server.utils.StatUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@@ -60,6 +72,7 @@ import javax.servlet.http.HttpServletRequest;
import lombok.SneakyThrows; import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -79,6 +92,12 @@ public class QueryServiceImpl implements QueryService {
private final QueryPlanner queryPlanner; private final QueryPlanner queryPlanner;
private final MetricService metricService;
private final ModelService modelService;
private final DimensionService dimensionService;
public QueryServiceImpl( public QueryServiceImpl(
StatUtils statUtils, StatUtils statUtils,
QueryUtils queryUtils, QueryUtils queryUtils,
@@ -88,7 +107,10 @@ public class QueryServiceImpl implements QueryService {
QueryCache queryCache, QueryCache queryCache,
SemanticSchemaManager semanticSchemaManager, SemanticSchemaManager semanticSchemaManager,
DefaultQueryParser queryParser, DefaultQueryParser queryParser,
QueryPlanner queryPlanner) { QueryPlanner queryPlanner,
MetricService metricService,
ModelService modelService,
DimensionService dimensionService) {
this.statUtils = statUtils; this.statUtils = statUtils;
this.queryUtils = queryUtils; this.queryUtils = queryUtils;
this.queryReqConverter = queryReqConverter; this.queryReqConverter = queryReqConverter;
@@ -98,6 +120,9 @@ public class QueryServiceImpl implements QueryService {
this.semanticSchemaManager = semanticSchemaManager; this.semanticSchemaManager = semanticSchemaManager;
this.queryParser = queryParser; this.queryParser = queryParser;
this.queryPlanner = queryPlanner; this.queryPlanner = queryPlanner;
this.metricService = metricService;
this.modelService = modelService;
this.dimensionService = dimensionService;
} }
@Override @Override
@@ -241,8 +266,130 @@ public class QueryServiceImpl implements QueryService {
} }
@Override @Override
public SemanticQueryResp queryByMetric(QueryMetricReq queryMetricReq, User user) throws Exception { public SemanticQueryResp queryByMetric(QueryMetricReq queryMetricReq, User user) {
return null; QueryStructReq queryStructReq = buildQueryStructReq(queryMetricReq);
return queryByReq(queryStructReq.convert(), user);
}
private QueryStructReq buildQueryStructReq(QueryMetricReq queryMetricReq) {
//1. If a domainId exists, the modelIds obtained from the domainId.
Set<Long> modelIdsByDomainId = getModelIdsByDomainId(queryMetricReq);
//2. get metrics and dimensions
List<MetricResp> metricResps = getMetricResps(queryMetricReq, modelIdsByDomainId);
List<DimensionResp> dimensionResps = getDimensionResps(queryMetricReq, modelIdsByDomainId);
//3. choose ModelCluster
Set<Long> modelIds = getModelIds(modelIdsByDomainId, metricResps, dimensionResps);
ModelCluster modelCluster = getModelCluster(metricResps, modelIds);
//4. set groups
List<String> dimensionBizNames = dimensionResps.stream()
.filter(entry -> modelCluster.getModelIds().contains(entry.getModelId()))
.map(entry -> entry.getBizName()).collect(Collectors.toList());
QueryStructReq queryStructReq = new QueryStructReq();
if (CollectionUtils.isNotEmpty(dimensionBizNames)) {
queryStructReq.setGroups(dimensionBizNames);
}
//5. set aggregators
List<String> metricBizNames = metricResps.stream()
.filter(entry -> modelCluster.getModelIds().contains(entry.getModelId()))
.map(entry -> entry.getBizName()).collect(Collectors.toList());
if (CollectionUtils.isEmpty(metricBizNames)) {
throw new IllegalArgumentException("Invalid input parameters, unable to obtain valid metrics");
}
List<Aggregator> aggregators = new ArrayList<>();
for (String metricBizName : metricBizNames) {
Aggregator aggregator = new Aggregator();
aggregator.setColumn(metricBizName);
aggregators.add(aggregator);
}
queryStructReq.setAggregators(aggregators);
queryStructReq.setLimit(queryMetricReq.getLimit());
//6. set modelIds
for (Long modelId : modelCluster.getModelIds()) {
queryStructReq.addModelId(modelId);
}
//7. set dateInfo
queryStructReq.setDateInfo(queryMetricReq.getDateInfo());
return queryStructReq;
}
private QueryStructReq buildQueryStructReq(List<DimensionResp> dimensionResps,
MetricResp metricResp, DateConf dateConf, Long limit) {
Set<Long> modelIds = dimensionResps.stream().map(DimensionResp::getModelId).collect(Collectors.toSet());
modelIds.add(metricResp.getModelId());
QueryStructReq queryStructReq = new QueryStructReq();
queryStructReq.setGroups(dimensionResps.stream()
.map(DimensionResp::getBizName).collect(Collectors.toList()));
queryStructReq.getGroups().add(0, getTimeDimension(dateConf));
Aggregator aggregator = new Aggregator();
aggregator.setColumn(metricResp.getBizName());
queryStructReq.setAggregators(Lists.newArrayList(aggregator));
queryStructReq.setDateInfo(dateConf);
queryStructReq.setModelIds(modelIds);
queryStructReq.setLimit(limit);
return queryStructReq;
}
private ModelCluster getModelCluster(List<MetricResp> metricResps, Set<Long> modelIds) {
Map<String, ModelCluster> modelClusterMap = ModelClusterBuilder.buildModelClusters(new ArrayList<>(modelIds));
Map<String, List<SchemaItem>> modelClusterToMatchCount = new HashMap<>();
for (ModelCluster modelCluster : modelClusterMap.values()) {
for (MetricResp metricResp : metricResps) {
if (modelCluster.getModelIds().contains(metricResp.getModelId())) {
modelClusterToMatchCount.computeIfAbsent(modelCluster.getKey(), k -> new ArrayList<>())
.add(metricResp);
}
}
}
String keyWithMaxSize = modelClusterToMatchCount.entrySet().stream()
.max(Comparator.comparingInt(entry -> entry.getValue().size()))
.map(Map.Entry::getKey)
.orElse(null);
return modelClusterMap.get(keyWithMaxSize);
}
private Set<Long> getModelIds(Set<Long> modelIdsByDomainId, List<MetricResp> metricResps,
List<DimensionResp> dimensionResps) {
Set<Long> result = new HashSet<>();
if (CollectionUtils.isNotEmpty(modelIdsByDomainId)) {
result.addAll(modelIdsByDomainId);
return result;
}
Set<Long> metricModelIds = metricResps.stream().map(entry -> entry.getModelId())
.collect(Collectors.toSet());
result.addAll(metricModelIds);
Set<Long> dimensionModelIds = dimensionResps.stream().map(entry -> entry.getModelId())
.collect(Collectors.toSet());
result.addAll(dimensionModelIds);
return result;
}
private List<DimensionResp> getDimensionResps(QueryMetricReq queryMetricReq, Set<Long> modelIds) {
DimensionsFilter dimensionsFilter = new DimensionsFilter();
BeanUtils.copyProperties(queryMetricReq, dimensionsFilter);
dimensionsFilter.setModelIds(new ArrayList<>(modelIds));
List<DimensionResp> dimensionResps = dimensionService.queryDimensions(dimensionsFilter);
return dimensionResps;
}
private List<MetricResp> getMetricResps(QueryMetricReq queryMetricReq, Set<Long> modelIds) {
MetricsFilter metricsFilter = new MetricsFilter();
BeanUtils.copyProperties(queryMetricReq, metricsFilter);
metricsFilter.setModelIds(new ArrayList<>(modelIds));
return metricService.queryMetrics(metricsFilter);
}
private Set<Long> getModelIdsByDomainId(QueryMetricReq queryMetricReq) {
List<ModelResp> modelResps = modelService.getAllModelByDomainIds(
Collections.singletonList(queryMetricReq.getDomainId()));
return modelResps.stream().map(ModelResp::getId).collect(Collectors.toSet());
} }
private SingleItemQueryResult dataQuery(Integer appId, Item item, DateConf dateConf, Long limit) throws Exception { private SingleItemQueryResult dataQuery(Integer appId, Item item, DateConf dateConf, Long limit) throws Exception {
@@ -271,23 +418,6 @@ public class QueryServiceImpl implements QueryService {
return appService.getApp(appId); return appService.getApp(appId);
} }
private QueryStructReq buildQueryStructReq(List<DimensionResp> dimensionResps,
MetricResp metricResp, DateConf dateConf, Long limit) {
Set<Long> modelIds = dimensionResps.stream().map(DimensionResp::getModelId).collect(Collectors.toSet());
modelIds.add(metricResp.getModelId());
QueryStructReq queryStructReq = new QueryStructReq();
queryStructReq.setGroups(dimensionResps.stream()
.map(DimensionResp::getBizName).collect(Collectors.toList()));
queryStructReq.getGroups().add(0, getTimeDimension(dateConf));
Aggregator aggregator = new Aggregator();
aggregator.setColumn(metricResp.getBizName());
queryStructReq.setAggregators(Lists.newArrayList(aggregator));
queryStructReq.setDateInfo(dateConf);
queryStructReq.setModelIds(modelIds);
queryStructReq.setLimit(limit);
return queryStructReq;
}
private String getTimeDimension(DateConf dateConf) { private String getTimeDimension(DateConf dateConf) {
if (Constants.MONTH.equals(dateConf.getPeriod())) { if (Constants.MONTH.equals(dateConf.getPeriod())) {
return TimeDimensionEnum.MONTH.getName(); return TimeDimensionEnum.MONTH.getName();

View File

@@ -64,6 +64,13 @@ public class SchemaServiceImpl implements SchemaService {
protected final Cache<String, List<ItemUseResp>> itemUseCache = protected final Cache<String, List<ItemUseResp>> itemUseCache =
CacheBuilder.newBuilder().expireAfterWrite(1, TimeUnit.DAYS).build(); CacheBuilder.newBuilder().expireAfterWrite(1, TimeUnit.DAYS).build();
protected final Cache<ViewFilterReq, List<ViewSchemaResp>> viewSchemaCache =
CacheBuilder.newBuilder().expireAfterWrite(30, TimeUnit.SECONDS).build();
protected final Cache<SchemaFilterReq, SemanticSchemaResp> semanticSchemaCache =
CacheBuilder.newBuilder().expireAfterWrite(30, TimeUnit.SECONDS).build();
private final StatUtils statUtils; private final StatUtils statUtils;
private final ModelService modelService; private final ModelService modelService;
private final DimensionService dimensionService; private final DimensionService dimensionService;
@@ -91,6 +98,22 @@ public class SchemaServiceImpl implements SchemaService {
@SneakyThrows @SneakyThrows
@Override @Override
public List<ViewSchemaResp> fetchViewSchema(ViewFilterReq filter) { public List<ViewSchemaResp> fetchViewSchema(ViewFilterReq filter) {
List<ViewSchemaResp> viewList = viewSchemaCache.getIfPresent(filter);
if (CollectionUtils.isEmpty(viewList)) {
viewList = buildViewSchema(filter);
viewSchemaCache.put(filter, viewList);
}
return viewList;
}
public ViewSchemaResp fetchViewSchema(Long viewId) {
if (viewId == null) {
return null;
}
return fetchViewSchema(new ViewFilterReq(viewId)).stream().findFirst().orElse(null);
}
public List<ViewSchemaResp> buildViewSchema(ViewFilterReq filter) {
List<ViewSchemaResp> viewSchemaResps = new ArrayList<>(); List<ViewSchemaResp> viewSchemaResps = new ArrayList<>();
List<Long> viewIds = filter.getViewIds(); List<Long> viewIds = filter.getViewIds();
MetaFilter metaFilter = new MetaFilter(); MetaFilter metaFilter = new MetaFilter();
@@ -127,13 +150,6 @@ public class SchemaServiceImpl implements SchemaService {
return viewSchemaResps; return viewSchemaResps;
} }
public ViewSchemaResp fetchViewSchema(Long viewId) {
if (viewId == null) {
return null;
}
return fetchViewSchema(new ViewFilterReq(viewId)).stream().findFirst().orElse(null);
}
public List<ModelSchemaResp> fetchModelSchemaResps(List<Long> modelIds) { public List<ModelSchemaResp> fetchModelSchemaResps(List<Long> modelIds) {
List<ModelSchemaResp> modelSchemaResps = Lists.newArrayList(); List<ModelSchemaResp> modelSchemaResps = Lists.newArrayList();
if (CollectionUtils.isEmpty(modelIds)) { if (CollectionUtils.isEmpty(modelIds)) {
@@ -258,8 +274,7 @@ public class SchemaServiceImpl implements SchemaService {
return viewService.getViewList(metaFilter); return viewService.getViewList(metaFilter);
} }
@Override public SemanticSchemaResp buildSemanticSchema(SchemaFilterReq schemaFilterReq) {
public SemanticSchemaResp fetchSemanticSchema(SchemaFilterReq schemaFilterReq) {
SemanticSchemaResp semanticSchemaResp = new SemanticSchemaResp(); SemanticSchemaResp semanticSchemaResp = new SemanticSchemaResp();
semanticSchemaResp.setViewId(schemaFilterReq.getViewId()); semanticSchemaResp.setViewId(schemaFilterReq.getViewId());
semanticSchemaResp.setModelIds(schemaFilterReq.getModelIds()); semanticSchemaResp.setModelIds(schemaFilterReq.getModelIds());
@@ -294,6 +309,16 @@ public class SchemaServiceImpl implements SchemaService {
return semanticSchemaResp; return semanticSchemaResp;
} }
@Override
public SemanticSchemaResp fetchSemanticSchema(SchemaFilterReq schemaFilterReq) {
SemanticSchemaResp semanticSchemaResp = semanticSchemaCache.getIfPresent(schemaFilterReq);
if (semanticSchemaResp == null) {
semanticSchemaResp = buildSemanticSchema(schemaFilterReq);
semanticSchemaCache.put(schemaFilterReq, semanticSchemaResp);
}
return semanticSchemaResp;
}
@SneakyThrows @SneakyThrows
@Override @Override
public List<ItemUseResp> getStatInfo(ItemUseReq itemUseReq) { public List<ItemUseResp> getStatInfo(ItemUseReq itemUseReq) {

View File

@@ -10,17 +10,29 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.AuthType; import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.common.pojo.enums.StatusEnum; import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.headless.api.pojo.QueryConfig; import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.ViewDetail; import com.tencent.supersonic.headless.api.pojo.ViewDetail;
import com.tencent.supersonic.headless.api.pojo.request.ViewReq; import com.tencent.supersonic.headless.api.pojo.request.ViewReq;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.DomainResp; import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.ViewResp; import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
import com.tencent.supersonic.headless.server.persistence.dataobject.ViewDO; import com.tencent.supersonic.headless.server.persistence.dataobject.ViewDO;
import com.tencent.supersonic.headless.server.persistence.mapper.ViewDOMapper; import com.tencent.supersonic.headless.server.persistence.mapper.ViewDOMapper;
import com.tencent.supersonic.headless.server.pojo.MetaFilter; import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.DomainService; import com.tencent.supersonic.headless.server.service.DomainService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.ViewService; import com.tencent.supersonic.headless.server.service.ViewService;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.Arrays; import java.util.Arrays;
import java.util.Comparator; import java.util.Comparator;
import java.util.Date; import java.util.Date;
@@ -29,12 +41,8 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Service @Service
public class ViewServiceImpl public class ViewServiceImpl
@@ -46,21 +54,33 @@ public class ViewServiceImpl
@Autowired @Autowired
private DomainService domainService; private DomainService domainService;
@Lazy
@Autowired
private DimensionService dimensionService;
@Lazy
@Autowired
private MetricService metricService;
@Override @Override
public ViewResp save(ViewReq viewReq, User user) { public ViewResp save(ViewReq viewReq, User user) {
viewReq.createdBy(user.getName()); viewReq.createdBy(user.getName());
ViewDO viewDO = convert(viewReq); ViewDO viewDO = convert(viewReq);
viewDO.setStatus(StatusEnum.ONLINE.getCode()); viewDO.setStatus(StatusEnum.ONLINE.getCode());
ViewResp viewResp = convert(viewDO);
conflictCheck(viewResp);
save(viewDO); save(viewDO);
return convert(viewDO); return viewResp;
} }
@Override @Override
public ViewResp update(ViewReq viewReq, User user) { public ViewResp update(ViewReq viewReq, User user) {
viewReq.updatedBy(user.getName()); viewReq.updatedBy(user.getName());
ViewDO viewDO = convert(viewReq); ViewDO viewDO = convert(viewReq);
ViewResp viewResp = convert(viewDO);
conflictCheck(viewResp);
updateById(viewDO); updateById(viewDO);
return convert(viewDO); return viewResp;
} }
@Override @Override
@@ -78,6 +98,9 @@ public class ViewServiceImpl
if (!CollectionUtils.isEmpty(metaFilter.getIds())) { if (!CollectionUtils.isEmpty(metaFilter.getIds())) {
wrapper.lambda().in(ViewDO::getId, metaFilter.getIds()); wrapper.lambda().in(ViewDO::getId, metaFilter.getIds());
} }
if (metaFilter.getStatus() != null) {
wrapper.lambda().eq(ViewDO::getStatus, metaFilter.getStatus());
}
wrapper.lambda().ne(ViewDO::getStatus, StatusEnum.DELETED.getCode()); wrapper.lambda().ne(ViewDO::getStatus, StatusEnum.DELETED.getCode());
return list(wrapper).stream().map(this::convert).collect(Collectors.toList()); return list(wrapper).stream().map(this::convert).collect(Collectors.toList());
} }
@@ -175,4 +198,46 @@ public class ViewServiceImpl
viewResp -> viewResp.getAllModels().stream().map(modelId -> Pair.of(modelId, viewResp.getId()))) viewResp -> viewResp.getAllModels().stream().map(modelId -> Pair.of(modelId, viewResp.getId())))
.collect(Collectors.groupingBy(Pair::getLeft, Collectors.mapping(Pair::getRight, Collectors.toList()))); .collect(Collectors.groupingBy(Pair::getLeft, Collectors.mapping(Pair::getRight, Collectors.toList())));
} }
private void conflictCheck(ViewResp viewResp) {
List<Long> allDimensionIds = viewResp.getAllDimensions();
List<Long> allMetricIds = viewResp.getAllMetrics();
MetaFilter metaFilter = new MetaFilter();
if (!CollectionUtils.isEmpty(allDimensionIds)) {
metaFilter.setIds(allDimensionIds);
List<DimensionResp> dimensionResps = dimensionService.getDimensions(metaFilter);
List<String> duplicateDimensionNames = findDuplicates(dimensionResps, DimensionResp::getName);
List<String> duplicateDimensionBizNames = findDuplicates(dimensionResps, DimensionResp::getBizName);
if (!duplicateDimensionNames.isEmpty()) {
throw new InvalidArgumentException("存在相同的维度名: " + duplicateDimensionNames);
}
if (!duplicateDimensionBizNames.isEmpty()) {
throw new InvalidArgumentException("存在相同的维度英文名: " + duplicateDimensionBizNames);
}
}
if (!CollectionUtils.isEmpty(allMetricIds)) {
metaFilter.setIds(allMetricIds);
List<MetricResp> metricResps = metricService.getMetrics(metaFilter);
List<String> duplicateMetricNames = findDuplicates(metricResps, MetricResp::getName);
List<String> duplicateMetricBizNames = findDuplicates(metricResps, MetricResp::getBizName);
if (!duplicateMetricNames.isEmpty()) {
throw new InvalidArgumentException("存在相同的指标名: " + duplicateMetricNames);
}
if (!duplicateMetricBizNames.isEmpty()) {
throw new InvalidArgumentException("存在相同的指标英文名: " + duplicateMetricBizNames);
}
}
}
private <T, R> List<String> findDuplicates(List<T> list, Function<T, R> keyExtractor) {
return list.stream()
.collect(Collectors.groupingBy(keyExtractor, Collectors.counting()))
.entrySet().stream()
.filter(entry -> entry.getValue() > 1)
.map(Map.Entry::getKey)
.map(Object::toString)
.collect(Collectors.toList());
}
} }

View File

@@ -63,7 +63,6 @@ public class DatabaseConverter {
BeanUtils.copyProperties(databaseDO, databaseResp); BeanUtils.copyProperties(databaseDO, databaseResp);
ConnectInfo connectInfo = JSONObject.parseObject(databaseDO.getConfig(), ConnectInfo.class); ConnectInfo connectInfo = JSONObject.parseObject(databaseDO.getConfig(), ConnectInfo.class);
databaseResp.setUrl(connectInfo.getUrl()); databaseResp.setUrl(connectInfo.getUrl());
databaseResp.setPassword(connectInfo.getPassword());
databaseResp.setUsername(connectInfo.getUserName()); databaseResp.setUsername(connectInfo.getUserName());
databaseResp.setDatabase(connectInfo.getDatabase()); databaseResp.setDatabase(connectInfo.getDatabase());
if (StringUtils.isNotBlank(databaseDO.getAdmin())) { if (StringUtils.isNotBlank(databaseDO.getAdmin())) {
@@ -75,4 +74,11 @@ public class DatabaseConverter {
return databaseResp; return databaseResp;
} }
public static DatabaseResp convertWithPassword(DatabaseDO databaseDO) {
DatabaseResp databaseResp = convert(databaseDO);
ConnectInfo connectInfo = JSONObject.parseObject(databaseDO.getConfig(), ConnectInfo.class);
databaseResp.setPassword(connectInfo.getPassword());
return databaseResp;
}
} }

View File

@@ -1,5 +1,10 @@
package com.tencent.supersonic.headless.server.utils; package com.tencent.supersonic.headless.server.utils;
import static com.tencent.supersonic.common.pojo.Constants.AND_UPPER;
import static com.tencent.supersonic.common.pojo.Constants.APOSTROPHE;
import static com.tencent.supersonic.common.pojo.Constants.COMMA;
import static com.tencent.supersonic.common.pojo.Constants.SPACE;
import com.google.common.base.Strings; import com.google.common.base.Strings;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.Aggregator; import com.tencent.supersonic.common.pojo.Aggregator;
@@ -31,11 +36,6 @@ import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.MetricService; import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.ModelService; import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.QueryService; import com.tencent.supersonic.headless.server.service.QueryService;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.time.LocalDate; import java.time.LocalDate;
import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatter;
import java.util.ArrayList; import java.util.ArrayList;
@@ -48,11 +48,10 @@ import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.StringJoiner; import java.util.StringJoiner;
import org.springframework.beans.BeanUtils;
import static com.tencent.supersonic.common.pojo.Constants.AND_UPPER; import org.springframework.beans.factory.annotation.Value;
import static com.tencent.supersonic.common.pojo.Constants.APOSTROPHE; import org.springframework.stereotype.Component;
import static com.tencent.supersonic.common.pojo.Constants.COMMA; import org.springframework.util.CollectionUtils;
import static com.tencent.supersonic.common.pojo.Constants.SPACE;
@Component @Component
public class DictUtils { public class DictUtils {
@@ -79,9 +78,9 @@ public class DictUtils {
private final ModelService modelService; private final ModelService modelService;
public DictUtils(DimensionService dimensionService, public DictUtils(DimensionService dimensionService,
MetricService metricService, MetricService metricService,
QueryService queryService, QueryService queryService,
ModelService modelService) { ModelService modelService) {
this.dimensionService = dimensionService; this.dimensionService = dimensionService;
this.metricService = metricService; this.metricService = metricService;
this.queryService = queryService; this.queryService = queryService;
@@ -222,7 +221,7 @@ public class DictUtils {
&& Objects.nonNull(dictItemResp.getConfig().getMetricId())) { && Objects.nonNull(dictItemResp.getConfig().getMetricId())) {
// 查询默认指标 // 查询默认指标
QueryStructReq queryStructReq = generateQueryStruct(dictItemResp); QueryStructReq queryStructReq = generateQueryStruct(dictItemResp);
return queryStructReq.convert(queryStructReq, true); return queryStructReq.convert(true);
} }
// count(1) 作为指标 // count(1) 作为指标
return constructQuerySqlReq(dictItemResp); return constructQuerySqlReq(dictItemResp);

View File

@@ -72,6 +72,7 @@ public class MetricConverter {
ModelResp modelResp = modelMap.get(metricDO.getModelId()); ModelResp modelResp = modelMap.get(metricDO.getModelId());
if (modelResp != null) { if (modelResp != null) {
metricResp.setModelName(modelResp.getName()); metricResp.setModelName(modelResp.getName());
metricResp.setModelBizName(modelResp.getBizName());
metricResp.setDomainId(modelResp.getDomainId()); metricResp.setDomainId(modelResp.getDomainId());
} }
metricResp.setIsCollect(collect != null && collect.contains(metricDO.getId())); metricResp.setIsCollect(collect != null && collect.contains(metricDO.getId()));

View File

@@ -0,0 +1,47 @@
package com.tencent.supersonic.headless.server.utils;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.response.ModelSchemaResp;
import com.tencent.supersonic.headless.server.pojo.ModelCluster;
import com.tencent.supersonic.headless.server.service.SchemaService;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
public class ModelClusterBuilder {
public static Map<String, ModelCluster> buildModelClusters(List<Long> modelIds) {
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
List<ModelSchemaResp> modelSchemaResps = schemaService.fetchModelSchemaResps(modelIds);
Map<Long, ModelSchemaResp> modelIdToModelSchema = modelSchemaResps.stream()
.collect(Collectors.toMap(ModelSchemaResp::getId, value -> value, (k1, k2) -> k1));
Set<Long> visited = new HashSet<>();
List<Set<Long>> modelClusters = new ArrayList<>();
for (ModelSchemaResp model : modelSchemaResps) {
if (!visited.contains(model.getId())) {
Set<Long> modelCluster = new HashSet<>();
dfs(model, modelIdToModelSchema, visited, modelCluster);
modelClusters.add(modelCluster);
}
}
return modelClusters.stream().map(ModelCluster::build)
.collect(Collectors.toMap(ModelCluster::getKey, value -> value, (k1, k2) -> k1));
}
private static void dfs(ModelSchemaResp model, Map<Long, ModelSchemaResp> modelMap,
Set<Long> visited, Set<Long> modelCluster) {
visited.add(model.getId());
modelCluster.add(model.getId());
for (Long neighborId : model.getModelClusterSet()) {
if (!visited.contains(neighborId)) {
dfs(modelMap.get(neighborId), modelMap, visited, modelCluster);
}
}
}
}

View File

@@ -167,4 +167,40 @@
</select> </select>
<select id="queryDimensions" resultMap="ResultMapWithBLOBs">
select *
from s2_dimension
where status != 3
<if test="modelIds != null and modelIds.size >0">
and model_id in
<foreach collection="modelIds" index="index" item="model" open="(" close=")"
separator=",">
#{model}
</foreach>
</if>
<if test="dimensionIds != null and dimensionIds.size >0">
and id in
<foreach collection="dimensionIds" index="index" item="dimensionId" open="(" close=")"
separator=",">
#{dimensionId}
</foreach>
</if>
<if test="dimensionNames != null and dimensionNames.size > 0">
AND (
(name IN
<foreach collection="dimensionNames" index="index" item="dimensionName" open="(" close=")"
separator=",">
#{dimensionName}
</foreach>)
OR
(biz_name IN
<foreach collection="dimensionNames" index="index" item="dimensionName" open="(" close=")"
separator=",">
#{dimensionName}
</foreach>)
)
</if>
</select>
</mapper> </mapper>

View File

@@ -2,27 +2,29 @@
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" <!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd"> "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.tencent.supersonic.headless.server.persistence.mapper.MetricDOCustomMapper"> <mapper namespace="com.tencent.supersonic.headless.server.persistence.mapper.MetricDOCustomMapper">
<resultMap id="BaseResultMap" type="com.tencent.supersonic.headless.server.persistence.dataobject.MetricDO"> <resultMap id="BaseResultMap"
<id column="id" jdbcType="BIGINT" property="id" /> type="com.tencent.supersonic.headless.server.persistence.dataobject.MetricDO">
<result column="model_id" jdbcType="BIGINT" property="modelId" /> <id column="id" jdbcType="BIGINT" property="id"/>
<result column="name" jdbcType="VARCHAR" property="name" /> <result column="model_id" jdbcType="BIGINT" property="modelId"/>
<result column="biz_name" jdbcType="VARCHAR" property="bizName" /> <result column="name" jdbcType="VARCHAR" property="name"/>
<result column="description" jdbcType="VARCHAR" property="description" /> <result column="biz_name" jdbcType="VARCHAR" property="bizName"/>
<result column="status" jdbcType="INTEGER" property="status" /> <result column="description" jdbcType="VARCHAR" property="description"/>
<result column="sensitive_level" jdbcType="INTEGER" property="sensitiveLevel" /> <result column="status" jdbcType="INTEGER" property="status"/>
<result column="type" jdbcType="VARCHAR" property="type" /> <result column="sensitive_level" jdbcType="INTEGER" property="sensitiveLevel"/>
<result column="created_at" jdbcType="TIMESTAMP" property="createdAt" /> <result column="type" jdbcType="VARCHAR" property="type"/>
<result column="created_by" jdbcType="VARCHAR" property="createdBy" /> <result column="created_at" jdbcType="TIMESTAMP" property="createdAt"/>
<result column="updated_at" jdbcType="TIMESTAMP" property="updatedAt" /> <result column="created_by" jdbcType="VARCHAR" property="createdBy"/>
<result column="updated_by" jdbcType="VARCHAR" property="updatedBy" /> <result column="updated_at" jdbcType="TIMESTAMP" property="updatedAt"/>
<result column="data_format_type" jdbcType="VARCHAR" property="dataFormatType" /> <result column="updated_by" jdbcType="VARCHAR" property="updatedBy"/>
<result column="data_format" jdbcType="VARCHAR" property="dataFormat" /> <result column="data_format_type" jdbcType="VARCHAR" property="dataFormatType"/>
<result column="alias" jdbcType="VARCHAR" property="alias" /> <result column="data_format" jdbcType="VARCHAR" property="dataFormat"/>
<result column="tags" jdbcType="VARCHAR" property="tags" /> <result column="alias" jdbcType="VARCHAR" property="alias"/>
<result column="define_type" jdbcType="VARCHAR" property="defineType" /> <result column="tags" jdbcType="VARCHAR" property="tags"/>
<result column="define_type" jdbcType="VARCHAR" property="defineType"/>
</resultMap> </resultMap>
<resultMap extends="BaseResultMap" id="ResultMapWithBLOBs" type="com.tencent.supersonic.headless.server.persistence.dataobject.MetricDO"> <resultMap extends="BaseResultMap" id="ResultMapWithBLOBs"
<result column="type_params" jdbcType="LONGVARCHAR" property="typeParams" /> type="com.tencent.supersonic.headless.server.persistence.dataobject.MetricDO">
<result column="type_params" jdbcType="LONGVARCHAR" property="typeParams"/>
</resultMap> </resultMap>
<sql id="Example_Where_Clause"> <sql id="Example_Where_Clause">
<where> <where>
@@ -56,14 +58,16 @@
</where> </where>
</sql> </sql>
<sql id="Base_Column_List"> <sql id="Base_Column_List">
id, model_id, name, biz_name, description, status, sensitive_level, type, created_at, id
, model_id, name, biz_name, description, status, sensitive_level, type, created_at,
created_by, updated_at, updated_by, data_format_type, data_format, alias, tags, define_type created_by, updated_at, updated_by, data_format_type, data_format, alias, tags, define_type
</sql> </sql>
<sql id="Blob_Column_List"> <sql id="Blob_Column_List">
type_params type_params
</sql> </sql>
<insert id="batchInsert" parameterType="java.util.List" useGeneratedKeys="true" keyProperty="id"> <insert id="batchInsert" parameterType="java.util.List" useGeneratedKeys="true"
keyProperty="id">
insert into s2_metric (model_id, name, insert into s2_metric (model_id, name,
biz_name, description, type,status,sensitive_level, biz_name, description, type,status,sensitive_level,
created_at, created_by, updated_at, created_at, created_by, updated_at,
@@ -94,9 +98,9 @@
</update> </update>
<select id="query" resultMap="ResultMapWithBLOBs"> <select id="query" resultMap="ResultMapWithBLOBs">
select * select *
from s2_metric from s2_metric
where status != 3 where status != 3
<if test="type != null and type != ''"> <if test="type != null and type != ''">
and type = #{type} and type = #{type}
</if> </if>
@@ -126,14 +130,14 @@
<if test="modelIds != null and modelIds.size >0"> <if test="modelIds != null and modelIds.size >0">
and model_id in and model_id in
<foreach collection="modelIds" index="index" item="model" open="(" close=")" <foreach collection="modelIds" index="index" item="model" open="(" close=")"
separator=","> separator=",">
#{model} #{model}
</foreach> </foreach>
</if> </if>
<if test="ids != null and ids.size >0"> <if test="ids != null and ids.size >0">
and id in and id in
<foreach collection="ids" index="index" item="id" open="(" close=")" <foreach collection="ids" index="index" item="id" open="(" close=")"
separator=","> separator=",">
#{id} #{id}
</foreach> </foreach>
</if> </if>
@@ -142,4 +146,40 @@
</if> </if>
</select> </select>
<select id="queryMetrics" resultMap="ResultMapWithBLOBs">
select *
from s2_metric
where status != 3
<if test="modelIds != null and modelIds.size >0">
and model_id in
<foreach collection="modelIds" index="index" item="model" open="(" close=")"
separator=",">
#{model}
</foreach>
</if>
<if test="metricIds != null and metricIds.size >0">
and id in
<foreach collection="metricIds" index="index" item="metricId" open="(" close=")"
separator=",">
#{metricId}
</foreach>
</if>
<if test="metricNames != null and metricNames.size > 0">
AND (
(name IN
<foreach collection="metricNames" index="index" item="metricName" open="(" close=")"
separator=",">
#{metricName}
</foreach>)
OR
(biz_name IN
<foreach collection="metricNames" index="index" item="metricName" open="(" close=")"
separator=",">
#{metricName}
</foreach>)
)
</if>
</select>
</mapper> </mapper>

View File

@@ -0,0 +1,71 @@
package com.tencent.supersonic.headless.server.utils;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.Param;
import com.tencent.supersonic.headless.api.pojo.SqlVariable;
import com.tencent.supersonic.headless.api.pojo.enums.VariableValueType;
import com.tencent.supersonic.headless.core.utils.SqlVariableParseUtils;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
public class SqlVariableParseUtilsTest {
@Test
void testParseSql_defaultVariableValue() {
String sql = "select * from t_$interval$ where id = $id$ and name = $name$";
List<SqlVariable> variables = Lists.newArrayList(mockNumSqlVariable(),
mockExprSqlVariable(), mockStrSqlVariable());
String actualSql = SqlVariableParseUtils.parse(sql, variables, Lists.newArrayList());
String expectedSql = "select * from t_d where id = 1 and name = 'tom'";
Assertions.assertEquals(expectedSql, actualSql);
}
@Test
void testParseSql() {
String sql = "select * from t_$interval$ where id = $id$ and name = $name$";
List<SqlVariable> variables = Lists.newArrayList(mockNumSqlVariable(),
mockExprSqlVariable(), mockStrSqlVariable());
List<Param> params = Lists.newArrayList(mockIdParam(), mockNameParam(), mockIntervalParam());
String actualSql = SqlVariableParseUtils.parse(sql, variables, params);
String expectedSql = "select * from t_wk where id = 2 and name = 'alice'";
Assertions.assertEquals(expectedSql, actualSql);
}
private SqlVariable mockNumSqlVariable() {
return mockSqlVariable("id", VariableValueType.NUMBER, 1);
}
private SqlVariable mockStrSqlVariable() {
return mockSqlVariable("name", VariableValueType.STRING, "tom");
}
private SqlVariable mockExprSqlVariable() {
return mockSqlVariable("interval", VariableValueType.EXPR, "d");
}
private SqlVariable mockSqlVariable(String name, VariableValueType variableValueType, Object value) {
SqlVariable sqlVariable = new SqlVariable();
sqlVariable.setName(name);
sqlVariable.setValueType(variableValueType);
sqlVariable.setDefaultValues(Lists.newArrayList(value));
return sqlVariable;
}
private Param mockIdParam() {
return mockParam("id", "2");
}
private Param mockNameParam() {
return mockParam("name", "alice");
}
private Param mockIntervalParam() {
return mockParam("interval", "wk");
}
private Param mockParam(String name, String value) {
return new Param(name, value);
}
}

View File

@@ -224,7 +224,7 @@ public class BenchMarkDemoDataLoader {
new ViewModelConfig(5L, Lists.newArrayList(8L), Lists.newArrayList()), new ViewModelConfig(5L, Lists.newArrayList(8L), Lists.newArrayList()),
new ViewModelConfig(6L, Lists.newArrayList(9L, 10L), Lists.newArrayList()), new ViewModelConfig(6L, Lists.newArrayList(9L, 10L), Lists.newArrayList()),
new ViewModelConfig(7L, Lists.newArrayList(11L, 12L), Lists.newArrayList()), new ViewModelConfig(7L, Lists.newArrayList(11L, 12L), Lists.newArrayList()),
new ViewModelConfig(8L, Lists.newArrayList(13L, 14L, 15L), Lists.newArrayList(8L, 9L)) new ViewModelConfig(8L, Lists.newArrayList(13L, 14L), Lists.newArrayList(8L, 9L))
); );
ViewDetail viewDetail = new ViewDetail(); ViewDetail viewDetail = new ViewDetail();
viewDetail.setViewModelConfigs(viewModelConfigs); viewDetail.setViewModelConfigs(viewModelConfigs);

View File

@@ -23,7 +23,6 @@ import com.tencent.supersonic.common.pojo.SysParameter;
import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.service.SysParameterService; import com.tencent.supersonic.common.service.SysParameterService;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.server.service.KnowledgeService;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.beans.factory.annotation.Qualifier;
@@ -53,8 +52,6 @@ public class ChatDemoLoader implements CommandLineRunner {
private AgentService agentService; private AgentService agentService;
@Autowired @Autowired
private SysParameterService sysParameterService; private SysParameterService sysParameterService;
@Autowired
private KnowledgeService knowledgeService;
@Value("${demo.enabled:false}") @Value("${demo.enabled:false}")
private boolean demoEnabled; private boolean demoEnabled;

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic; package com.tencent.supersonic;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.parser.JavaLLMProxy; import com.tencent.supersonic.chat.core.parser.JavaLLMProxy;
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlExamplarLoader; import com.tencent.supersonic.chat.core.parser.sql.llm.SqlExamplarLoader;
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlExample; import com.tencent.supersonic.chat.core.parser.sql.llm.SqlExample;
import com.tencent.supersonic.chat.core.utils.ComponentFactory; import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import java.util.List; import java.util.List;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
@@ -20,7 +20,7 @@ public class EmbeddingInitListener implements CommandLineRunner {
@Autowired @Autowired
private SqlExamplarLoader sqlExamplarLoader; private SqlExamplarLoader sqlExamplarLoader;
@Autowired @Autowired
private OptimizationConfig optimizationConfig; private EmbeddingConfig embeddingConfig;
@Override @Override
public void run(String... args) { public void run(String... args) {
@@ -31,7 +31,7 @@ public class EmbeddingInitListener implements CommandLineRunner {
try { try {
if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) { if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) {
List<SqlExample> sqlExamples = sqlExamplarLoader.getSqlExamples(); List<SqlExample> sqlExamples = sqlExamplarLoader.getSqlExamples();
String collectionName = optimizationConfig.getText2sqlCollectionName(); String collectionName = embeddingConfig.getText2sqlCollectionName();
sqlExamplarLoader.addEmbeddingStore(sqlExamples, collectionName); sqlExamplarLoader.addEmbeddingStore(sqlExamples, collectionName);
} }
} catch (Exception e) { } catch (Exception e) {

View File

@@ -46,4 +46,4 @@ com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
com.tencent.supersonic.chat.server.processor.execute.MetricRatioProcessor com.tencent.supersonic.chat.server.processor.execute.MetricRatioProcessor
com.tencent.supersonic.common.util.embedding.S2EmbeddingStore=\ com.tencent.supersonic.common.util.embedding.S2EmbeddingStore=\
com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore

View File

@@ -42,6 +42,9 @@ metric:
mybatis: mybatis:
mapper-locations=classpath:mappers/custom/*.xml,classpath*:/mappers/*.xml mapper-locations=classpath:mappers/custom/*.xml,classpath*:/mappers/*.xml
corrector:
additional:
information: true
llm: llm:
parser: parser:

View File

@@ -183,7 +183,10 @@ CREATE TABLE s2_view(
created_at datetime, created_at datetime,
created_by VARCHAR(255), created_by VARCHAR(255),
updated_at datetime, updated_at datetime,
updated_by VARCHAR(255) updated_by VARCHAR(255),
query_config VARCHAR(3000),
`admin` varchar(3000) DEFAULT NULL,
`admin_org` varchar(3000) DEFAULT NULL
)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; )ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
alter table s2_plugin change column model `view` varchar(100); alter table s2_plugin change column model `view` varchar(100);

View File

@@ -5,14 +5,14 @@ dean _1_2 36
john _1_2 50 john _1_2 50
jack _1_2 38 jack _1_2 38
admin _1_2 70 admin _1_2 70
周杰伦 _2_7 100 周杰伦 _4_7 100
陈奕迅 _2_7 100 陈奕迅 _4_7 100
林俊杰 _2_7 100 林俊杰 _4_7 100
张碧晨 _2_7 100 张碧晨 _4_7 100
程响 _2_7 100 程响 _4_7 100
Taylor#Swift _2_7 100 Taylor#Swift _4_7 100
内地 _2_4 100 内地 _4_4 100
欧美 _2_4 100 欧美 _4_4 100
港台 _2_4 100 港台 _4_4 100
流行 _2_6 100 流行 _4_6 100
国风 _2_6 100 国风 _4_6 100

View File

@@ -1,6 +1,6 @@
p1 _2_3 52 p1 _3_3 52
p2 _2_3 47 p2 _3_3 47
p3 _2_3 31 p3 _3_3 31
p4 _2_3 36 p4 _3_3 36
p5 _2_3 50 p5 _3_3 50
p6 _2_3 38 p6 _3_3 38

View File

@@ -1,9 +1,9 @@
周杰伦 _2_7 9000 周杰伦 _4_7 9000
周深 _2_7 8000 周深 _4_7 8000
周传雄 _2_7 7000 周传雄 _4_7 7000
周华建 _2_7 6000 周华建 _4_7 6000
陈奕迅 _2_7 8000 陈奕迅 _4_7 8000
林俊杰 _2_7 7000 林俊杰 _4_7 7000
张碧晨 _2_7 7000 张碧晨 _4_7 7000
程响 _2_7 7000 程响 _4_7 7000
Taylor#Swift _2_7 7000 Taylor#Swift _4_7 7000

View File

@@ -1,4 +1,4 @@
美国 _3_8 1 美国 _5_8 1
加拿大 _3_8 1 加拿大 _5_8 1
锡尔赫特、吉大港、库斯蒂亚 _3_8 1 锡尔赫特、吉大港、库斯蒂亚 _5_8 1
孟加拉国 _3_8 3 孟加拉国 _5_8 3

View File

@@ -1,6 +1,6 @@
现代 _3_9 1 现代 _5_9 1
tagore _3_9 1 tagore _5_9 1
蓝调 _3_9 1 蓝调 _5_9 1
流行 _3_9 1 流行 _5_9 1
民间 _3_9 1 民间 _5_9 1
nazrul _3_9 1 nazrul _5_9 1

View File

@@ -1,4 +1,4 @@
美国 _3_11 1 美国 _6_11 1
印度 _3_11 2 印度 _6_11 2
英国 _3_11 1 英国 _6_11 1
孟加拉国 _3_11 2 孟加拉国 _6_11 2

View File

@@ -1,2 +1,2 @@
男性 _3_12 3 男性 _6_12 3
女性 _3_12 3 女性 _6_12 3

View File

@@ -1,2 +1,2 @@
mp4 _3_14 4 mp4 _7_14 4
mp3 _3_14 2 mp3 _7_14 2

View File

@@ -1,4 +1,4 @@
美国 _3_17 1 美国 _8_17 1
印度 _3_17 2 印度 _8_17 2
英国 _3_17 1 英国 _8_17 1
孟加拉国 _3_17 2 孟加拉国 _8_17 2

Some files were not shown because too many files have changed in this diff Show More