mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(chat) support remove InExpression and partly complete fillResponse if queryResults exist primaryEntityBizName (#181)
This commit is contained in:
@@ -8,5 +8,6 @@ import lombok.Data;
|
||||
public class ModelInfo extends DataInfo implements Serializable {
|
||||
|
||||
private List<String> words;
|
||||
private String primaryEntityName;
|
||||
private String primaryEntityBizName;
|
||||
}
|
||||
|
||||
@@ -1,19 +1,57 @@
|
||||
package com.tencent.supersonic.chat.responder.execute;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
public class EntityInfoExecuteResponder implements ExecuteResponder {
|
||||
|
||||
@Override
|
||||
public void fillResponse(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) {
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(semanticParseInfo, queryReq.getUser());
|
||||
queryResult.setEntityInfo(entityInfo);
|
||||
User user = queryReq.getUser();
|
||||
queryResult.setEntityInfo(semanticService.getEntityInfo(semanticParseInfo, user));
|
||||
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(semanticParseInfo.getModelId());
|
||||
if (Objects.isNull(entityInfo) || Objects.isNull(entityInfo.getModelInfo())
|
||||
|| Objects.isNull(entityInfo.getModelInfo().getPrimaryEntityName())) {
|
||||
return;
|
||||
}
|
||||
String primaryEntityBizName = entityInfo.getModelInfo().getPrimaryEntityBizName();
|
||||
boolean existPrimaryEntityName = queryResult.getQueryColumns().stream()
|
||||
.anyMatch(queryColumn -> primaryEntityBizName.equals(queryColumn.getNameEn()));
|
||||
|
||||
if (!existPrimaryEntityName) {
|
||||
return;
|
||||
}
|
||||
List<Map<String, Object>> queryResults = queryResult.getQueryResults();
|
||||
List<String> entities = queryResults.stream()
|
||||
.map(entry -> entry.get(primaryEntityBizName))
|
||||
.filter(Objects::nonNull)
|
||||
.map(String::valueOf)
|
||||
.collect(Collectors.toList());
|
||||
if (CollectionUtils.isEmpty(entities)) {
|
||||
return;
|
||||
}
|
||||
QueryResultWithSchemaResp queryResultWithSchemaResp = semanticService.getQueryResultWithSchemaResp(entityInfo,
|
||||
semanticParseInfo.getModelId(), entities, user);
|
||||
if (Objects.isNull(queryResultWithSchemaResp)) {
|
||||
return;
|
||||
}
|
||||
List<Map<String, Object>> entityResultList = queryResultWithSchemaResp.getResultList();
|
||||
if (CollectionUtils.isEmpty(entityResultList)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -26,9 +26,9 @@ import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.DataInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ModelInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.MetricInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ModelInfo;
|
||||
import com.tencent.supersonic.chat.config.AggregatorConfig;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
|
||||
@@ -51,6 +51,7 @@ import java.time.YearMonth;
|
||||
import java.time.format.DateTimeFormatter;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
@@ -62,6 +63,7 @@ import java.util.Set;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -113,15 +115,12 @@ public class SemanticService {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!"".equals(modelInfoId)) {
|
||||
try {
|
||||
setMainModel(entityInfo, parseInfo.getModelId(),
|
||||
modelInfoId, user);
|
||||
|
||||
return entityInfo;
|
||||
} catch (Exception e) {
|
||||
log.error("setMainModel error {}", e);
|
||||
}
|
||||
try {
|
||||
setMainModel(entityInfo, parseInfo.getModelId(),
|
||||
modelInfoId, user);
|
||||
return entityInfo;
|
||||
} catch (Exception e) {
|
||||
log.error("setMainModel error {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -152,6 +151,7 @@ public class SemanticService {
|
||||
modelInfo.setWords(modelSchema.getModel().getAlias());
|
||||
modelInfo.setBizName(modelSchema.getModel().getBizName());
|
||||
if (Objects.nonNull(modelSchema.getEntity())) {
|
||||
modelInfo.setPrimaryEntityName(modelSchema.getEntity().getName());
|
||||
modelInfo.setPrimaryEntityBizName(modelSchema.getEntity().getBizName());
|
||||
}
|
||||
|
||||
@@ -190,9 +190,40 @@ public class SemanticService {
|
||||
}
|
||||
|
||||
public void setMainModel(EntityInfo modelInfo, Long model, String entity, User user) {
|
||||
ModelSchema modelSchema = schemaService.getModelSchema(model);
|
||||
if (StringUtils.isEmpty(entity)) {
|
||||
return;
|
||||
}
|
||||
|
||||
modelInfo.setEntityId(entity);
|
||||
List<String> entities = Collections.singletonList(entity);
|
||||
|
||||
QueryResultWithSchemaResp queryResultWithColumns = getQueryResultWithSchemaResp(modelInfo, model, entities,
|
||||
user);
|
||||
|
||||
if (queryResultWithColumns != null) {
|
||||
if (!CollectionUtils.isEmpty(queryResultWithColumns.getResultList())
|
||||
&& queryResultWithColumns.getResultList().size() > 0) {
|
||||
Map<String, Object> result = queryResultWithColumns.getResultList().get(0);
|
||||
for (Map.Entry<String, Object> entry : result.entrySet()) {
|
||||
String entryKey = getEntryKey(entry);
|
||||
if (entry.getValue() == null || entryKey == null) {
|
||||
continue;
|
||||
}
|
||||
modelInfo.getDimensions().stream().filter(i -> entryKey.equals(i.getBizName()))
|
||||
.forEach(i -> i.setValue(entry.getValue().toString()));
|
||||
modelInfo.getMetrics().stream().filter(i -> entryKey.equals(i.getBizName()))
|
||||
.forEach(i -> i.setValue(entry.getValue().toString()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public QueryResultWithSchemaResp getQueryResultWithSchemaResp(EntityInfo modelInfo, Long model,
|
||||
List<String> entities, User user) {
|
||||
if (CollectionUtils.isEmpty(entities)) {
|
||||
return null;
|
||||
}
|
||||
ModelSchema modelSchema = schemaService.getModelSchema(model);
|
||||
modelInfo.setEntityId(entities.get(0));
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
semanticParseInfo.setModel(modelSchema.getModel());
|
||||
semanticParseInfo.setNativeQuery(true);
|
||||
@@ -217,10 +248,7 @@ public class SemanticService {
|
||||
semanticParseInfo.setDateInfo(dateInfo);
|
||||
|
||||
// add filter
|
||||
QueryFilter chatFilter = new QueryFilter();
|
||||
chatFilter.setValue(String.valueOf(entity));
|
||||
chatFilter.setOperator(FilterOperatorEnum.EQUALS);
|
||||
chatFilter.setBizName(getEntityPrimaryName(modelInfo));
|
||||
QueryFilter chatFilter = getQueryFilter(modelInfo, entities);
|
||||
Set<QueryFilter> chatFilters = new LinkedHashSet();
|
||||
chatFilters.add(chatFilter);
|
||||
semanticParseInfo.setDimensionFilters(chatFilters);
|
||||
@@ -232,23 +260,20 @@ public class SemanticService {
|
||||
} catch (Exception e) {
|
||||
log.warn("setMainModel queryByStruct error, e:", e);
|
||||
}
|
||||
return queryResultWithColumns;
|
||||
}
|
||||
|
||||
if (queryResultWithColumns != null) {
|
||||
if (!CollectionUtils.isEmpty(queryResultWithColumns.getResultList())
|
||||
&& queryResultWithColumns.getResultList().size() > 0) {
|
||||
Map<String, Object> result = queryResultWithColumns.getResultList().get(0);
|
||||
for (Map.Entry<String, Object> entry : result.entrySet()) {
|
||||
String entryKey = getEntryKey(entry);
|
||||
if (entry.getValue() == null || entryKey == null) {
|
||||
continue;
|
||||
}
|
||||
modelInfo.getDimensions().stream().filter(i -> entryKey.equals(i.getBizName()))
|
||||
.forEach(i -> i.setValue(entry.getValue().toString()));
|
||||
modelInfo.getMetrics().stream().filter(i -> entryKey.equals(i.getBizName()))
|
||||
.forEach(i -> i.setValue(entry.getValue().toString()));
|
||||
}
|
||||
}
|
||||
private QueryFilter getQueryFilter(EntityInfo modelInfo, List<String> entities) {
|
||||
QueryFilter chatFilter = new QueryFilter();
|
||||
if (entities.size() == 1) {
|
||||
chatFilter.setValue(entities.get(0));
|
||||
chatFilter.setOperator(FilterOperatorEnum.EQUALS);
|
||||
} else {
|
||||
chatFilter.setValue(entities);
|
||||
chatFilter.setOperator(FilterOperatorEnum.IN);
|
||||
}
|
||||
chatFilter.setBizName(getEntityPrimaryName(modelInfo));
|
||||
return chatFilter;
|
||||
}
|
||||
|
||||
private Set<SchemaElement> getDimensions(EntityInfo modelInfo) {
|
||||
@@ -332,7 +357,7 @@ public class SemanticService {
|
||||
}
|
||||
|
||||
public AggregateInfo getAggregateInfo(User user, SemanticParseInfo semanticParseInfo,
|
||||
QueryResultWithSchemaResp result) {
|
||||
QueryResultWithSchemaResp result) {
|
||||
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics()) || !aggregatorConfig.getEnableRatio()) {
|
||||
return new AggregateInfo();
|
||||
}
|
||||
@@ -384,7 +409,7 @@ public class SemanticService {
|
||||
}
|
||||
|
||||
private MetricInfo queryRatio(User user, SemanticParseInfo semanticParseInfo, SchemaElement metric,
|
||||
AggOperatorEnum aggOperatorEnum, QueryResultWithSchemaResp results) {
|
||||
AggOperatorEnum aggOperatorEnum, QueryResultWithSchemaResp results) {
|
||||
MetricInfo metricInfo = new MetricInfo();
|
||||
metricInfo.setStatistics(new HashMap<>());
|
||||
QueryStructReq queryStructReq = QueryReqBuilder.buildStructRatioReq(semanticParseInfo, metric, aggOperatorEnum);
|
||||
@@ -432,7 +457,7 @@ public class SemanticService {
|
||||
}
|
||||
|
||||
private DateConf getRatioDateConf(AggOperatorEnum aggOperatorEnum, SemanticParseInfo semanticParseInfo,
|
||||
QueryResultWithSchemaResp results) {
|
||||
QueryResultWithSchemaResp results) {
|
||||
String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo());
|
||||
Optional<String> lastDayOp = results.getResultList().stream()
|
||||
.map(r -> r.get(dateField).toString())
|
||||
|
||||
@@ -15,4 +15,6 @@ public class JsqlConstants {
|
||||
public static final String GREATER_THAN_EQUALS_CONSTANT = " 1 >= 1 ";
|
||||
public static final String EQUAL_CONSTANT = " 1 = 1 ";
|
||||
|
||||
public static final String IN_CONSTANT = " 1 in (1) ";
|
||||
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import net.sf.jsqlparser.expression.operators.conditional.XorExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
|
||||
import net.sf.jsqlparser.expression.operators.relational.InExpression;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import net.sf.jsqlparser.schema.Table;
|
||||
@@ -526,29 +527,50 @@ public class SqlParserUpdateHelper {
|
||||
}
|
||||
|
||||
private static void removeExpressionWithConstant(Expression expression, Set<String> removeFieldNames) {
|
||||
if (!(expression instanceof EqualsTo)) {
|
||||
return;
|
||||
if (expression instanceof EqualsTo) {
|
||||
ComparisonOperator comparisonOperator = (ComparisonOperator) expression;
|
||||
String columnName = getColumnName(comparisonOperator.getLeftExpression(),
|
||||
comparisonOperator.getRightExpression());
|
||||
if (!removeFieldNames.contains(columnName)) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
ComparisonOperator constantExpression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(
|
||||
JsqlConstants.EQUAL_CONSTANT);
|
||||
comparisonOperator.setLeftExpression(constantExpression.getLeftExpression());
|
||||
comparisonOperator.setRightExpression(constantExpression.getRightExpression());
|
||||
comparisonOperator.setASTNode(constantExpression.getASTNode());
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("JSQLParserException", e);
|
||||
}
|
||||
}
|
||||
ComparisonOperator comparisonOperator = (ComparisonOperator) expression;
|
||||
if (expression instanceof InExpression) {
|
||||
InExpression inExpression = (InExpression) expression;
|
||||
String columnName = getColumnName(inExpression.getLeftExpression(), inExpression.getRightExpression());
|
||||
if (!removeFieldNames.contains(columnName)) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
InExpression constantExpression = (InExpression) CCJSqlParserUtil.parseCondExpression(
|
||||
JsqlConstants.IN_CONSTANT);
|
||||
inExpression.setLeftExpression(constantExpression.getLeftExpression());
|
||||
inExpression.setRightItemsList(constantExpression.getRightItemsList());
|
||||
inExpression.setASTNode(constantExpression.getASTNode());
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("JSQLParserException", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static String getColumnName(Expression leftExpression, Expression rightExpression) {
|
||||
String columnName = "";
|
||||
if (comparisonOperator.getRightExpression() instanceof Column) {
|
||||
columnName = ((Column) (comparisonOperator).getRightExpression()).getColumnName();
|
||||
if (leftExpression instanceof Column) {
|
||||
columnName = ((Column) leftExpression).getColumnName();
|
||||
}
|
||||
if (comparisonOperator.getLeftExpression() instanceof Column) {
|
||||
columnName = ((Column) (comparisonOperator).getLeftExpression()).getColumnName();
|
||||
}
|
||||
if (!removeFieldNames.contains(columnName)) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
ComparisonOperator constantExpression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(
|
||||
JsqlConstants.EQUAL_CONSTANT);
|
||||
comparisonOperator.setLeftExpression(constantExpression.getLeftExpression());
|
||||
comparisonOperator.setRightExpression(constantExpression.getRightExpression());
|
||||
comparisonOperator.setASTNode(constantExpression.getASTNode());
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("JSQLParserException", e);
|
||||
if (rightExpression instanceof Column) {
|
||||
columnName = ((Column) rightExpression).getColumnName();
|
||||
}
|
||||
return columnName;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -469,6 +469,16 @@ class SqlParserUpdateHelperTest {
|
||||
+ "AND 1 = 1 AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' "
|
||||
+ "ORDER BY 播放量 DESC LIMIT 11",
|
||||
replaceSql);
|
||||
|
||||
sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 "
|
||||
+ "and 歌曲名 in ('邓紫棋','周杰伦') and 歌曲名 in ('邓紫棋') and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'"
|
||||
+ " order by 播放量 desc limit 11";
|
||||
replaceSql = SqlParserUpdateHelper.removeWhereCondition(sql, removeFieldNames);
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 "
|
||||
+ "AND 1 IN (1) AND 1 IN (1) AND 数据日期 = '2023-08-09' AND "
|
||||
+ "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11",
|
||||
replaceSql);
|
||||
}
|
||||
|
||||
private Map<String, String> initParams() {
|
||||
|
||||
Reference in New Issue
Block a user