(improvement)(Chat) Move chat-core to headless (#805)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2024-03-12 22:20:30 +08:00
committed by GitHub
parent f152deeb81
commit f93bee81cb
301 changed files with 2256 additions and 4527 deletions

View File

@@ -47,7 +47,7 @@ public class DimValueAspect {
@Autowired
private DimensionService dimensionService;
@Around("execution(* com.tencent.supersonic.headless.server.service.QueryService.queryByReq(..))")
@Around("execution(* com.tencent.supersonic.headless.server.service.ChatQueryService.queryByReq(..))")
public Object handleDimValue(ProceedingJoinPoint joinPoint) throws Throwable {
if (!dimensionValueMapEnable) {
log.debug("dimensionValueMapEnable is false, skip dimensionValueMap");

View File

@@ -0,0 +1,82 @@
package com.tencent.supersonic.headless.server.listener;
import com.tencent.supersonic.headless.core.knowledge.DictWord;
import com.tencent.supersonic.headless.core.knowledge.KnowledgeService;
import com.tencent.supersonic.headless.server.service.impl.WordService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.CommandLineRunner;
import org.springframework.core.annotation.Order;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.concurrent.CompletableFuture;
@Slf4j
@Component
@Order(2)
public class ApplicationStartedListener implements CommandLineRunner {
@Autowired
private KnowledgeService knowledgeService;
@Autowired
private WordService wordService;
@Override
public void run(String... args) {
updateKnowledgeDimValue();
}
public Boolean updateKnowledgeDimValue() {
Boolean isOk = false;
try {
log.debug("ApplicationStartedInit start");
List<DictWord> dictWords = wordService.getAllDictWords();
wordService.setPreDictWords(dictWords);
knowledgeService.reloadAllData(dictWords);
log.debug("ApplicationStartedInit end");
isOk = true;
} catch (Exception e) {
log.error("ApplicationStartedInit error", e);
}
return isOk;
}
public Boolean updateKnowledgeDimValueAsync() {
CompletableFuture.supplyAsync(() -> {
updateKnowledgeDimValue();
return null;
});
return true;
}
/***
* reload knowledge task
*/
@Scheduled(cron = "${reload.knowledge.corn:0 0/1 * * * ?}")
public void reloadKnowledge() {
log.debug("reloadKnowledge start");
try {
List<DictWord> dictWords = wordService.getAllDictWords();
List<DictWord> preDictWords = wordService.getPreDictWords();
if (CollectionUtils.isEqualCollection(dictWords, preDictWords)) {
log.debug("dictWords has not changed, reloadKnowledge end");
return;
}
log.info("dictWords has changed");
wordService.setPreDictWords(dictWords);
knowledgeService.updateOnlineKnowledge(wordService.getAllDictWords());
} catch (Exception e) {
log.error("reloadKnowledge error", e);
}
log.debug("reloadKnowledge end");
}
}

View File

@@ -0,0 +1,45 @@
package com.tencent.supersonic.headless.server.listener;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DataEvent;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.common.pojo.enums.EventType;
import com.tencent.supersonic.headless.core.knowledge.DictWord;
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.ApplicationListener;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@Component
@Slf4j
public class SchemaDictUpdateListener implements ApplicationListener<DataEvent> {
@Async
@Override
public void onApplicationEvent(DataEvent dataEvent) {
if (CollectionUtils.isEmpty(dataEvent.getDataItems())) {
return;
}
dataEvent.getDataItems().forEach(dataItem -> {
DictWord dictWord = new DictWord();
dictWord.setWord(dataItem.getName());
String sign = DictWordType.NATURE_SPILT;
String suffixNature = DictWordType.getSuffixNature(dataItem.getType());
String nature = sign + dataItem.getModelId() + dataItem.getId() + suffixNature;
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
dictWord.setNature(nature);
dictWord.setNatureWithFrequency(natureWithFrequency);
if (EventType.ADD.equals(dataEvent.getEventType())) {
HanlpHelper.addToCustomDictionary(dictWord);
} else if (EventType.DELETE.equals(dataEvent.getEventType())) {
HanlpHelper.removeFromCustomDictionary(dictWord);
} else if (EventType.UPDATE.equals(dataEvent.getEventType())) {
HanlpHelper.removeFromCustomDictionary(dictWord);
dictWord.setWord(dataItem.getNewName());
HanlpHelper.addToCustomDictionary(dictWord);
}
});
}
}

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.headless.server.persistence.dataobject;
import lombok.Data;
import java.io.Serializable;
import java.time.Instant;
@Data
public class ChatContextDO implements Serializable {
private Integer chatId;
private Instant modifiedAt;
private String user;
private String queryText;
private String semanticParse;
}

View File

@@ -0,0 +1,54 @@
package com.tencent.supersonic.headless.server.persistence.dataobject;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import java.util.Date;
@Data
@Builder
@NoArgsConstructor
@Getter
@AllArgsConstructor
public class StatisticsDO {
/**
* questionId
*/
private Long questionId;
/**
* chatId
*/
private Long chatId;
/**
* createTime
*/
private Date createTime;
/**
* queryText
*/
private String queryText;
/**
* userName
*/
private String userName;
/**
* interface
*/
private String interfaceName;
/**
* cost
*/
private Integer cost;
private Integer type;
}

View File

@@ -0,0 +1,14 @@
package com.tencent.supersonic.headless.server.persistence.mapper;
import com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface ChatContextMapper {
ChatContextDO getContextByChatId(int chatId);
int updateContext(ChatContextDO contextDO);
int addContext(ChatContextDO contextDO);
}

View File

@@ -0,0 +1,12 @@
package com.tencent.supersonic.headless.server.persistence.mapper;
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
@Mapper
public interface StatisticsMapper {
boolean batchSaveStatistics(@Param("list") List<StatisticsDO> list);
}

View File

@@ -0,0 +1,11 @@
package com.tencent.supersonic.headless.server.persistence.repository;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
public interface ChatContextRepository {
ChatContext getOrCreateContext(int chatId);
void updateContext(ChatContext chatCtx);
}

View File

@@ -0,0 +1,72 @@
package com.tencent.supersonic.headless.server.persistence.repository.impl;
import com.google.gson.Gson;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO;
import com.tencent.supersonic.headless.server.persistence.mapper.ChatContextMapper;
import com.tencent.supersonic.headless.server.persistence.repository.ChatContextRepository;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;
@Repository
@Primary
@Slf4j
public class ChatContextRepositoryImpl implements ChatContextRepository {
@Autowired(required = false)
private final ChatContextMapper chatContextMapper;
public ChatContextRepositoryImpl(ChatContextMapper chatContextMapper) {
this.chatContextMapper = chatContextMapper;
}
@Override
public ChatContext getOrCreateContext(int chatId) {
ChatContextDO context = chatContextMapper.getContextByChatId(chatId);
if (context == null) {
ChatContext chatContext = new ChatContext();
chatContext.setChatId(chatId);
return chatContext;
}
return cast(context);
}
@Override
public void updateContext(ChatContext chatCtx) {
ChatContextDO context = cast(chatCtx);
if (chatContextMapper.getContextByChatId(chatCtx.getChatId()) == null) {
chatContextMapper.addContext(context);
} else {
chatContextMapper.updateContext(context);
}
}
private ChatContext cast(ChatContextDO contextDO) {
ChatContext chatContext = new ChatContext();
chatContext.setChatId(contextDO.getChatId());
chatContext.setUser(contextDO.getUser());
chatContext.setQueryText(contextDO.getQueryText());
if (contextDO.getSemanticParse() != null && !contextDO.getSemanticParse().isEmpty()) {
SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(),
SemanticParseInfo.class);
chatContext.setParseInfo(semanticParseInfo);
}
return chatContext;
}
private ChatContextDO cast(ChatContext chatContext) {
ChatContextDO chatContextDO = new ChatContextDO();
chatContextDO.setChatId(chatContext.getChatId());
chatContextDO.setQueryText(chatContext.getQueryText());
chatContextDO.setUser(chatContext.getUser());
if (chatContext.getParseInfo() != null) {
Gson g = new Gson();
chatContextDO.setSemanticParse(g.toJson(chatContext.getParseInfo()));
}
return chatContextDO;
}
}

View File

@@ -0,0 +1,43 @@
package com.tencent.supersonic.headless.server.processor;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.core.chat.query.QueryManager;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.server.service.impl.SemanticService;
import org.springframework.util.CollectionUtils;
import java.util.List;
/**
* EntityInfoProcessor fills core attributes of an entity so that
* users get to know which entity is parsed out.
*/
public class EntityInfoProcessor implements ResultProcessor {
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
if (CollectionUtils.isEmpty(selectedParses)) {
return;
}
selectedParses.forEach(parseInfo -> {
String queryMode = parseInfo.getQueryMode();
if (QueryManager.containsRuleQuery(queryMode)) {
return;
}
//1. set entity info
DataSetSchema dataSetSchema =
queryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, queryContext.getUser());
if (QueryManager.isTagQuery(queryMode)
|| QueryManager.isMetricQuery(queryMode)) {
parseInfo.setEntityInfo(entityInfo);
}
});
}
}

View File

@@ -0,0 +1,222 @@
package com.tencent.supersonic.headless.server.processor;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.server.service.impl.SemanticService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* ParseInfoProcessor extracts structured info from S2SQL so that
* users get to know the details.
**/
@Slf4j
public class ParseInfoProcessor implements ResultProcessor {
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
if (CollectionUtils.isEmpty(candidateQueries)) {
return;
}
List<SemanticParseInfo> candidateParses = candidateQueries.stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
candidateParses.forEach(this::updateParseInfo);
}
public void updateParseInfo(SemanticParseInfo parseInfo) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
if (StringUtils.isBlank(correctS2SQL)) {
return;
}
// if S2SQL equals correctS2SQL, then not update the parseInfo.
if (correctS2SQL.equals(sqlInfo.getS2SQL())) {
return;
}
List<FieldExpression> expressions = SqlSelectHelper.getFilterExpression(correctS2SQL);
//set dataInfo
try {
if (!org.apache.commons.collections.CollectionUtils.isEmpty(expressions)) {
DateConf dateInfo = getDateInfo(expressions);
if (dateInfo != null && parseInfo.getDateInfo() == null) {
parseInfo.setDateInfo(dateInfo);
}
}
} catch (Exception e) {
log.error("set dateInfo error :", e);
}
//set filter
Long dataSetId = parseInfo.getDataSetId();
try {
Map<String, SchemaElement> fieldNameToElement = getNameToElement(dataSetId);
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
parseInfo.getDimensionFilters().addAll(result);
} catch (Exception e) {
log.error("set dimensionFilter error :", e);
}
SemanticSchema semanticSchema = ContextUtils.getBean(SemanticService.class).getSemanticSchema();
if (Objects.isNull(semanticSchema)) {
return;
}
List<String> allFields = getFieldsExceptDate(SqlSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
Set<SchemaElement> metrics = getElements(dataSetId, allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics);
if (QueryType.METRIC.equals(parseInfo.getQueryType())) {
List<String> groupByFields = SqlSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL());
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
parseInfo.setDimensions(getElements(dataSetId, groupByDimensions, semanticSchema.getDimensions()));
} else if (QueryType.TAG.equals(parseInfo.getQueryType())) {
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
List<String> selectDimensions = getFieldsExceptDate(selectFields);
parseInfo.setDimensions(getElements(dataSetId, selectDimensions, semanticSchema.getDimensions()));
}
}
private Set<SchemaElement> getElements(Long dataSetId, List<String> allFields, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> {
if (CollectionUtils.isEmpty(schemaElement.getAlias())) {
return dataSetId.equals(schemaElement.getDataSet()) && allFields.contains(
schemaElement.getName());
}
Set<String> allFieldsSet = new HashSet<>(allFields);
Set<String> aliasSet = new HashSet<>(schemaElement.getAlias());
List<String> intersection = allFieldsSet.stream()
.filter(aliasSet::contains).collect(Collectors.toList());
return dataSetId.equals(schemaElement.getDataSet()) && (allFields.contains(
schemaElement.getName()) || !CollectionUtils.isEmpty(intersection));
}
).collect(Collectors.toSet());
}
private List<String> getFieldsExceptDate(List<String> allFields) {
if (org.springframework.util.CollectionUtils.isEmpty(allFields)) {
return new ArrayList<>();
}
return allFields.stream()
.filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry))
.collect(Collectors.toList());
}
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
List<FieldExpression> fieldExpressions) {
List<QueryFilter> result = Lists.newArrayList();
for (FieldExpression expression : fieldExpressions) {
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(expression.getFieldValue());
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
if (Objects.isNull(schemaElement)) {
continue;
}
dimensionFilter.setName(schemaElement.getName());
dimensionFilter.setBizName(schemaElement.getBizName());
dimensionFilter.setElementID(schemaElement.getId());
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
dimensionFilter.setOperator(operatorEnum);
dimensionFilter.setFunction(expression.getFunction());
result.add(dimensionFilter);
}
return result;
}
private DateConf getDateInfo(List<FieldExpression> fieldExpressions) {
List<FieldExpression> dateExpressions = fieldExpressions.stream()
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
.collect(Collectors.toList());
if (org.apache.commons.collections.CollectionUtils.isEmpty(dateExpressions)) {
return null;
}
DateConf dateInfo = new DateConf();
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
FieldExpression firstExpression = dateExpressions.get(0);
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
return dateInfo;
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
}
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
}
}
return dateInfo;
}
private boolean containOperators(FieldExpression expression, FilterOperatorEnum firstOperator,
FilterOperatorEnum... operatorEnums) {
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(
expression.getFieldValue()));
}
private boolean hasSecondDate(List<FieldExpression> dateExpressions) {
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
}
protected Map<String, SchemaElement> getNameToElement(Long dataSetId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SemanticService.class).getSemanticSchema();
List<SchemaElement> dimensions = semanticSchema.getDimensions(dataSetId);
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
List<SchemaElement> allElements = Lists.newArrayList();
allElements.addAll(dimensions);
allElements.addAll(metrics);
//support alias
return allElements.stream()
.flatMap(schemaElement -> {
Set<Pair<String, SchemaElement>> result = new HashSet<>();
result.add(Pair.of(schemaElement.getName(), schemaElement));
List<String> aliasList = schemaElement.getAlias();
if (!org.springframework.util.CollectionUtils.isEmpty(aliasList)) {
for (String alias : aliasList) {
result.add(Pair.of(alias, schemaElement));
}
}
return result.stream();
})
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(),
(value1, value2) -> value2));
}
}

View File

@@ -0,0 +1,84 @@
package com.tencent.supersonic.headless.server.processor;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
/**
* QueryRankProcessor ranks candidate parsing results based on
* a heuristic scoring algorithm and then takes topN.
**/
@Slf4j
public class QueryRankProcessor implements ResultProcessor {
private static final int candidateTopSize = 5;
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
candidateQueries = rank(candidateQueries);
queryContext.setCandidateQueries(candidateQueries);
}
public List<SemanticQuery> rank(List<SemanticQuery> candidateQueries) {
log.debug("pick before [{}]", candidateQueries);
if (CollectionUtils.isEmpty(candidateQueries)) {
return candidateQueries;
}
List<SemanticQuery> selectedQueries = new ArrayList<>();
if (candidateQueries.size() == 1) {
selectedQueries.addAll(candidateQueries);
} else {
selectedQueries = getTopCandidateQuery(candidateQueries);
}
generateParseInfoId(selectedQueries);
log.debug("pick after [{}]", selectedQueries);
return selectedQueries;
}
public List<SemanticQuery> getTopCandidateQuery(List<SemanticQuery> semanticQueries) {
return semanticQueries.stream()
.filter(query -> !checkFullyInherited(query))
.sorted((o1, o2) -> {
if (o1.getParseInfo().getScore() < o2.getParseInfo().getScore()) {
return 1;
} else if (o1.getParseInfo().getScore() > o2.getParseInfo().getScore()) {
return -1;
}
return 0;
}).limit(candidateTopSize)
.collect(Collectors.toList());
}
private void generateParseInfoId(List<SemanticQuery> semanticQueries) {
for (int i = 0; i < semanticQueries.size(); i++) {
SemanticQuery query = semanticQueries.get(i);
query.getParseInfo().setId(i + 1);
}
}
private boolean checkFullyInherited(SemanticQuery query) {
SemanticParseInfo parseInfo = query.getParseInfo();
if (!(query instanceof RuleSemanticQuery)) {
return false;
}
for (SchemaElementMatch match : parseInfo.getElementMatches()) {
if (!match.isInherited()) {
return false;
}
}
return parseInfo.getDateInfo() == null || parseInfo.getDateInfo().isInherited();
}
}

View File

@@ -0,0 +1,33 @@
package com.tencent.supersonic.headless.server.processor;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
import java.util.stream.Collectors;
/**
* RespBuildProcessor fill response object with parsing results.
**/
@Slf4j
public class RespBuildProcessor implements ResultProcessor {
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
parseResp.setChatId(queryContext.getChatId());
parseResp.setQueryText(queryContext.getQueryText());
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
if (candidateQueries.size() > 0) {
List<SemanticParseInfo> candidateParses = candidateQueries.stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
parseResp.setSelectedParses(candidateParses);
parseResp.setState(ParseResp.ParseState.COMPLETED);
} else {
parseResp.setState(ParseResp.ParseState.FAILED);
}
}
}

View File

@@ -0,0 +1,14 @@
package com.tencent.supersonic.headless.server.processor;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
/**
* A ParseResultProcessor wraps things up before returning results to users in parse stage.
*/
public interface ResultProcessor {
void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext);
}

View File

@@ -0,0 +1,84 @@
package com.tencent.supersonic.headless.server.processor;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.core.chat.query.QueryManager;
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.server.service.QueryService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* SqlInfoProcessor adds S2SQL to the parsing results so that
* technical users could verify SQL by themselves.
**/
@Slf4j
public class SqlInfoProcessor implements ResultProcessor {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
if (CollectionUtils.isEmpty(semanticQueries)) {
return;
}
List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo)
.collect(Collectors.toList());
long startTime = System.currentTimeMillis();
addSqlInfo(queryContext, selectedParses);
parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - startTime);
}
private void addSqlInfo(QueryContext queryContext, List<SemanticParseInfo> semanticParseInfos) {
if (CollectionUtils.isEmpty(semanticParseInfos)) {
return;
}
semanticParseInfos.forEach(parseInfo -> {
try {
addSqlInfo(queryContext, parseInfo);
} catch (Exception e) {
log.warn("get sql info failed:{}", parseInfo, e);
}
});
}
private void addSqlInfo(QueryContext queryContext, SemanticParseInfo parseInfo) throws Exception {
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
if (Objects.isNull(semanticQuery)) {
return;
}
semanticQuery.setParseInfo(parseInfo);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
QueryService queryService = ContextUtils.getBean(QueryService.class);
ExplainSqlReq<Object> explainSqlReq = ExplainSqlReq.builder().queryReq(semanticQueryReq)
.queryTypeEnum(QueryMethod.SQL).build();
ExplainResp explain = queryService.explain(explainSqlReq, queryContext.getUser());
String explainSql = explain.getSql();
if (StringUtils.isBlank(explainSql)) {
return;
}
SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (semanticQuery instanceof LLMSqlQuery) {
keyPipelineLog.info("\ns2sql:{}\ncorrectS2SQL:{}\nquerySQL:{}", sqlInfo.getS2SQL(),
sqlInfo.getCorrectS2SQL(), explainSql);
}
sqlInfo.setQuerySQL(explainSql);
}
}

View File

@@ -0,0 +1,22 @@
package com.tencent.supersonic.headless.server.processor;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
/**
* TimeCostProcessor adds time cost of parsing.
**/
@Slf4j
public class TimeCostProcessor implements ResultProcessor {
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
parseResp.getParseTimeCost().setParseTime(
System.currentTimeMillis() - parseStartTime - parseResp.getParseTimeCost().getSqlTime());
}
}

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.server.rest;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.BatchDownloadReq;
import com.tencent.supersonic.headless.api.pojo.request.DownloadStructReq;
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
@@ -19,10 +19,6 @@ import com.tencent.supersonic.headless.api.pojo.response.ItemUseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.server.service.DownloadService;
import com.tencent.supersonic.headless.server.service.QueryService;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.validation.Valid;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping;
@@ -30,6 +26,11 @@ import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.validation.Valid;
import java.util.List;
@RestController
@RequestMapping("/api/semantic/query")
@Slf4j
@@ -109,13 +110,13 @@ public class QueryController {
User user = UserHolder.findUser(request, response);
String queryReqJson = JsonUtil.toString(explainSqlReq.getQueryReq());
if (QueryType.SQL.equals(explainSqlReq.getQueryTypeEnum())) {
if (QueryMethod.SQL.equals(explainSqlReq.getQueryTypeEnum())) {
ExplainSqlReq<QuerySqlReq> explainSqlReqNew = ExplainSqlReq.<QuerySqlReq>builder()
.queryReq(JsonUtil.toObject(queryReqJson, QuerySqlReq.class))
.queryTypeEnum(explainSqlReq.getQueryTypeEnum()).build();
return queryService.explain(explainSqlReqNew, user);
}
if (QueryType.STRUCT.equals(explainSqlReq.getQueryTypeEnum())) {
if (QueryMethod.STRUCT.equals(explainSqlReq.getQueryTypeEnum())) {
ExplainSqlReq<QueryStructReq> explainSqlReqNew = ExplainSqlReq.<QueryStructReq>builder()
.queryReq(JsonUtil.toObject(queryReqJson, QueryStructReq.class))
.queryTypeEnum(explainSqlReq.getQueryTypeEnum()).build();

View File

@@ -6,8 +6,6 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryMetricReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.QueryService;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping;
@@ -15,6 +13,9 @@ import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@RestController
@RequestMapping("/api/semantic/query")
@Slf4j

View File

@@ -4,8 +4,6 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.server.service.QueryService;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping;
@@ -13,6 +11,9 @@ import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@RestController
@RequestMapping("/api/semantic/query")
@Slf4j

View File

@@ -4,8 +4,6 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.headless.api.pojo.request.QueryTagReq;
import com.tencent.supersonic.headless.server.service.QueryService;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping;
@@ -13,6 +11,9 @@ import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@RestController
@RequestMapping("/api/semantic/query")
@Slf4j

View File

@@ -0,0 +1,18 @@
package com.tencent.supersonic.headless.server.service;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
public interface ChatContextService {
/***
* get the model from context
* @param chatId
* @return
*/
Long getContextModel(Integer chatId);
ChatContext getOrCreateContext(int chatId);
void updateContext(ChatContext chatCtx);
}

View File

@@ -0,0 +1,30 @@
package com.tencent.supersonic.headless.server.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
/***
* QueryService for query and search
*/
public interface ChatQueryService {
ParseResp performParsing(QueryReq queryReq);
QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception;
SemanticParseInfo queryContext(Integer chatId);
QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws Exception;
EntityInfo getEntityInfo(SemanticParseInfo parseInfo, User user);
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
}

View File

@@ -24,9 +24,12 @@ public interface DataSetService {
Map<Long, List<Long>> getModelIdToDataSetIds(List<Long> dataSetIds);
Map<Long, List<Long>> getModelIdToDataSetIds();
List<DataSetResp> getDataSets(User user);
List<DataSetResp> getDataSetsInheritAuth(User user, Long domainId);
SemanticQueryReq convert(QueryDataSetReq queryDataSetReq);
}

View File

@@ -1,25 +0,0 @@
package com.tencent.supersonic.headless.server.service;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.DictWord;
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
import java.util.List;
import java.util.Set;
public interface KnowledgeService {
List<S2Term> getTerms(String text);
List<HanlpMapResult> prefixSearch(String key, int limit, Set<Long> dataSetIds);
List<HanlpMapResult> suffixSearch(String key, int limit, Set<Long> dataSetIds);
void updateSemanticKnowledge(List<DictWord> natures);
void reloadAllData(List<DictWord> natures);
void updateOnlineKnowledge(List<DictWord> natures);
}

View File

@@ -1,11 +0,0 @@
package com.tencent.supersonic.headless.server.service;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import java.util.List;
public interface MetaEmbeddingService {
List<RetrieveQueryResult> retrieveQuery(List<Long> dataSetIds, RetrieveQuery retrieveQuery, int num);
}

View File

@@ -11,8 +11,9 @@ import com.tencent.supersonic.headless.api.pojo.response.ItemQueryResultResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemUseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.server.annotation.ApiHeaderCheck;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import java.util.List;
public interface QueryService {

View File

@@ -3,12 +3,14 @@ package com.tencent.supersonic.headless.server.service;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.request.DataSetFilterReq;
import com.tencent.supersonic.headless.api.pojo.request.ItemUseReq;
import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq;
import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq;
import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq;
import com.tencent.supersonic.headless.api.pojo.request.SchemaItemQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.DataSetFilterReq;
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
@@ -17,8 +19,6 @@ import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
import java.util.List;
import java.util.concurrent.ExecutionException;
@@ -27,6 +27,10 @@ public interface SchemaService {
List<DataSetSchemaResp> fetchDataSetSchema(DataSetFilterReq filter);
DataSetSchema getDataSetSchema(Long dataSetId);
List<DataSetSchema> getDataSetSchema();
List<ModelSchemaResp> fetchModelSchemaResps(List<Long> modelIds);
PageInfo<DimensionResp> queryDimension(PageDimensionReq pageDimensionReq, User user);
@@ -39,8 +43,6 @@ public interface SchemaService {
List<ModelResp> getModelList(User user, AuthType authType, Long domainId);
List<DataSetResp> getDataSetList(Long domainId);
SemanticSchemaResp fetchSemanticSchema(SchemaFilterReq schemaFilterReq);
List<ItemUseResp> getStatInfo(ItemUseReq itemUseReq) throws ExecutionException;

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.headless.server.service;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import java.util.List;
/**
* search service
*/
public interface SearchService {
List<SearchResult> search(QueryReq queryCtx);
}

View File

@@ -0,0 +1,50 @@
package com.tencent.supersonic.headless.server.service.impl;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.server.persistence.repository.ChatContextRepository;
import com.tencent.supersonic.headless.server.service.ChatContextService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.Objects;
@Slf4j
@Service
public class ChatContextServiceImpl implements ChatContextService {
private ChatContextRepository chatContextRepository;
public ChatContextServiceImpl(ChatContextRepository chatContextRepository) {
this.chatContextRepository = chatContextRepository;
}
@Override
public Long getContextModel(Integer chatId) {
if (Objects.isNull(chatId)) {
return null;
}
ChatContext chatContext = getOrCreateContext(chatId);
if (Objects.isNull(chatContext)) {
return null;
}
SemanticParseInfo originalSemanticParse = chatContext.getParseInfo();
if (Objects.nonNull(originalSemanticParse) && Objects.nonNull(originalSemanticParse.getDataSetId())) {
return originalSemanticParse.getDataSetId();
}
return null;
}
@Override
public ChatContext getOrCreateContext(int chatId) {
return chatContextRepository.getOrCreateContext(chatId);
}
@Override
public void updateContext(ChatContext chatCtx) {
log.debug("save ChatContext {}", chatCtx);
chatContextRepository.updateContext(chatCtx);
}
}

View File

@@ -0,0 +1,638 @@
package com.tencent.supersonic.headless.server.service.impl;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.enums.CostType;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.core.chat.corrector.SemanticCorrector;
import com.tencent.supersonic.headless.core.chat.mapper.SchemaMapper;
import com.tencent.supersonic.headless.core.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.core.chat.query.QueryManager;
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.core.knowledge.KnowledgeService;
import com.tencent.supersonic.headless.core.knowledge.SearchService;
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
import com.tencent.supersonic.headless.server.service.ChatContextService;
import com.tencent.supersonic.headless.server.service.ChatQueryService;
import com.tencent.supersonic.headless.server.service.DataSetService;
import com.tencent.supersonic.headless.server.service.QueryService;
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.schema.Column;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@Service
@Slf4j
public class ChatQueryServiceImpl implements ChatQueryService {
@Autowired
private SemanticService semanticService;
@Autowired
private ChatContextService chatContextService;
@Autowired
private KnowledgeService knowledgeService;
@Autowired
private QueryService queryService;
@Autowired
private DataSetService dataSetService;
@Value("${time.threshold: 100}")
private Integer timeThreshold;
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
private List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
@Override
public ParseResp performParsing(QueryReq queryReq) {
ParseResp parseResult = new ParseResp();
// build queryContext and chatContext
QueryContext queryCtx = buildQueryContext(queryReq);
// in order to support multi-turn conversation, chat context is needed
ChatContext chatCtx = chatContextService.getOrCreateContext(queryReq.getChatId());
List<StatisticsDO> timeCostDOList = new ArrayList<>();
// 1. mapper
schemaMappers.forEach(mapper -> {
long startTime = System.currentTimeMillis();
mapper.map(queryCtx);
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(mapper.getClass().getSimpleName()).type(CostType.MAPPER.getType()).build());
});
// 2. parser
semanticParsers.forEach(parser -> {
long startTime = System.currentTimeMillis();
parser.parse(queryCtx, chatCtx);
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(parser.getClass().getSimpleName()).type(CostType.PARSER.getType()).build());
log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
});
// 3. corrector
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
if (CollectionUtils.isNotEmpty(candidateQueries)) {
for (SemanticQuery semanticQuery : candidateQueries) {
// the rules are not being corrected.
if (semanticQuery instanceof RuleSemanticQuery) {
continue;
}
semanticCorrectors.forEach(corrector -> {
corrector.correct(queryCtx, semanticQuery.getParseInfo());
});
}
}
//4. processor
resultProcessors.forEach(processor -> {
processor.process(parseResult, queryCtx, chatCtx);
});
return parseResult;
}
private QueryContext buildQueryContext(QueryReq queryReq) {
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds();
QueryContext queryCtx = QueryContext.builder()
.queryFilters(queryReq.getQueryFilters())
.semanticSchema(semanticSchema)
.candidateQueries(new ArrayList<>())
.mapInfo(new SchemaMapInfo())
.modelIdToDataSetIds(modelIdToDataSetIds)
.build();
BeanUtils.copyProperties(queryReq, queryCtx);
return queryCtx;
}
@Override
public QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception {
List<StatisticsDO> timeCostDOList = new ArrayList<>();
SemanticParseInfo parseInfo = queryReq.getParseInfo();
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
if (semanticQuery == null) {
return null;
}
semanticQuery.setParseInfo(parseInfo);
// in order to support multi-turn conversation, chat context is needed
ChatContext chatCtx = chatContextService.getOrCreateContext(queryReq.getChatId());
long startTime = System.currentTimeMillis();
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
QueryResult queryResult = doExecution(semanticQueryReq, parseInfo, queryReq.getUser());
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(semanticQuery.getClass().getSimpleName()).type(CostType.QUERY.getType()).build());
queryResult.setQueryTimeCost(timeCostDOList.get(0).getCost().longValue());
queryResult.setChatContext(parseInfo);
// update chat context after a successful semantic query
if (QueryState.SUCCESS.equals(queryResult.getQueryState()) && queryReq.isSaveAnswer()) {
chatCtx.setParseInfo(parseInfo);
chatContextService.updateContext(chatCtx);
}
chatCtx.setQueryText(queryReq.getQueryText());
chatCtx.setUser(queryReq.getUser().getName());
return queryResult;
}
private QueryResult doExecution(SemanticQueryReq semanticQueryReq,
SemanticParseInfo parseInfo, User user) throws Exception {
SemanticQueryResp queryResp = queryService.queryByReq(semanticQueryReq, user);
QueryResult queryResult = new QueryResult();
if (queryResp != null) {
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
}
String sql = queryResp == null ? null : queryResp.getSql();
List<Map<String, Object>> resultList = queryResp == null ? new ArrayList<>()
: queryResp.getResultList();
List<QueryColumn> columns = queryResp == null ? new ArrayList<>() : queryResp.getColumns();
queryResult.setQuerySql(sql);
queryResult.setQueryResults(resultList);
queryResult.setQueryColumns(columns);
queryResult.setQueryMode(parseInfo.getQueryMode());
queryResult.setQueryState(QueryState.SUCCESS);
return queryResult;
}
@Override
public SemanticParseInfo queryContext(Integer chatId) {
ChatContext context = chatContextService.getOrCreateContext(chatId);
return context.getParseInfo();
}
//mainly used for executing after revising filters,for example:"fans_cnt>=100000"->"fans_cnt>500000",
//"style='流行'"->"style in ['流行','爱国']"
@Override
public QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws Exception {
SemanticParseInfo parseInfo = getSemanticParseInfo(queryData);
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
semanticQuery.setParseInfo(parseInfo);
List<String> fields = new ArrayList<>();
if (Objects.nonNull(parseInfo.getSqlInfo())
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) {
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
fields = SqlSelectHelper.getAllFields(correctorSql);
}
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())
&& checkMetricReplace(fields, queryData.getMetrics())) {
//replace metrics
log.info("llm begin replace metrics!");
SchemaElement metricToReplace = queryData.getMetrics().iterator().next();
replaceMetrics(parseInfo, metricToReplace);
} else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
log.info("llm begin revise filters!");
String correctorSql = reviseCorrectS2SQL(queryData, parseInfo);
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
semanticQuery.setParseInfo(parseInfo);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
ExplainSqlReq<Object> explainSqlReq = ExplainSqlReq.builder().queryReq(semanticQueryReq)
.queryTypeEnum(QueryMethod.SQL).build();
ExplainResp explain = queryService.explain(explainSqlReq, user);
if (StringUtils.isNotBlank(explain.getSql())) {
parseInfo.getSqlInfo().setQuerySQL(explain.getSql());
}
} else {
log.info("rule begin replace metrics and revise filters!");
//remove unvalid filters
validFilter(semanticQuery.getParseInfo().getDimensionFilters());
validFilter(semanticQuery.getParseInfo().getMetricFilters());
//init s2sql
semanticQuery.initS2Sql(semanticSchema, user);
QueryReq queryReq = new QueryReq();
queryReq.setQueryFilters(new QueryFilters());
queryReq.setUser(user);
}
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
QueryResult queryResult = doExecution(semanticQueryReq, semanticQuery.getParseInfo(), user);
queryResult.setChatContext(semanticQuery.getParseInfo());
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(parseInfo.getDataSetId());
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user);
queryResult.setEntityInfo(entityInfo);
return queryResult;
}
private boolean checkMetricReplace(List<String> oriFields, Set<SchemaElement> metrics) {
if (CollectionUtils.isEmpty(oriFields)) {
return false;
}
if (CollectionUtils.isEmpty(metrics)) {
return false;
}
List<String> metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList());
return !oriFields.containsAll(metricNames);
}
public String reviseCorrectS2SQL(QueryDataReq queryData, SemanticParseInfo parseInfo) {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
log.info("correctorSql before replacing:{}", correctorSql);
// get where filter and having filter
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql);
List<FieldExpression> havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql);
List<Expression> addWhereConditions = new ArrayList<>();
List<Expression> addHavingConditions = new ArrayList<>();
Set<String> removeWhereFieldNames = new HashSet<>();
Set<String> removeHavingFieldNames = new HashSet<>();
// replace where filter
updateFilters(whereExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames);
updateDateInfo(queryData, parseInfo, filedNameToValueMap,
whereExpressionList, addWhereConditions, removeWhereFieldNames);
correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
correctorSql = SqlRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames);
// replace having filter
updateFilters(havingExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames);
correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap);
correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
correctorSql = SqlAddHelper.addWhere(correctorSql, addWhereConditions);
correctorSql = SqlAddHelper.addHaving(correctorSql, addHavingConditions);
log.info("correctorSql after replacing:{}", correctorSql);
return correctorSql;
}
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
List<String> oriMetrics = parseInfo.getMetrics().stream()
.map(SchemaElement::getName).collect(Collectors.toList());
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
log.info("before replaceMetrics:{}", correctorSql);
log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric);
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
if (CollectionUtils.isNotEmpty(oriMetrics) && !oriMetrics.contains(metric.getName())) {
fieldMap.put(oriMetrics.get(0), Pair.of(metric.getName(), metric.getDefaultAgg()));
correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap);
}
log.info("after replaceMetrics:{}", correctorSql);
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
}
@Override
public EntityInfo getEntityInfo(SemanticParseInfo parseInfo, User user) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
DataSetSchema dataSetSchema =
semanticService.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
return semanticService.getEntityInfo(parseInfo, dataSetSchema, user);
}
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
Map<String, Map<String, String>> filedNameToValueMap,
List<FieldExpression> fieldExpressionList,
List<Expression> addConditions,
Set<String> removeFieldNames) {
if (Objects.isNull(queryData.getDateInfo())) {
return;
}
Map<String, String> map = new HashMap<>();
String dateField = TimeDimensionEnum.DAY.getChName();
if (queryData.getDateInfo().getUnit() > 1) {
queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1));
queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(1));
}
// startDate equals to endDate
if (queryData.getDateInfo().getStartDate().equals(queryData.getDateInfo().getEndDate())) {
for (FieldExpression fieldExpression : fieldExpressionList) {
if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) {
//sql where condition exists 'equals' operator about date,just replace
if (fieldExpression.getOperator().equals(FilterOperatorEnum.EQUALS)) {
dateField = fieldExpression.getFieldName();
map.put(fieldExpression.getFieldValue().toString(),
queryData.getDateInfo().getStartDate());
filedNameToValueMap.put(dateField, map);
} else {
// first remove,then add
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
EqualsTo equalsTo = new EqualsTo();
Column column = new Column(TimeDimensionEnum.DAY.getChName());
StringValue stringValue = new StringValue(queryData.getDateInfo().getStartDate());
equalsTo.setLeftExpression(column);
equalsTo.setRightExpression(stringValue);
addConditions.add(equalsTo);
}
break;
}
}
} else {
for (FieldExpression fieldExpression : fieldExpressionList) {
if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) {
dateField = fieldExpression.getFieldName();
//just replace
if (FilterOperatorEnum.GREATER_THAN_EQUALS.getValue().equals(fieldExpression.getOperator())
|| FilterOperatorEnum.GREATER_THAN.getValue().equals(fieldExpression.getOperator())) {
map.put(fieldExpression.getFieldValue().toString(),
queryData.getDateInfo().getStartDate());
}
if (FilterOperatorEnum.MINOR_THAN_EQUALS.getValue().equals(fieldExpression.getOperator())
|| FilterOperatorEnum.MINOR_THAN.getValue().equals(fieldExpression.getOperator())) {
map.put(fieldExpression.getFieldValue().toString(),
queryData.getDateInfo().getEndDate());
}
filedNameToValueMap.put(dateField, map);
// first remove,then add
if (FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator())) {
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions);
MinorThanEquals minorThanEquals = new MinorThanEquals();
addTimeFilters(queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions);
}
}
}
}
parseInfo.setDateInfo(queryData.getDateInfo());
}
private <T extends ComparisonOperator> void addTimeFilters(String date,
T comparisonExpression,
List<Expression> addConditions) {
Column column = new Column(TimeDimensionEnum.DAY.getChName());
StringValue stringValue = new StringValue(date);
comparisonExpression.setLeftExpression(column);
comparisonExpression.setRightExpression(stringValue);
addConditions.add(comparisonExpression);
}
private void updateFilters(List<FieldExpression> fieldExpressionList,
Set<QueryFilter> metricFilters,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions,
Set<String> removeFieldNames) {
if (CollectionUtils.isEmpty(metricFilters)) {
return;
}
for (QueryFilter dslQueryFilter : metricFilters) {
for (FieldExpression fieldExpression : fieldExpressionList) {
if (fieldExpression.getFieldName() != null
&& fieldExpression.getFieldName().contains(dslQueryFilter.getName())) {
removeFieldNames.add(dslQueryFilter.getName());
if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) {
EqualsTo equalsTo = new EqualsTo();
addWhereFilters(dslQueryFilter, equalsTo, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN_EQUALS)) {
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
addWhereFilters(dslQueryFilter, greaterThanEquals, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN)) {
GreaterThan greaterThan = new GreaterThan();
addWhereFilters(dslQueryFilter, greaterThan, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN_EQUALS)) {
MinorThanEquals minorThanEquals = new MinorThanEquals();
addWhereFilters(dslQueryFilter, minorThanEquals, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN)) {
MinorThan minorThan = new MinorThan();
addWhereFilters(dslQueryFilter, minorThan, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.IN)) {
InExpression inExpression = new InExpression();
addWhereInFilters(dslQueryFilter, inExpression, contextMetricFilters, addConditions);
}
break;
}
}
}
}
// add in condition to sql where condition
private void addWhereInFilters(QueryFilter dslQueryFilter,
InExpression inExpression,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) {
Column column = new Column(dslQueryFilter.getName());
ExpressionList expressionList = new ExpressionList();
List<Expression> expressions = new ArrayList<>();
List<String> valueList = JsonUtil.toList(
JsonUtil.toString(dslQueryFilter.getValue()), String.class);
if (CollectionUtils.isEmpty(valueList)) {
return;
}
valueList.stream().forEach(o -> {
StringValue stringValue = new StringValue(o);
expressions.add(stringValue);
});
expressionList.setExpressions(expressions);
inExpression.setLeftExpression(column);
inExpression.setRightItemsList(expressionList);
addConditions.add(inExpression);
contextMetricFilters.stream().forEach(o -> {
if (o.getName().equals(dslQueryFilter.getName())) {
o.setValue(dslQueryFilter.getValue());
o.setOperator(dslQueryFilter.getOperator());
}
});
}
// add where filter
private <T extends ComparisonOperator> void addWhereFilters(QueryFilter dslQueryFilter,
T comparisonExpression,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) {
String columnName = dslQueryFilter.getName();
if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) {
columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")";
}
if (Objects.isNull(dslQueryFilter.getValue())) {
return;
}
Column column = new Column(columnName);
comparisonExpression.setLeftExpression(column);
if (StringUtils.isNumeric(dslQueryFilter.getValue().toString())) {
LongValue longValue = new LongValue(Long.parseLong(dslQueryFilter.getValue().toString()));
comparisonExpression.setRightExpression(longValue);
} else {
StringValue stringValue = new StringValue(dslQueryFilter.getValue().toString());
comparisonExpression.setRightExpression(stringValue);
}
addConditions.add(comparisonExpression);
contextMetricFilters.stream().forEach(o -> {
if (o.getName().equals(dslQueryFilter.getName())) {
o.setValue(dslQueryFilter.getValue());
o.setOperator(dslQueryFilter.getOperator());
}
});
}
private SemanticParseInfo getSemanticParseInfo(QueryDataReq queryData) {
SemanticParseInfo parseInfo = queryData.getParseInfo();
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
return parseInfo;
}
if (CollectionUtils.isNotEmpty(queryData.getDimensions())) {
parseInfo.setDimensions(queryData.getDimensions());
}
if (CollectionUtils.isNotEmpty(queryData.getMetrics())) {
parseInfo.setMetrics(queryData.getMetrics());
}
if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) {
parseInfo.setDimensionFilters(queryData.getDimensionFilters());
}
if (CollectionUtils.isNotEmpty(queryData.getMetricFilters())) {
parseInfo.setMetricFilters(queryData.getMetricFilters());
}
if (Objects.nonNull(queryData.getDateInfo())) {
parseInfo.setDateInfo(queryData.getDateInfo());
}
return parseInfo;
}
private void validFilter(Set<QueryFilter> filters) {
for (QueryFilter queryFilter : filters) {
if (Objects.isNull(queryFilter.getValue())) {
filters.remove(queryFilter);
}
if (queryFilter.getOperator().equals(FilterOperatorEnum.IN) && CollectionUtils.isEmpty(
JsonUtil.toList(JsonUtil.toString(queryFilter.getValue()), String.class))) {
filters.remove(queryFilter);
}
}
}
@Override
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
SemanticQueryResp semanticQueryResp = new SemanticQueryResp();
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
SchemaElement schemaElement = semanticSchema.getDimension(dimensionValueReq.getElementID());
Set<Long> detectDataSetIds = new HashSet<>();
detectDataSetIds.add(schemaElement.getDataSet());
dimensionValueReq.setModelId(schemaElement.getModel());
List<String> dimensionValues = getDimensionValues(dimensionValueReq, detectDataSetIds);
// if the search results is null,search dimensionValue from database
if (CollectionUtils.isEmpty(dimensionValues)) {
semanticQueryResp = queryDatabase(dimensionValueReq, user);
return semanticQueryResp;
}
List<QueryColumn> columns = new ArrayList<>();
QueryColumn queryColumn = new QueryColumn();
queryColumn.setNameEn(dimensionValueReq.getBizName());
queryColumn.setShowType("CATEGORY");
queryColumn.setAuthorized(true);
queryColumn.setType("CHAR");
columns.add(queryColumn);
List<Map<String, Object>> resultList = new ArrayList<>();
dimensionValues.stream().forEach(o -> {
Map<String, Object> map = new HashMap<>();
map.put(dimensionValueReq.getBizName(), o);
resultList.add(map);
});
semanticQueryResp.setColumns(columns);
semanticQueryResp.setResultList(resultList);
return semanticQueryResp;
}
private List<String> getDimensionValues(DimensionValueReq dimensionValueReq, Set<Long> dataSetIds) {
//if value is null ,then search from NATURE_TO_VALUES
if (StringUtils.isBlank(dimensionValueReq.getValue())) {
return SearchService.getDimensionValue(dimensionValueReq);
}
Map<Long, List<Long>> modelIdToDataSetIds = new HashMap<>();
modelIdToDataSetIds.put(dimensionValueReq.getModelId(), new ArrayList<>(dataSetIds));
//search from prefixSearch
List<HanlpMapResult> hanlpMapResultList = knowledgeService.prefixSearch(dimensionValueReq.getValue(),
2000, modelIdToDataSetIds);
HanlpHelper.transLetterOriginal(hanlpMapResultList);
return hanlpMapResultList.stream()
.filter(o -> {
for (String nature : o.getNatures()) {
Long elementID = NatureHelper.getElementID(nature);
if (dimensionValueReq.getElementID().equals(elementID)) {
return true;
}
}
return false;
})
.map(mapResult -> mapResult.getName())
.collect(Collectors.toList());
}
private SemanticQueryResp queryDatabase(DimensionValueReq dimensionValueReq, User user) throws Exception {
QueryStructReq queryStructReq = new QueryStructReq();
DateConf dateConf = new DateConf();
dateConf.setDateMode(DateConf.DateMode.RECENT);
dateConf.setUnit(1);
dateConf.setPeriod("DAY");
queryStructReq.setDateInfo(dateConf);
queryStructReq.setLimit(20L);
queryStructReq.setDataSetId(dimensionValueReq.getModelId());
queryStructReq.setQueryType(QueryType.ID);
List<String> groups = new ArrayList<>();
groups.add(dimensionValueReq.getBizName());
queryStructReq.setGroups(groups);
return queryService.queryByReq(queryStructReq, user);
}
}

View File

@@ -239,6 +239,11 @@ public class DataSetServiceImpl
Collectors.mapping(Pair::getRight, Collectors.toList())));
}
@Override
public Map<Long, List<Long>> getModelIdToDataSetIds() {
return getModelIdToDataSetIds(Lists.newArrayList());
}
private void conflictCheck(DataSetResp dataSetResp) {
List<Long> allDimensionIds = dataSetResp.getAllDimensions();
List<Long> allMetricIds = dataSetResp.getAllMetrics();

View File

@@ -9,11 +9,11 @@ import com.tencent.supersonic.headless.api.pojo.request.DictSingleTaskReq;
import com.tencent.supersonic.headless.api.pojo.response.DictItemResp;
import com.tencent.supersonic.headless.api.pojo.response.DictTaskResp;
import com.tencent.supersonic.headless.core.file.FileHandler;
import com.tencent.supersonic.headless.core.knowledge.KnowledgeService;
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.server.persistence.dataobject.DictTaskDO;
import com.tencent.supersonic.headless.server.persistence.repository.DictRepository;
import com.tencent.supersonic.headless.server.service.DictTaskService;
import com.tencent.supersonic.headless.server.service.KnowledgeService;
import com.tencent.supersonic.headless.server.utils.DictUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
@@ -41,7 +41,6 @@ public class DictTaskServiceImpl implements DictTaskService {
private final DictUtils dictConverter;
private final DictUtils dictUtils;
private final FileHandler fileHandler;
private final KnowledgeService knowledgeService;
public DictTaskServiceImpl(DictRepository dictRepository,
DictUtils dictConverter,
@@ -52,7 +51,6 @@ public class DictTaskServiceImpl implements DictTaskService {
this.dictConverter = dictConverter;
this.dictUtils = dictUtils;
this.fileHandler = fileHandler;
this.knowledgeService = knowledgeService;
}
@Override

View File

@@ -1,97 +0,0 @@
package com.tencent.supersonic.headless.server.service.impl;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.DictWord;
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.core.knowledge.SearchService;
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.server.service.KnowledgeService;
import com.tencent.supersonic.headless.server.service.DataSetService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
@Service
@Slf4j
public class KnowledgeServiceImpl implements KnowledgeService {
private final DataSetService dataSetService;
public KnowledgeServiceImpl(DataSetService dataSetService) {
this.dataSetService = dataSetService;
}
@Override
public void updateSemanticKnowledge(List<DictWord> natures) {
List<DictWord> prefixes = natures.stream()
.filter(entry -> !entry.getNatureWithFrequency().contains(DictWordType.SUFFIX.getTypeWithSpilt()))
.collect(Collectors.toList());
for (DictWord nature : prefixes) {
HanlpHelper.addToCustomDictionary(nature);
}
List<DictWord> suffixes = natures.stream()
.filter(entry -> entry.getNatureWithFrequency().contains(DictWordType.SUFFIX.getTypeWithSpilt()))
.collect(Collectors.toList());
SearchService.loadSuffix(suffixes);
}
@Override
public void reloadAllData(List<DictWord> natures) {
// 1. reload custom knowledge
try {
HanlpHelper.reloadCustomDictionary();
} catch (Exception e) {
log.error("reloadCustomDictionary error", e);
}
// 2. update online knowledge
updateOnlineKnowledge(natures);
}
@Override
public void updateOnlineKnowledge(List<DictWord> natures) {
try {
updateSemanticKnowledge(natures);
} catch (Exception e) {
log.error("updateSemanticKnowledge error", e);
}
}
@Override
public List<S2Term> getTerms(String text) {
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds(new ArrayList<>());
return HanlpHelper.getTerms(text, modelIdToDataSetIds);
}
@Override
public List<HanlpMapResult> prefixSearch(String key, int limit, Set<Long> dataSetIds) {
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds(new ArrayList<>(dataSetIds));
return prefixSearchByModel(key, limit, modelIdToDataSetIds);
}
public List<HanlpMapResult> prefixSearchByModel(String key, int limit,
Map<Long, List<Long>> modelIdToDataSetIds) {
return SearchService.prefixSearch(key, limit, modelIdToDataSetIds);
}
@Override
public List<HanlpMapResult> suffixSearch(String key, int limit, Set<Long> dataSetIds) {
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds(new ArrayList<>(dataSetIds));
return suffixSearchByModel(key, limit, modelIdToDataSetIds.keySet());
}
public List<HanlpMapResult> suffixSearchByModel(String key, int limit, Set<Long> models) {
return SearchService.suffixSearch(key, limit, models);
}
}

View File

@@ -1,97 +0,0 @@
package com.tencent.supersonic.headless.server.service.impl;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import com.tencent.supersonic.headless.server.service.DataSetService;
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@Service
@Slf4j
public class MetaEmbeddingServiceImpl implements MetaEmbeddingService {
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
@Autowired
private EmbeddingConfig embeddingConfig;
@Autowired
private DataSetService dataSetService;
@Override
public List<RetrieveQueryResult> retrieveQuery(List<Long> dataSetIds, RetrieveQuery retrieveQuery, int num) {
// dataSetIds->modelIds
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds(dataSetIds);
Set<Long> allModels = modelIdToDataSetIds.keySet();
if (CollectionUtils.isNotEmpty(allModels) && allModels.size() == 1) {
Map<String, String> filterCondition = new HashMap<>();
filterCondition.put("modelId", allModels.stream().findFirst().get().toString());
retrieveQuery.setFilterCondition(filterCondition);
}
String collectionName = embeddingConfig.getMetaCollectionName();
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(collectionName, retrieveQuery, num);
if (CollectionUtils.isEmpty(resultList)) {
return new ArrayList<>();
}
//filter by modelId
if (CollectionUtils.isEmpty(allModels)) {
return resultList;
}
return resultList.stream()
.map(retrieveQueryResult -> {
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
if (CollectionUtils.isEmpty(retrievals)) {
return retrieveQueryResult;
}
//filter by modelId
retrievals.removeIf(retrieval -> {
Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId"));
if (Objects.isNull(modelId)) {
return CollectionUtils.isEmpty(allModels);
}
return !allModels.contains(modelId);
});
//add dataSetId
retrievals = retrievals.stream().flatMap(retrieval -> {
Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId"));
List<Long> dataSetIdsByModelId = modelIdToDataSetIds.get(modelId);
if (!CollectionUtils.isEmpty(dataSetIdsByModelId)) {
Set<Retrieval> result = new HashSet<>();
for (Long dataSetId : dataSetIdsByModelId) {
Retrieval retrievalNew = new Retrieval();
BeanUtils.copyProperties(retrieval, retrievalNew);
retrievalNew.getMetadata().putIfAbsent("dataSetId", dataSetId + Constants.UNDERLINE);
result.add(retrievalNew);
}
return result.stream();
}
Set<Retrieval> result = new HashSet<>();
result.add(retrieval);
return result.stream();
}).collect(Collectors.toList());
retrieveQueryResult.setRetrieval(retrievals);
return retrieveQueryResult;
})
.filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval()))
.collect(Collectors.toList());
}
}

View File

@@ -458,14 +458,14 @@ public class MetricServiceImpl implements MetricService {
if (bizNameMap.containsKey(metricReq.getBizName())) {
MetricResp metricResp = bizNameMap.get(metricReq.getBizName());
if (!metricResp.getId().equals(metricReq.getId())) {
throw new RuntimeException(String.format("主题域下存在相同的指标字段名:%s 创建人:%s",
throw new RuntimeException(String.format("模型下存在相同的指标字段名:%s 创建人:%s",
metricReq.getBizName(), metricResp.getCreatedBy()));
}
}
if (nameMap.containsKey(metricReq.getName())) {
MetricResp metricResp = nameMap.get(metricReq.getName());
if (!metricResp.getId().equals(metricReq.getId())) {
throw new RuntimeException(String.format("主题域下存在相同的指标名:%s 创建人:%s",
throw new RuntimeException(String.format("模型下存在相同的指标名:%s 创建人:%s",
metricReq.getName(), metricResp.getCreatedBy()));
}
}

View File

@@ -1,7 +1,5 @@
package com.tencent.supersonic.headless.server.service.impl;
import static com.tencent.supersonic.common.pojo.Constants.AT_SYMBOL;
import com.github.pagehelper.PageInfo;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
@@ -13,6 +11,7 @@ import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.enums.SchemaType;
import com.tencent.supersonic.headless.api.pojo.request.DataSetFilterReq;
import com.tencent.supersonic.headless.api.pojo.request.ItemUseReq;
@@ -44,23 +43,28 @@ import com.tencent.supersonic.headless.server.service.ModelRelaService;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.SchemaService;
import com.tencent.supersonic.headless.server.service.TagMetaService;
import com.tencent.supersonic.headless.server.utils.DataSetSchemaBuilder;
import com.tencent.supersonic.headless.server.utils.DimensionConverter;
import com.tencent.supersonic.headless.server.utils.MetricConverter;
import com.tencent.supersonic.headless.server.utils.StatUtils;
import com.tencent.supersonic.headless.server.utils.TagConverter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import static com.tencent.supersonic.common.pojo.Constants.AT_SYMBOL;
@Slf4j
@Service
public class SchemaServiceImpl implements SchemaService {
@@ -118,6 +122,43 @@ public class SchemaServiceImpl implements SchemaService {
return fetchDataSetSchema(new DataSetFilterReq(dataSetId)).stream().findFirst().orElse(null);
}
private List<DataSetSchemaResp> fetchDataSetSchema(List<Long> ids) {
DataSetFilterReq dataSetFilterReq = new DataSetFilterReq();
dataSetFilterReq.setDataSetIds(ids);
return fetchDataSetSchema(dataSetFilterReq);
}
@Override
public DataSetSchema getDataSetSchema(Long dataSetId) {
List<Long> ids = new ArrayList<>();
ids.add(dataSetId);
List<DataSetSchemaResp> dataSetSchemaResps = fetchDataSetSchema(ids);
if (!CollectionUtils.isEmpty(dataSetSchemaResps)) {
Optional<DataSetSchemaResp> dataSetSchemaResp = dataSetSchemaResps.stream()
.filter(d -> d.getId().equals(dataSetId)).findFirst();
if (dataSetSchemaResp.isPresent()) {
DataSetSchemaResp dataSetSchema = dataSetSchemaResp.get();
return DataSetSchemaBuilder.build(dataSetSchema);
}
}
return null;
}
@Override
public List<DataSetSchema> getDataSetSchema() {
return getDataSetSchema(new ArrayList<>());
}
public List<DataSetSchema> getDataSetSchema(List<Long> ids) {
List<DataSetSchema> domainSchemaList = new ArrayList<>();
for (DataSetSchemaResp resp : fetchDataSetSchema(ids)) {
domainSchemaList.add(DataSetSchemaBuilder.build(resp));
}
return domainSchemaList;
}
public List<DataSetSchemaResp> buildDataSetSchema(DataSetFilterReq filter) {
MetaFilter metaFilter = new MetaFilter();
metaFilter.setStatus(StatusEnum.ONLINE.getCode());
@@ -284,13 +325,6 @@ public class SchemaServiceImpl implements SchemaService {
return modelService.getModelListWithAuth(user, domainId, authTypeEnum);
}
@Override
public List<DataSetResp> getDataSetList(Long domainId) {
MetaFilter metaFilter = new MetaFilter();
metaFilter.setDomainId(domainId);
return dataSetService.getDataSetList(metaFilter);
}
public SemanticSchemaResp buildSemanticSchema(SchemaFilterReq schemaFilterReq) {
SemanticSchemaResp semanticSchemaResp = new SemanticSchemaResp();
semanticSchemaResp.setDataSetId(schemaFilterReq.getDataSetId());

View File

@@ -0,0 +1,321 @@
package com.tencent.supersonic.headless.server.service.impl;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import com.tencent.supersonic.headless.core.chat.mapper.MatchText;
import com.tencent.supersonic.headless.core.chat.mapper.ModelWithSemanticType;
import com.tencent.supersonic.headless.core.chat.mapper.SearchMatchStrategy;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.knowledge.DataSetInfoStat;
import com.tencent.supersonic.headless.core.knowledge.DictWord;
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.core.knowledge.KnowledgeService;
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
import com.tencent.supersonic.headless.server.service.ChatContextService;
import com.tencent.supersonic.headless.server.service.DataSetService;
import com.tencent.supersonic.headless.server.service.SearchService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* search service impl
*/
@Service
@Slf4j
public class SearchServiceImpl implements SearchService {
private static final int RESULT_SIZE = 10;
@Autowired
private SemanticService semanticService;
@Autowired
private SearchMatchStrategy searchMatchStrategy;
@Autowired
private ChatContextService chatContextService;
@Autowired
private KnowledgeService knowledgeService;
@Autowired
private DataSetService dataSetService;
@Override
public List<SearchResult> search(QueryReq queryReq) {
String queryText = queryReq.getQueryText();
// 1.get meta info
SemanticSchema semanticSchemaDb = semanticService.getSemanticSchema();
List<SchemaElement> metricsDb = semanticSchemaDb.getMetrics();
final Map<Long, String> dataSetIdToName = semanticSchemaDb.getDataSetIdToName();
Map<Long, List<Long>> modelIdToDataSetIds =
dataSetService.getModelIdToDataSetIds(new ArrayList<>(dataSetIdToName.keySet()));
// 2.detect by segment
List<S2Term> originals = knowledgeService.getTerms(queryText, modelIdToDataSetIds);
log.info("hanlp parse result: {}", originals);
Set<Long> dataSetIds = queryReq.getDataSetIds();
QueryContext queryContext = new QueryContext();
BeanUtils.copyProperties(queryReq, queryContext);
Map<MatchText, List<HanlpMapResult>> regTextMap =
searchMatchStrategy.match(queryContext, originals, dataSetIds);
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));
// 3.get the most matching data
Optional<Entry<MatchText, List<HanlpMapResult>>> mostSimilarSearchResult = regTextMap.entrySet()
.stream()
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
.reduce((entry1, entry2) ->
entry1.getKey().getDetectSegment().length() >= entry2.getKey().getDetectSegment().length()
? entry1 : entry2);
// 4.optimize the results after the query
if (!mostSimilarSearchResult.isPresent()) {
return Lists.newArrayList();
}
Map.Entry<MatchText, List<HanlpMapResult>> searchTextEntry = mostSimilarSearchResult.get();
log.info("searchTextEntry:{},queryReq:{}", searchTextEntry, queryReq);
Set<SearchResult> searchResults = new LinkedHashSet();
DataSetInfoStat dataSetInfoStat = NatureHelper.getDataSetStat(originals);
List<Long> possibleModels = getPossibleDataSets(queryReq, originals, dataSetInfoStat, dataSetIds);
// 5.1 priority dimension metric
boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleModels), dataSetIdToName,
searchTextEntry, searchResults);
// 5.2 process based on dimension values
MatchText matchText = searchTextEntry.getKey();
Map<String, String> natureToNameMap = getNatureToNameMap(searchTextEntry, new HashSet<>(possibleModels));
log.debug("possibleModels:{},natureToNameMap:{}", possibleModels, natureToNameMap);
for (Map.Entry<String, String> natureToNameEntry : natureToNameMap.entrySet()) {
Set<SearchResult> searchResultSet = searchDimensionValue(metricsDb, dataSetIdToName,
dataSetInfoStat.getMetricDataSetCount(), existMetricAndDimension,
matchText, natureToNameMap, natureToNameEntry, queryReq.getQueryFilters());
searchResults.addAll(searchResultSet);
}
return searchResults.stream().limit(RESULT_SIZE).collect(Collectors.toList());
}
private List<Long> getPossibleDataSets(QueryReq queryCtx, List<S2Term> originals,
DataSetInfoStat dataSetInfoStat, Set<Long> dataSetIds) {
if (CollectionUtils.isNotEmpty(dataSetIds)) {
return new ArrayList<>(dataSetIds);
}
List<Long> possibleModels = NatureHelper.selectPossibleDataSets(originals);
Long contextModel = chatContextService.getContextModel(queryCtx.getChatId());
log.debug("possibleModels:{},dataSetInfoStat:{},contextModel:{}",
possibleModels, dataSetInfoStat, contextModel);
// If nothing is recognized or only metric are present, then add the contextModel.
if (nothingOrOnlyMetric(dataSetInfoStat)) {
return Lists.newArrayList(contextModel);
}
return possibleModels;
}
private boolean nothingOrOnlyMetric(DataSetInfoStat modelStat) {
return modelStat.getMetricDataSetCount() >= 0 && modelStat.getDimensionDataSetCount() <= 0
&& modelStat.getDimensionValueDataSetCount() <= 0 && modelStat.getDataSetCount() <= 0;
}
private boolean effectiveModel(Long contextModel) {
return Objects.nonNull(contextModel) && contextModel > 0;
}
private Set<SearchResult> searchDimensionValue(List<SchemaElement> metricsDb,
Map<Long, String> modelToName,
long metricModelCount,
boolean existMetricAndDimension,
MatchText matchText,
Map<String, String> natureToNameMap,
Map.Entry<String, String> natureToNameEntry,
QueryFilters queryFilters) {
Set<SearchResult> searchResults = new LinkedHashSet();
String nature = natureToNameEntry.getKey();
String wordName = natureToNameEntry.getValue();
Long modelId = NatureHelper.getDataSetId(nature);
SchemaElementType schemaElementType = NatureHelper.convertToElementType(nature);
if (SchemaElementType.ENTITY.equals(schemaElementType)) {
return searchResults;
}
// If there are no metric/dimension, complete the metric information
SearchResult searchResult = SearchResult.builder()
.modelId(modelId)
.modelName(modelToName.get(modelId))
.recommend(matchText.getRegText() + wordName)
.schemaElementType(schemaElementType)
.subRecommend(wordName)
.build();
if (metricModelCount <= 0 && !existMetricAndDimension) {
if (filterByQueryFilter(wordName, queryFilters)) {
return searchResults;
}
searchResults.add(searchResult);
int metricSize = getMetricSize(natureToNameMap);
List<String> metrics = filerMetricsByModel(metricsDb, modelId, metricSize * 3)
.stream()
.limit(metricSize).collect(Collectors.toList());
for (String metric : metrics) {
SearchResult result = SearchResult.builder()
.modelId(modelId)
.modelName(modelToName.get(modelId))
.recommend(matchText.getRegText() + wordName + DictWordType.SPACE + metric)
.subRecommend(wordName + DictWordType.SPACE + metric)
.isComplete(false)
.build();
searchResults.add(result);
}
} else {
searchResults.add(searchResult);
}
return searchResults;
}
private int getMetricSize(Map<String, String> natureToNameMap) {
int metricSize = RESULT_SIZE / (natureToNameMap.entrySet().size());
if (metricSize <= 1) {
metricSize = 1;
}
return metricSize;
}
private boolean filterByQueryFilter(String wordName, QueryFilters queryFilters) {
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return false;
}
List<QueryFilter> filters = queryFilters.getFilters();
for (QueryFilter filter : filters) {
if (wordName.equalsIgnoreCase(String.valueOf(filter.getValue()))) {
return false;
}
}
return true;
}
protected List<String> filerMetricsByModel(List<SchemaElement> metricsDb, Long model, int metricSize) {
if (CollectionUtils.isEmpty(metricsDb)) {
return Lists.newArrayList();
}
return metricsDb.stream()
.filter(mapDO -> Objects.nonNull(mapDO) && model.equals(mapDO.getDataSet()))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.flatMap(entry -> {
List<String> result = new ArrayList<>();
result.add(entry.getName());
return result.stream();
})
.limit(metricSize).collect(Collectors.toList());
}
/***
* convert nature to name
* @param recommendTextListEntry
* @return
*/
private Map<String, String> getNatureToNameMap(Map.Entry<MatchText, List<HanlpMapResult>> recommendTextListEntry,
Set<Long> possibleModels) {
List<HanlpMapResult> recommendValues = recommendTextListEntry.getValue();
return recommendValues.stream()
.flatMap(entry -> entry.getNatures().stream()
.filter(nature -> {
if (CollectionUtils.isEmpty(possibleModels)) {
return true;
}
Long model = NatureHelper.getDataSetId(nature);
return possibleModels.contains(model);
})
.map(nature -> {
DictWord posDO = new DictWord();
posDO.setWord(entry.getName());
posDO.setNature(nature);
return posDO;
})).sorted(Comparator.comparingInt(a -> a.getWord().length()))
.collect(Collectors.toMap(DictWord::getNature, DictWord::getWord, (value1, value2) -> value1,
LinkedHashMap::new));
}
private boolean searchMetricAndDimension(Set<Long> possibleModels, Map<Long, String> modelToName,
Map.Entry<MatchText, List<HanlpMapResult>> searchTextEntry, Set<SearchResult> searchResults) {
boolean existMetric = false;
log.info("searchMetricAndDimension searchTextEntry:{}", searchTextEntry);
MatchText matchText = searchTextEntry.getKey();
List<HanlpMapResult> hanlpMapResults = searchTextEntry.getValue();
for (HanlpMapResult hanlpMapResult : hanlpMapResults) {
List<ModelWithSemanticType> dimensionMetricClassIds = hanlpMapResult.getNatures().stream()
.map(nature -> new ModelWithSemanticType(NatureHelper.getDataSetId(nature),
NatureHelper.convertToElementType(nature)))
.filter(entry -> matchCondition(entry, possibleModels)).collect(Collectors.toList());
if (CollectionUtils.isEmpty(dimensionMetricClassIds)) {
continue;
}
for (ModelWithSemanticType modelWithSemanticType : dimensionMetricClassIds) {
existMetric = true;
Long modelId = modelWithSemanticType.getModel();
SchemaElementType schemaElementType = modelWithSemanticType.getSchemaElementType();
SearchResult searchResult = SearchResult.builder()
.modelId(modelId)
.modelName(modelToName.get(modelId))
.recommend(matchText.getRegText() + hanlpMapResult.getName())
.subRecommend(hanlpMapResult.getName())
.schemaElementType(schemaElementType)
.build();
//visibility to filter metrics
searchResults.add(searchResult);
}
log.info("parseResult:{},dimensionMetricClassIds:{},possibleModels:{}", hanlpMapResult,
dimensionMetricClassIds, possibleModels);
}
log.info("searchMetricAndDimension searchResults:{}", searchResults);
return existMetric;
}
private boolean matchCondition(ModelWithSemanticType entry, Set<Long> possibleModels) {
if (!(SchemaElementType.METRIC.equals(entry.getSchemaElementType()) || SchemaElementType.DIMENSION.equals(
entry.getSchemaElementType()))) {
return false;
}
if (CollectionUtils.isEmpty(possibleModels)) {
return true;
}
return possibleModels.contains(entry.getModel());
}
}

View File

@@ -0,0 +1,229 @@
package com.tencent.supersonic.headless.server.service.impl;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.DataInfo;
import com.tencent.supersonic.headless.api.pojo.DataSetInfo;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.core.utils.QueryReqBuilder;
import com.tencent.supersonic.headless.server.service.QueryService;
import com.tencent.supersonic.headless.server.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.time.LocalDate;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@Service
@Slf4j
public class SemanticService {
@Autowired
private SchemaService schemaService;
@Autowired
private QueryService queryService;
public SemanticSchema getSemanticSchema() {
return new SemanticSchema(schemaService.getDataSetSchema());
}
public DataSetSchema getDataSetSchema(Long id) {
return schemaService.getDataSetSchema(id);
}
public EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user) {
if (parseInfo != null && parseInfo.getDataSetId() > 0) {
EntityInfo entityInfo = getEntityBasicInfo(dataSetSchema);
if (parseInfo.getDimensionFilters().size() <= 0 || entityInfo.getDataSetInfo() == null) {
entityInfo.setMetrics(null);
entityInfo.setDimensions(null);
return entityInfo;
}
String primaryKey = entityInfo.getDataSetInfo().getPrimaryKey();
if (StringUtils.isNotBlank(primaryKey)) {
String entityId = "";
for (QueryFilter chatFilter : parseInfo.getDimensionFilters()) {
if (chatFilter != null && chatFilter.getBizName() != null && chatFilter.getBizName()
.equals(primaryKey)) {
if (chatFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) {
entityId = chatFilter.getValue().toString();
}
}
}
entityInfo.setEntityId(entityId);
try {
fillEntityInfoValue(entityInfo, dataSetSchema, user);
return entityInfo;
} catch (Exception e) {
log.error("setMainModel error", e);
}
}
}
return null;
}
private EntityInfo getEntityBasicInfo(DataSetSchema dataSetSchema) {
EntityInfo entityInfo = new EntityInfo();
if (dataSetSchema == null) {
return entityInfo;
}
Long dataSetId = dataSetSchema.getDataSet().getDataSet();
DataSetInfo dataSetInfo = new DataSetInfo();
dataSetInfo.setItemId(dataSetId.intValue());
dataSetInfo.setName(dataSetSchema.getDataSet().getName());
dataSetInfo.setWords(dataSetSchema.getDataSet().getAlias());
dataSetInfo.setBizName(dataSetSchema.getDataSet().getBizName());
if (Objects.nonNull(dataSetSchema.getEntity())) {
dataSetInfo.setPrimaryKey(dataSetSchema.getEntity().getBizName());
}
entityInfo.setDataSetInfo(dataSetInfo);
TagTypeDefaultConfig tagTypeDefaultConfig = dataSetSchema.getTagTypeDefaultConfig();
if (tagTypeDefaultConfig == null || tagTypeDefaultConfig.getDefaultDisplayInfo() == null) {
return entityInfo;
}
List<DataInfo> dimensions = tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream()
.map(id -> {
SchemaElement element = dataSetSchema.getElement(SchemaElementType.DIMENSION, id);
if (element == null) {
return null;
}
return new DataInfo(element.getId().intValue(), element.getName(), element.getBizName(), null);
}).filter(Objects::nonNull).collect(Collectors.toList());
List<DataInfo> metrics = tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream()
.map(id -> {
SchemaElement element = dataSetSchema.getElement(SchemaElementType.METRIC, id);
if (element == null) {
return null;
}
return new DataInfo(element.getId().intValue(), element.getName(), element.getBizName(), null);
}).filter(Objects::nonNull).collect(Collectors.toList());
entityInfo.setDimensions(dimensions);
entityInfo.setMetrics(metrics);
return entityInfo;
}
public void fillEntityInfoValue(EntityInfo entityInfo, DataSetSchema dataSetSchema, User user) {
SemanticQueryResp queryResultWithColumns =
getQueryResultWithSchemaResp(entityInfo, dataSetSchema, user);
if (queryResultWithColumns != null) {
if (!CollectionUtils.isEmpty(queryResultWithColumns.getResultList())
&& queryResultWithColumns.getResultList().size() > 0) {
Map<String, Object> result = queryResultWithColumns.getResultList().get(0);
for (Map.Entry<String, Object> entry : result.entrySet()) {
String entryKey = getEntryKey(entry);
if (entry.getValue() == null || entryKey == null) {
continue;
}
entityInfo.getDimensions().stream().filter(i -> entryKey.equals(i.getBizName()))
.forEach(i -> i.setValue(entry.getValue().toString()));
entityInfo.getMetrics().stream().filter(i -> entryKey.equals(i.getBizName()))
.forEach(i -> i.setValue(entry.getValue().toString()));
}
}
}
}
public SemanticQueryResp getQueryResultWithSchemaResp(EntityInfo entityInfo,
DataSetSchema dataSetSchema, User user) {
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setDataSet(dataSetSchema.getDataSet());
semanticParseInfo.setQueryType(QueryType.TAG);
semanticParseInfo.setMetrics(getMetrics(entityInfo));
semanticParseInfo.setDimensions(getDimensions(entityInfo));
DateConf dateInfo = new DateConf();
int unit = 1;
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
if (Objects.nonNull(timeDefaultConfig)) {
unit = timeDefaultConfig.getUnit();
String date = LocalDate.now().plusDays(-unit).toString();
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
dateInfo.setStartDate(date);
dateInfo.setEndDate(date);
} else {
dateInfo.setUnit(unit);
dateInfo.setDateMode(DateConf.DateMode.RECENT);
}
semanticParseInfo.setDateInfo(dateInfo);
//add filter
QueryFilter chatFilter = getQueryFilter(entityInfo);
Set<QueryFilter> chatFilters = new LinkedHashSet();
chatFilters.add(chatFilter);
semanticParseInfo.setDimensionFilters(chatFilters);
SemanticQueryResp queryResultWithColumns = null;
try {
QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(semanticParseInfo);
queryResultWithColumns = queryService.queryByReq(queryStructReq, user);
} catch (Exception e) {
log.warn("setMainModel queryByStruct error, e:", e);
}
return queryResultWithColumns;
}
private QueryFilter getQueryFilter(EntityInfo entityInfo) {
QueryFilter chatFilter = new QueryFilter();
chatFilter.setValue(entityInfo.getEntityId());
chatFilter.setOperator(FilterOperatorEnum.EQUALS);
chatFilter.setBizName(getEntityPrimaryName(entityInfo));
return chatFilter;
}
private Set<SchemaElement> getDimensions(EntityInfo modelInfo) {
Set<SchemaElement> dimensions = new LinkedHashSet();
for (DataInfo mainEntityDimension : modelInfo.getDimensions()) {
SchemaElement dimension = new SchemaElement();
dimension.setBizName(mainEntityDimension.getBizName());
dimensions.add(dimension);
}
return dimensions;
}
private String getEntryKey(Map.Entry<String, Object> entry) {
// metric parser special handle, TODO delete
String entryKey = entry.getKey();
if (entryKey.contains("__")) {
entryKey = entryKey.split("__")[1];
}
return entryKey;
}
private Set<SchemaElement> getMetrics(EntityInfo modelInfo) {
Set<SchemaElement> metrics = new LinkedHashSet();
for (DataInfo metricValue : modelInfo.getMetrics()) {
SchemaElement metric = new SchemaElement();
BeanUtils.copyProperties(metricValue, metric);
metrics.add(metric);
}
return metrics;
}
private String getEntityPrimaryName(EntityInfo entityInfo) {
return entityInfo.getDataSetInfo().getPrimaryKey();
}
}

View File

@@ -0,0 +1,65 @@
package com.tencent.supersonic.headless.server.service.impl;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.core.knowledge.DictWord;
import com.tencent.supersonic.headless.core.knowledge.builder.WordBuilderFactory;
import com.tencent.supersonic.headless.server.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
@Service
@Slf4j
public class WordService {
@Autowired
private SchemaService schemaService;
private List<DictWord> preDictWords = new ArrayList<>();
public List<DictWord> getAllDictWords() {
SemanticSchema semanticSchema = new SemanticSchema(schemaService.getDataSetSchema());
List<DictWord> words = new ArrayList<>();
addWordsByType(DictWordType.DIMENSION, semanticSchema.getDimensions(), words);
addWordsByType(DictWordType.METRIC, semanticSchema.getMetrics(), words);
addWordsByType(DictWordType.ENTITY, semanticSchema.getEntities(), words);
addWordsByType(DictWordType.VALUE, semanticSchema.getDimensionValues(), words);
addWordsByType(DictWordType.TAG, semanticSchema.getTags(), words);
return words;
}
private void addWordsByType(DictWordType value, List<SchemaElement> metas, List<DictWord> natures) {
metas = distinct(metas);
List<DictWord> natureList = WordBuilderFactory.get(value).getDictWords(metas);
log.debug("nature type:{} , nature size:{}", value.name(), natureList.size());
natures.addAll(natureList);
}
public List<DictWord> getPreDictWords() {
return preDictWords;
}
public void setPreDictWords(List<DictWord> preDictWords) {
this.preDictWords = preDictWords;
}
private List<SchemaElement> distinct(List<SchemaElement> metas) {
return metas.stream()
.collect(Collectors.toMap(SchemaElement::getId, Function.identity(), (e1, e2) -> e1))
.values()
.stream()
.collect(Collectors.toList());
}
}

View File

@@ -0,0 +1,58 @@
package com.tencent.supersonic.headless.server.utils;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.core.chat.corrector.SemanticCorrector;
import com.tencent.supersonic.headless.core.chat.mapper.SchemaMapper;
import com.tencent.supersonic.headless.core.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.core.io.support.SpringFactoriesLoader;
import java.util.ArrayList;
import java.util.List;
/**
* HeadlessConverter QueryOptimizer QueryExecutor object factory
*/
@Slf4j
public class ComponentFactory {
private static List<ResultProcessor> resultProcessors = new ArrayList<>();
private static List<SchemaMapper> schemaMappers = new ArrayList<>();
private static List<SemanticParser> semanticParsers = new ArrayList<>();
private static List<SemanticCorrector> semanticCorrectors = new ArrayList<>();
public static List<ResultProcessor> getResultProcessors() {
return CollectionUtils.isEmpty(resultProcessors) ? init(ResultProcessor.class,
resultProcessors) : resultProcessors;
}
public static List<SchemaMapper> getSchemaMappers() {
return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers) : schemaMappers;
}
public static List<SemanticParser> getSemanticParsers() {
return CollectionUtils.isEmpty(semanticParsers) ? init(SemanticParser.class, semanticParsers) : semanticParsers;
}
public static List<SemanticCorrector> getSemanticCorrectors() {
return CollectionUtils.isEmpty(semanticCorrectors) ? init(SemanticCorrector.class,
semanticCorrectors) : semanticCorrectors;
}
public static <T> T getBean(String name, Class<T> tClass) {
return ContextUtils.getContext().getBean(name, tClass);
}
private static <T> List<T> init(Class<T> factoryType, List list) {
list.addAll(SpringFactoriesLoader.loadFactories(factoryType,
Thread.currentThread().getContextClassLoader()));
return list;
}
private static <T> T init(Class<T> factoryType) {
return SpringFactoriesLoader.loadFactories(factoryType,
Thread.currentThread().getContextClassLoader()).get(0);
}
}

View File

@@ -0,0 +1,209 @@
package com.tencent.supersonic.headless.server.utils;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.DimValueMap;
import com.tencent.supersonic.headless.api.pojo.RelateDimension;
import com.tencent.supersonic.headless.api.pojo.RelatedSchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.TagResp;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
public class DataSetSchemaBuilder {
public static DataSetSchema build(DataSetSchemaResp resp) {
DataSetSchema dataSetSchema = new DataSetSchema();
dataSetSchema.setQueryConfig(resp.getQueryConfig());
dataSetSchema.setQueryType(resp.getQueryType());
SchemaElement dataSet = SchemaElement.builder()
.dataSet(resp.getId())
.id(resp.getId())
.name(resp.getName())
.bizName(resp.getBizName())
.type(SchemaElementType.DATASET)
.build();
dataSetSchema.setDataSet(dataSet);
Set<SchemaElement> metrics = getMetrics(resp);
dataSetSchema.getMetrics().addAll(metrics);
Set<SchemaElement> dimensions = getDimensions(resp);
dataSetSchema.getDimensions().addAll(dimensions);
Set<SchemaElement> dimensionValues = getDimensionValues(resp);
dataSetSchema.getDimensionValues().addAll(dimensionValues);
Set<SchemaElement> tags = getTags(resp);
dataSetSchema.getTags().addAll(tags);
Set<SchemaElement> tagValues = getTagValues(resp);
dataSetSchema.getTagValues().addAll(tagValues);
SchemaElement entity = getEntity(resp);
if (Objects.nonNull(entity)) {
dataSetSchema.setEntity(entity);
}
return dataSetSchema;
}
private static SchemaElement getEntity(DataSetSchemaResp resp) {
DimSchemaResp dim = resp.getPrimaryKey();
if (Objects.isNull(dim)) {
return null;
}
return SchemaElement.builder()
.dataSet(resp.getId())
.model(dim.getModelId())
.id(dim.getId())
.name(dim.getName())
.bizName(dim.getBizName())
.type(SchemaElementType.ENTITY)
.useCnt(dim.getUseCnt())
.alias(dim.getEntityAlias())
.build();
}
private static Set<SchemaElement> getTags(DataSetSchemaResp resp) {
Set<SchemaElement> tags = new HashSet<>();
for (TagResp tagResp : resp.getTags()) {
SchemaElement element = SchemaElement.builder()
.dataSet(resp.getId())
.model(tagResp.getModelId())
.id(tagResp.getId())
.name(tagResp.getName())
.bizName(tagResp.getBizName())
.type(SchemaElementType.TAG)
.build();
tags.add(element);
}
return tags;
}
private static Set<SchemaElement> getTagValues(DataSetSchemaResp resp) {
Set<SchemaElement> dimensionValues = new HashSet<>();
for (TagResp tagResp : resp.getTags()) {
SchemaElement element = SchemaElement.builder()
.dataSet(resp.getId())
.model(tagResp.getModelId())
.id(tagResp.getId())
.name(tagResp.getName())
.bizName(tagResp.getBizName())
.type(SchemaElementType.TAG_VALUE)
.build();
dimensionValues.add(element);
}
return dimensionValues;
}
private static Set<SchemaElement> getDimensions(DataSetSchemaResp resp) {
Set<SchemaElement> dimensions = new HashSet<>();
for (DimSchemaResp dim : resp.getDimensions()) {
List<String> alias = SchemaItem.getAliasList(dim.getAlias());
List<DimValueMap> dimValueMaps = dim.getDimValueMaps();
List<SchemaValueMap> schemaValueMaps = new ArrayList<>();
if (!CollectionUtils.isEmpty(dimValueMaps)) {
for (DimValueMap dimValueMap : dimValueMaps) {
SchemaValueMap schemaValueMap = new SchemaValueMap();
BeanUtils.copyProperties(dimValueMap, schemaValueMap);
schemaValueMaps.add(schemaValueMap);
}
}
SchemaElement dimToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.model(dim.getModelId())
.id(dim.getId())
.name(dim.getName())
.bizName(dim.getBizName())
.type(SchemaElementType.DIMENSION)
.useCnt(dim.getUseCnt())
.alias(alias)
.schemaValueMaps(schemaValueMaps)
.build();
dimensions.add(dimToAdd);
}
return dimensions;
}
private static Set<SchemaElement> getDimensionValues(DataSetSchemaResp resp) {
Set<SchemaElement> dimensionValues = new HashSet<>();
for (DimSchemaResp dim : resp.getDimensions()) {
Set<String> dimValueAlias = new HashSet<>();
List<DimValueMap> dimValueMaps = dim.getDimValueMaps();
if (!CollectionUtils.isEmpty(dimValueMaps)) {
for (DimValueMap dimValueMap : dimValueMaps) {
if (Strings.isNotEmpty(dimValueMap.getBizName())) {
dimValueAlias.add(dimValueMap.getBizName());
}
if (!CollectionUtils.isEmpty(dimValueMap.getAlias())) {
dimValueAlias.addAll(dimValueMap.getAlias());
}
}
}
SchemaElement dimValueToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.model(dim.getModelId())
.id(dim.getId())
.name(dim.getName())
.bizName(dim.getBizName())
.type(SchemaElementType.VALUE)
.useCnt(dim.getUseCnt())
.alias(new ArrayList<>(Arrays.asList(dimValueAlias.toArray(new String[0]))))
.build();
dimensionValues.add(dimValueToAdd);
}
return dimensionValues;
}
private static Set<SchemaElement> getMetrics(DataSetSchemaResp resp) {
Set<SchemaElement> metrics = new HashSet<>();
for (MetricSchemaResp metric : resp.getMetrics()) {
List<String> alias = SchemaItem.getAliasList(metric.getAlias());
SchemaElement metricToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.model(metric.getModelId())
.id(metric.getId())
.name(metric.getName())
.bizName(metric.getBizName())
.type(SchemaElementType.METRIC)
.useCnt(metric.getUseCnt())
.alias(alias)
.relatedSchemaElements(getRelateSchemaElement(metric))
.defaultAgg(metric.getDefaultAgg())
.build();
metrics.add(metricToAdd);
}
return metrics;
}
private static List<RelatedSchemaElement> getRelateSchemaElement(MetricSchemaResp metricSchemaResp) {
RelateDimension relateDimension = metricSchemaResp.getRelateDimension();
if (relateDimension == null || CollectionUtils.isEmpty(relateDimension.getDrillDownDimensions())) {
return Lists.newArrayList();
}
return relateDimension.getDrillDownDimensions().stream().map(dimension -> {
RelatedSchemaElement relateSchemaElement = new RelatedSchemaElement();
BeanUtils.copyProperties(dimension, relateSchemaElement);
return relateSchemaElement;
}).collect(Collectors.toList());
}
}

View File

@@ -1,11 +1,5 @@
package com.tencent.supersonic.headless.server.utils;
import static com.tencent.supersonic.common.pojo.Constants.AND_UPPER;
import static com.tencent.supersonic.common.pojo.Constants.APOSTROPHE;
import static com.tencent.supersonic.common.pojo.Constants.COMMA;
import static com.tencent.supersonic.common.pojo.Constants.POUND;
import static com.tencent.supersonic.common.pojo.Constants.SPACE;
import com.google.common.base.Strings;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.Aggregator;
@@ -40,7 +34,13 @@ import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.QueryService;
import com.tencent.supersonic.headless.server.service.TagMetaService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.time.LocalDate;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
@@ -53,14 +53,11 @@ import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.StringJoiner;
import com.tencent.supersonic.headless.server.service.TagMetaService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import static com.tencent.supersonic.common.pojo.Constants.AND_UPPER;
import static com.tencent.supersonic.common.pojo.Constants.APOSTROPHE;
import static com.tencent.supersonic.common.pojo.Constants.COMMA;
import static com.tencent.supersonic.common.pojo.Constants.POUND;
import static com.tencent.supersonic.common.pojo.Constants.SPACE;
@Slf4j
@Component
@@ -298,6 +295,7 @@ public class DictUtils {
}
private QuerySqlReq constructQuerySqlReq(DictItemResp dictItemResp) {
// todo tag
String sqlPattern = "select %s,count(1) from tbl %s group by %s order by count(1) desc limit %d";
String bizName = dictItemResp.getBizName();

View File

@@ -10,7 +10,7 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.QueryStat;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.enums.QueryOptMode;
import com.tencent.supersonic.headless.api.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.enums.QueryTypeBack;
import com.tencent.supersonic.headless.api.pojo.request.ItemUseReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
@@ -110,7 +110,7 @@ public class StatUtils {
queryStatInfo.setTraceId(traceId)
.setDataSetId(queryTagReq.getDataSetId())
.setUser(user)
.setQueryType(QueryType.STRUCT.getValue())
.setQueryType(QueryMethod.STRUCT.getValue())
.setQueryTypeBack(QueryTypeBack.NORMAL.getState())
.setQueryStructCmd(queryTagReq.toString())
.setQueryStructCmdMd5(DigestUtils.md5Hex(queryTagReq.toString()))
@@ -147,7 +147,7 @@ public class StatUtils {
queryStatInfo.setTraceId("")
.setUser(userName)
.setDataSetId(querySqlReq.getDataSetId())
.setQueryType(QueryType.SQL.getValue())
.setQueryType(QueryMethod.SQL.getValue())
.setQueryTypeBack(QueryTypeBack.NORMAL.getState())
.setQuerySqlCmd(querySqlReq.toString())
.setQuerySqlCmdMd5(DigestUtils.md5Hex(querySqlReq.toString()))
@@ -178,7 +178,7 @@ public class StatUtils {
queryStatInfo.setTraceId(traceId)
.setDataSetId(queryStructReq.getDataSetId())
.setUser(user)
.setQueryType(QueryType.STRUCT.getValue())
.setQueryType(QueryMethod.STRUCT.getValue())
.setQueryTypeBack(QueryTypeBack.NORMAL.getState())
.setQueryStructCmd(queryStructReq.toString())
.setQueryStructCmdMd5(DigestUtils.md5Hex(queryStructReq.toString()))

View File

@@ -0,0 +1,30 @@
<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.tencent.supersonic.headless.server.persistence.mapper.ChatContextMapper">
<resultMap id="ChatContextDO"
type="com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO">
<id column="chat_id" property="chatId"/>
<result column="modified_at" property="modifiedAt"/>
<result column="user" property="user"/>
<result column="query_text" property="queryText"/>
<result column="semantic_parse" property="semanticParse"/>
<!--<result column="ext_data" property="extData"/>-->
</resultMap>
<select id="getContextByChatId" resultMap="ChatContextDO">
select *
from s2_chat_context where chat_id=#{chatId} limit 1
</select>
<insert id="addContext" parameterType="com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO" >
insert into s2_chat_context (chat_id,user,query_text,semantic_parse) values (#{chatId}, #{user},#{queryText}, #{semanticParse})
</insert>
<update id="updateContext">
update s2_chat_context set query_text=#{queryText},semantic_parse=#{semanticParse} where chat_id=#{chatId}
</update>
</mapper>

View File

@@ -0,0 +1,28 @@
<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.tencent.supersonic.headless.server.persistence.mapper.StatisticsMapper">
<resultMap id="Statistics" type="com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO">
<id column="question_id" property="questionId"/>
<result column="chat_id" property="chatId"/>
<result column="user_name" property="userName"/>
<result column="query_text" property="queryText"/>
<result column="interface_name" property="interfaceName"/>
<result column="cost" property="cost"/>
<result column="type" property="type"/>
<result column="create_time" property="createTime"/>
</resultMap>
<insert id="batchSaveStatistics" parameterType="com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO">
insert into s2_chat_statistics
(question_id,chat_id, user_name, query_text, interface_name,cost,type ,create_time)
values
<foreach collection="list" item="item" index="index" separator=",">
(#{item.questionId}, #{item.chatId}, #{item.userName}, #{item.queryText}, #{item.interfaceName}, #{item.cost}, #{item.type},#{item.createTime})
</foreach>
</insert>
</mapper>

View File

@@ -0,0 +1,80 @@
package com.tencent.supersonic.headless.server.utils;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateModeUtils;
import com.tencent.supersonic.common.util.SqlFilterUtils;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.testng.Assert;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* QueryReqBuilderTest
*/
class QueryReqBuilderTest {
@Test
void buildS2SQLReq() {
init();
QueryStructReq queryStructReq = new QueryStructReq();
queryStructReq.setDataSetId(1L);
queryStructReq.setDataSetName("内容库");
queryStructReq.setQueryType(QueryType.METRIC);
Aggregator aggregator = new Aggregator();
aggregator.setFunc(AggOperatorEnum.UNKNOWN);
aggregator.setColumn("pv");
queryStructReq.setAggregators(Arrays.asList(aggregator));
queryStructReq.setGroups(Arrays.asList("department"));
DateConf dateConf = new DateConf();
dateConf.setDateMode(DateMode.LIST);
dateConf.setDateList(Arrays.asList("2023-08-01"));
queryStructReq.setDateInfo(dateConf);
List<Order> orders = new ArrayList<>();
Order order = new Order();
order.setColumn("uv");
orders.add(order);
queryStructReq.setOrders(orders);
QuerySqlReq querySQLReq = queryStructReq.convert();
Assert.assertEquals(
"SELECT department, SUM(pv) AS pv FROM 内容库 "
+ "WHERE (sys_imp_date IN ('2023-08-01')) GROUP "
+ "BY department ORDER BY uv LIMIT 2000", querySQLReq.getSql());
queryStructReq.setQueryType(QueryType.TAG);
querySQLReq = queryStructReq.convert();
Assert.assertEquals(
"SELECT department, pv FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
+ "ORDER BY uv LIMIT 2000",
querySQLReq.getSql());
}
private void init() {
MockedStatic<ContextUtils> mockContextUtils = Mockito.mockStatic(ContextUtils.class);
SqlFilterUtils sqlFilterUtils = new SqlFilterUtils();
mockContextUtils.when(() -> ContextUtils.getBean(SqlFilterUtils.class)).thenReturn(sqlFilterUtils);
DateModeUtils dateModeUtils = new DateModeUtils();
mockContextUtils.when(() -> ContextUtils.getBean(DateModeUtils.class)).thenReturn(dateModeUtils);
dateModeUtils.setSysDateCol("sys_imp_date");
dateModeUtils.setSysDateWeekCol("sys_imp_week");
dateModeUtils.setSysDateMonthCol("sys_imp_month");
}
}