(improvement)(chat) support remove InExpression and partly complete fillResponse if queryResults exist primaryEntityBizName (#181)

This commit is contained in:
lexluo09
2023-10-09 18:07:14 +08:00
committed by GitHub
parent 7cb8208065
commit 719b797037
6 changed files with 153 additions and 55 deletions

View File

@@ -8,5 +8,6 @@ import lombok.Data;
public class ModelInfo extends DataInfo implements Serializable { public class ModelInfo extends DataInfo implements Serializable {
private List<String> words; private List<String> words;
private String primaryEntityName;
private String primaryEntityBizName; private String primaryEntityBizName;
} }

View File

@@ -1,19 +1,57 @@
package com.tencent.supersonic.chat.responder.execute; package com.tencent.supersonic.chat.responder.execute;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
public class EntityInfoExecuteResponder implements ExecuteResponder { public class EntityInfoExecuteResponder implements ExecuteResponder {
@Override @Override
public void fillResponse(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) { public void fillResponse(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class); SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(semanticParseInfo, queryReq.getUser()); User user = queryReq.getUser();
queryResult.setEntityInfo(entityInfo); queryResult.setEntityInfo(semanticService.getEntityInfo(semanticParseInfo, user));
EntityInfo entityInfo = semanticService.getEntityInfo(semanticParseInfo.getModelId());
if (Objects.isNull(entityInfo) || Objects.isNull(entityInfo.getModelInfo())
|| Objects.isNull(entityInfo.getModelInfo().getPrimaryEntityName())) {
return;
}
String primaryEntityBizName = entityInfo.getModelInfo().getPrimaryEntityBizName();
boolean existPrimaryEntityName = queryResult.getQueryColumns().stream()
.anyMatch(queryColumn -> primaryEntityBizName.equals(queryColumn.getNameEn()));
if (!existPrimaryEntityName) {
return;
}
List<Map<String, Object>> queryResults = queryResult.getQueryResults();
List<String> entities = queryResults.stream()
.map(entry -> entry.get(primaryEntityBizName))
.filter(Objects::nonNull)
.map(String::valueOf)
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(entities)) {
return;
}
QueryResultWithSchemaResp queryResultWithSchemaResp = semanticService.getQueryResultWithSchemaResp(entityInfo,
semanticParseInfo.getModelId(), entities, user);
if (Objects.isNull(queryResultWithSchemaResp)) {
return;
}
List<Map<String, Object>> entityResultList = queryResultWithSchemaResp.getResultList();
if (CollectionUtils.isEmpty(entityResultList)) {
return;
}
} }
} }

View File

@@ -26,9 +26,9 @@ import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp; import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp; import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.DataInfo; import com.tencent.supersonic.chat.api.pojo.response.DataInfo;
import com.tencent.supersonic.chat.api.pojo.response.ModelInfo;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.MetricInfo; import com.tencent.supersonic.chat.api.pojo.response.MetricInfo;
import com.tencent.supersonic.chat.api.pojo.response.ModelInfo;
import com.tencent.supersonic.chat.config.AggregatorConfig; import com.tencent.supersonic.chat.config.AggregatorConfig;
import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder; import com.tencent.supersonic.chat.utils.QueryReqBuilder;
@@ -51,6 +51,7 @@ import java.time.YearMonth;
import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatter;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
@@ -62,6 +63,7 @@ import java.util.Set;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -113,15 +115,12 @@ public class SemanticService {
} }
} }
} }
if (!"".equals(modelInfoId)) { try {
try { setMainModel(entityInfo, parseInfo.getModelId(),
setMainModel(entityInfo, parseInfo.getModelId(), modelInfoId, user);
modelInfoId, user); return entityInfo;
} catch (Exception e) {
return entityInfo; log.error("setMainModel error {}", e);
} catch (Exception e) {
log.error("setMainModel error {}", e);
}
} }
} }
} }
@@ -152,6 +151,7 @@ public class SemanticService {
modelInfo.setWords(modelSchema.getModel().getAlias()); modelInfo.setWords(modelSchema.getModel().getAlias());
modelInfo.setBizName(modelSchema.getModel().getBizName()); modelInfo.setBizName(modelSchema.getModel().getBizName());
if (Objects.nonNull(modelSchema.getEntity())) { if (Objects.nonNull(modelSchema.getEntity())) {
modelInfo.setPrimaryEntityName(modelSchema.getEntity().getName());
modelInfo.setPrimaryEntityBizName(modelSchema.getEntity().getBizName()); modelInfo.setPrimaryEntityBizName(modelSchema.getEntity().getBizName());
} }
@@ -190,9 +190,40 @@ public class SemanticService {
} }
public void setMainModel(EntityInfo modelInfo, Long model, String entity, User user) { public void setMainModel(EntityInfo modelInfo, Long model, String entity, User user) {
ModelSchema modelSchema = schemaService.getModelSchema(model); if (StringUtils.isEmpty(entity)) {
return;
}
modelInfo.setEntityId(entity); List<String> entities = Collections.singletonList(entity);
QueryResultWithSchemaResp queryResultWithColumns = getQueryResultWithSchemaResp(modelInfo, model, entities,
user);
if (queryResultWithColumns != null) {
if (!CollectionUtils.isEmpty(queryResultWithColumns.getResultList())
&& queryResultWithColumns.getResultList().size() > 0) {
Map<String, Object> result = queryResultWithColumns.getResultList().get(0);
for (Map.Entry<String, Object> entry : result.entrySet()) {
String entryKey = getEntryKey(entry);
if (entry.getValue() == null || entryKey == null) {
continue;
}
modelInfo.getDimensions().stream().filter(i -> entryKey.equals(i.getBizName()))
.forEach(i -> i.setValue(entry.getValue().toString()));
modelInfo.getMetrics().stream().filter(i -> entryKey.equals(i.getBizName()))
.forEach(i -> i.setValue(entry.getValue().toString()));
}
}
}
}
public QueryResultWithSchemaResp getQueryResultWithSchemaResp(EntityInfo modelInfo, Long model,
List<String> entities, User user) {
if (CollectionUtils.isEmpty(entities)) {
return null;
}
ModelSchema modelSchema = schemaService.getModelSchema(model);
modelInfo.setEntityId(entities.get(0));
SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setModel(modelSchema.getModel()); semanticParseInfo.setModel(modelSchema.getModel());
semanticParseInfo.setNativeQuery(true); semanticParseInfo.setNativeQuery(true);
@@ -217,10 +248,7 @@ public class SemanticService {
semanticParseInfo.setDateInfo(dateInfo); semanticParseInfo.setDateInfo(dateInfo);
// add filter // add filter
QueryFilter chatFilter = new QueryFilter(); QueryFilter chatFilter = getQueryFilter(modelInfo, entities);
chatFilter.setValue(String.valueOf(entity));
chatFilter.setOperator(FilterOperatorEnum.EQUALS);
chatFilter.setBizName(getEntityPrimaryName(modelInfo));
Set<QueryFilter> chatFilters = new LinkedHashSet(); Set<QueryFilter> chatFilters = new LinkedHashSet();
chatFilters.add(chatFilter); chatFilters.add(chatFilter);
semanticParseInfo.setDimensionFilters(chatFilters); semanticParseInfo.setDimensionFilters(chatFilters);
@@ -232,23 +260,20 @@ public class SemanticService {
} catch (Exception e) { } catch (Exception e) {
log.warn("setMainModel queryByStruct error, e:", e); log.warn("setMainModel queryByStruct error, e:", e);
} }
return queryResultWithColumns;
}
if (queryResultWithColumns != null) { private QueryFilter getQueryFilter(EntityInfo modelInfo, List<String> entities) {
if (!CollectionUtils.isEmpty(queryResultWithColumns.getResultList()) QueryFilter chatFilter = new QueryFilter();
&& queryResultWithColumns.getResultList().size() > 0) { if (entities.size() == 1) {
Map<String, Object> result = queryResultWithColumns.getResultList().get(0); chatFilter.setValue(entities.get(0));
for (Map.Entry<String, Object> entry : result.entrySet()) { chatFilter.setOperator(FilterOperatorEnum.EQUALS);
String entryKey = getEntryKey(entry); } else {
if (entry.getValue() == null || entryKey == null) { chatFilter.setValue(entities);
continue; chatFilter.setOperator(FilterOperatorEnum.IN);
}
modelInfo.getDimensions().stream().filter(i -> entryKey.equals(i.getBizName()))
.forEach(i -> i.setValue(entry.getValue().toString()));
modelInfo.getMetrics().stream().filter(i -> entryKey.equals(i.getBizName()))
.forEach(i -> i.setValue(entry.getValue().toString()));
}
}
} }
chatFilter.setBizName(getEntityPrimaryName(modelInfo));
return chatFilter;
} }
private Set<SchemaElement> getDimensions(EntityInfo modelInfo) { private Set<SchemaElement> getDimensions(EntityInfo modelInfo) {
@@ -332,7 +357,7 @@ public class SemanticService {
} }
public AggregateInfo getAggregateInfo(User user, SemanticParseInfo semanticParseInfo, public AggregateInfo getAggregateInfo(User user, SemanticParseInfo semanticParseInfo,
QueryResultWithSchemaResp result) { QueryResultWithSchemaResp result) {
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics()) || !aggregatorConfig.getEnableRatio()) { if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics()) || !aggregatorConfig.getEnableRatio()) {
return new AggregateInfo(); return new AggregateInfo();
} }
@@ -384,7 +409,7 @@ public class SemanticService {
} }
private MetricInfo queryRatio(User user, SemanticParseInfo semanticParseInfo, SchemaElement metric, private MetricInfo queryRatio(User user, SemanticParseInfo semanticParseInfo, SchemaElement metric,
AggOperatorEnum aggOperatorEnum, QueryResultWithSchemaResp results) { AggOperatorEnum aggOperatorEnum, QueryResultWithSchemaResp results) {
MetricInfo metricInfo = new MetricInfo(); MetricInfo metricInfo = new MetricInfo();
metricInfo.setStatistics(new HashMap<>()); metricInfo.setStatistics(new HashMap<>());
QueryStructReq queryStructReq = QueryReqBuilder.buildStructRatioReq(semanticParseInfo, metric, aggOperatorEnum); QueryStructReq queryStructReq = QueryReqBuilder.buildStructRatioReq(semanticParseInfo, metric, aggOperatorEnum);
@@ -432,7 +457,7 @@ public class SemanticService {
} }
private DateConf getRatioDateConf(AggOperatorEnum aggOperatorEnum, SemanticParseInfo semanticParseInfo, private DateConf getRatioDateConf(AggOperatorEnum aggOperatorEnum, SemanticParseInfo semanticParseInfo,
QueryResultWithSchemaResp results) { QueryResultWithSchemaResp results) {
String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo()); String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo());
Optional<String> lastDayOp = results.getResultList().stream() Optional<String> lastDayOp = results.getResultList().stream()
.map(r -> r.get(dateField).toString()) .map(r -> r.get(dateField).toString())

View File

@@ -15,4 +15,6 @@ public class JsqlConstants {
public static final String GREATER_THAN_EQUALS_CONSTANT = " 1 >= 1 "; public static final String GREATER_THAN_EQUALS_CONSTANT = " 1 >= 1 ";
public static final String EQUAL_CONSTANT = " 1 = 1 "; public static final String EQUAL_CONSTANT = " 1 = 1 ";
public static final String IN_CONSTANT = " 1 in (1) ";
} }

View File

@@ -16,6 +16,7 @@ import net.sf.jsqlparser.expression.operators.conditional.XorExpression;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo; import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList; import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.schema.Table;
@@ -526,29 +527,50 @@ public class SqlParserUpdateHelper {
} }
private static void removeExpressionWithConstant(Expression expression, Set<String> removeFieldNames) { private static void removeExpressionWithConstant(Expression expression, Set<String> removeFieldNames) {
if (!(expression instanceof EqualsTo)) { if (expression instanceof EqualsTo) {
return; ComparisonOperator comparisonOperator = (ComparisonOperator) expression;
String columnName = getColumnName(comparisonOperator.getLeftExpression(),
comparisonOperator.getRightExpression());
if (!removeFieldNames.contains(columnName)) {
return;
}
try {
ComparisonOperator constantExpression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(
JsqlConstants.EQUAL_CONSTANT);
comparisonOperator.setLeftExpression(constantExpression.getLeftExpression());
comparisonOperator.setRightExpression(constantExpression.getRightExpression());
comparisonOperator.setASTNode(constantExpression.getASTNode());
} catch (JSQLParserException e) {
log.error("JSQLParserException", e);
}
} }
ComparisonOperator comparisonOperator = (ComparisonOperator) expression; if (expression instanceof InExpression) {
InExpression inExpression = (InExpression) expression;
String columnName = getColumnName(inExpression.getLeftExpression(), inExpression.getRightExpression());
if (!removeFieldNames.contains(columnName)) {
return;
}
try {
InExpression constantExpression = (InExpression) CCJSqlParserUtil.parseCondExpression(
JsqlConstants.IN_CONSTANT);
inExpression.setLeftExpression(constantExpression.getLeftExpression());
inExpression.setRightItemsList(constantExpression.getRightItemsList());
inExpression.setASTNode(constantExpression.getASTNode());
} catch (JSQLParserException e) {
log.error("JSQLParserException", e);
}
}
}
private static String getColumnName(Expression leftExpression, Expression rightExpression) {
String columnName = ""; String columnName = "";
if (comparisonOperator.getRightExpression() instanceof Column) { if (leftExpression instanceof Column) {
columnName = ((Column) (comparisonOperator).getRightExpression()).getColumnName(); columnName = ((Column) leftExpression).getColumnName();
} }
if (comparisonOperator.getLeftExpression() instanceof Column) { if (rightExpression instanceof Column) {
columnName = ((Column) (comparisonOperator).getLeftExpression()).getColumnName(); columnName = ((Column) rightExpression).getColumnName();
}
if (!removeFieldNames.contains(columnName)) {
return;
}
try {
ComparisonOperator constantExpression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(
JsqlConstants.EQUAL_CONSTANT);
comparisonOperator.setLeftExpression(constantExpression.getLeftExpression());
comparisonOperator.setRightExpression(constantExpression.getRightExpression());
comparisonOperator.setASTNode(constantExpression.getASTNode());
} catch (JSQLParserException e) {
log.error("JSQLParserException", e);
} }
return columnName;
} }
} }

View File

@@ -469,6 +469,16 @@ class SqlParserUpdateHelperTest {
+ "AND 1 = 1 AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' " + "AND 1 = 1 AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' "
+ "ORDER BY 播放量 DESC LIMIT 11", + "ORDER BY 播放量 DESC LIMIT 11",
replaceSql); replaceSql);
sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "and 歌曲名 in ('邓紫棋','周杰伦') and 歌曲名 in ('邓紫棋') and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'"
+ " order by 播放量 desc limit 11";
replaceSql = SqlParserUpdateHelper.removeWhereCondition(sql, removeFieldNames);
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "AND 1 IN (1) AND 1 IN (1) AND 数据日期 = '2023-08-09' AND "
+ "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11",
replaceSql);
} }
private Map<String, String> initParams() { private Map<String, String> initParams() {