(improvement)(chat) support remove InExpression and partly complete fillResponse if queryResults exist primaryEntityBizName (#181)

This commit is contained in:
lexluo09
2023-10-09 18:07:14 +08:00
committed by GitHub
parent 7cb8208065
commit 719b797037
6 changed files with 153 additions and 55 deletions

View File

@@ -8,5 +8,6 @@ import lombok.Data;
public class ModelInfo extends DataInfo implements Serializable {
private List<String> words;
private String primaryEntityName;
private String primaryEntityBizName;
}

View File

@@ -1,19 +1,57 @@
package com.tencent.supersonic.chat.responder.execute;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
public class EntityInfoExecuteResponder implements ExecuteResponder {
@Override
public void fillResponse(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(semanticParseInfo, queryReq.getUser());
queryResult.setEntityInfo(entityInfo);
User user = queryReq.getUser();
queryResult.setEntityInfo(semanticService.getEntityInfo(semanticParseInfo, user));
EntityInfo entityInfo = semanticService.getEntityInfo(semanticParseInfo.getModelId());
if (Objects.isNull(entityInfo) || Objects.isNull(entityInfo.getModelInfo())
|| Objects.isNull(entityInfo.getModelInfo().getPrimaryEntityName())) {
return;
}
String primaryEntityBizName = entityInfo.getModelInfo().getPrimaryEntityBizName();
boolean existPrimaryEntityName = queryResult.getQueryColumns().stream()
.anyMatch(queryColumn -> primaryEntityBizName.equals(queryColumn.getNameEn()));
if (!existPrimaryEntityName) {
return;
}
List<Map<String, Object>> queryResults = queryResult.getQueryResults();
List<String> entities = queryResults.stream()
.map(entry -> entry.get(primaryEntityBizName))
.filter(Objects::nonNull)
.map(String::valueOf)
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(entities)) {
return;
}
QueryResultWithSchemaResp queryResultWithSchemaResp = semanticService.getQueryResultWithSchemaResp(entityInfo,
semanticParseInfo.getModelId(), entities, user);
if (Objects.isNull(queryResultWithSchemaResp)) {
return;
}
List<Map<String, Object>> entityResultList = queryResultWithSchemaResp.getResultList();
if (CollectionUtils.isEmpty(entityResultList)) {
return;
}
}
}

View File

@@ -26,9 +26,9 @@ import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.DataInfo;
import com.tencent.supersonic.chat.api.pojo.response.ModelInfo;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.MetricInfo;
import com.tencent.supersonic.chat.api.pojo.response.ModelInfo;
import com.tencent.supersonic.chat.config.AggregatorConfig;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
@@ -51,6 +51,7 @@ import java.time.YearMonth;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
@@ -62,6 +63,7 @@ import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@@ -113,15 +115,12 @@ public class SemanticService {
}
}
}
if (!"".equals(modelInfoId)) {
try {
setMainModel(entityInfo, parseInfo.getModelId(),
modelInfoId, user);
return entityInfo;
} catch (Exception e) {
log.error("setMainModel error {}", e);
}
try {
setMainModel(entityInfo, parseInfo.getModelId(),
modelInfoId, user);
return entityInfo;
} catch (Exception e) {
log.error("setMainModel error {}", e);
}
}
}
@@ -152,6 +151,7 @@ public class SemanticService {
modelInfo.setWords(modelSchema.getModel().getAlias());
modelInfo.setBizName(modelSchema.getModel().getBizName());
if (Objects.nonNull(modelSchema.getEntity())) {
modelInfo.setPrimaryEntityName(modelSchema.getEntity().getName());
modelInfo.setPrimaryEntityBizName(modelSchema.getEntity().getBizName());
}
@@ -190,9 +190,40 @@ public class SemanticService {
}
public void setMainModel(EntityInfo modelInfo, Long model, String entity, User user) {
ModelSchema modelSchema = schemaService.getModelSchema(model);
if (StringUtils.isEmpty(entity)) {
return;
}
modelInfo.setEntityId(entity);
List<String> entities = Collections.singletonList(entity);
QueryResultWithSchemaResp queryResultWithColumns = getQueryResultWithSchemaResp(modelInfo, model, entities,
user);
if (queryResultWithColumns != null) {
if (!CollectionUtils.isEmpty(queryResultWithColumns.getResultList())
&& queryResultWithColumns.getResultList().size() > 0) {
Map<String, Object> result = queryResultWithColumns.getResultList().get(0);
for (Map.Entry<String, Object> entry : result.entrySet()) {
String entryKey = getEntryKey(entry);
if (entry.getValue() == null || entryKey == null) {
continue;
}
modelInfo.getDimensions().stream().filter(i -> entryKey.equals(i.getBizName()))
.forEach(i -> i.setValue(entry.getValue().toString()));
modelInfo.getMetrics().stream().filter(i -> entryKey.equals(i.getBizName()))
.forEach(i -> i.setValue(entry.getValue().toString()));
}
}
}
}
public QueryResultWithSchemaResp getQueryResultWithSchemaResp(EntityInfo modelInfo, Long model,
List<String> entities, User user) {
if (CollectionUtils.isEmpty(entities)) {
return null;
}
ModelSchema modelSchema = schemaService.getModelSchema(model);
modelInfo.setEntityId(entities.get(0));
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setModel(modelSchema.getModel());
semanticParseInfo.setNativeQuery(true);
@@ -217,10 +248,7 @@ public class SemanticService {
semanticParseInfo.setDateInfo(dateInfo);
// add filter
QueryFilter chatFilter = new QueryFilter();
chatFilter.setValue(String.valueOf(entity));
chatFilter.setOperator(FilterOperatorEnum.EQUALS);
chatFilter.setBizName(getEntityPrimaryName(modelInfo));
QueryFilter chatFilter = getQueryFilter(modelInfo, entities);
Set<QueryFilter> chatFilters = new LinkedHashSet();
chatFilters.add(chatFilter);
semanticParseInfo.setDimensionFilters(chatFilters);
@@ -232,23 +260,20 @@ public class SemanticService {
} catch (Exception e) {
log.warn("setMainModel queryByStruct error, e:", e);
}
return queryResultWithColumns;
}
if (queryResultWithColumns != null) {
if (!CollectionUtils.isEmpty(queryResultWithColumns.getResultList())
&& queryResultWithColumns.getResultList().size() > 0) {
Map<String, Object> result = queryResultWithColumns.getResultList().get(0);
for (Map.Entry<String, Object> entry : result.entrySet()) {
String entryKey = getEntryKey(entry);
if (entry.getValue() == null || entryKey == null) {
continue;
}
modelInfo.getDimensions().stream().filter(i -> entryKey.equals(i.getBizName()))
.forEach(i -> i.setValue(entry.getValue().toString()));
modelInfo.getMetrics().stream().filter(i -> entryKey.equals(i.getBizName()))
.forEach(i -> i.setValue(entry.getValue().toString()));
}
}
private QueryFilter getQueryFilter(EntityInfo modelInfo, List<String> entities) {
QueryFilter chatFilter = new QueryFilter();
if (entities.size() == 1) {
chatFilter.setValue(entities.get(0));
chatFilter.setOperator(FilterOperatorEnum.EQUALS);
} else {
chatFilter.setValue(entities);
chatFilter.setOperator(FilterOperatorEnum.IN);
}
chatFilter.setBizName(getEntityPrimaryName(modelInfo));
return chatFilter;
}
private Set<SchemaElement> getDimensions(EntityInfo modelInfo) {
@@ -332,7 +357,7 @@ public class SemanticService {
}
public AggregateInfo getAggregateInfo(User user, SemanticParseInfo semanticParseInfo,
QueryResultWithSchemaResp result) {
QueryResultWithSchemaResp result) {
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics()) || !aggregatorConfig.getEnableRatio()) {
return new AggregateInfo();
}
@@ -384,7 +409,7 @@ public class SemanticService {
}
private MetricInfo queryRatio(User user, SemanticParseInfo semanticParseInfo, SchemaElement metric,
AggOperatorEnum aggOperatorEnum, QueryResultWithSchemaResp results) {
AggOperatorEnum aggOperatorEnum, QueryResultWithSchemaResp results) {
MetricInfo metricInfo = new MetricInfo();
metricInfo.setStatistics(new HashMap<>());
QueryStructReq queryStructReq = QueryReqBuilder.buildStructRatioReq(semanticParseInfo, metric, aggOperatorEnum);
@@ -432,7 +457,7 @@ public class SemanticService {
}
private DateConf getRatioDateConf(AggOperatorEnum aggOperatorEnum, SemanticParseInfo semanticParseInfo,
QueryResultWithSchemaResp results) {
QueryResultWithSchemaResp results) {
String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo());
Optional<String> lastDayOp = results.getResultList().stream()
.map(r -> r.get(dateField).toString())

View File

@@ -15,4 +15,6 @@ public class JsqlConstants {
public static final String GREATER_THAN_EQUALS_CONSTANT = " 1 >= 1 ";
public static final String EQUAL_CONSTANT = " 1 = 1 ";
public static final String IN_CONSTANT = " 1 in (1) ";
}

View File

@@ -16,6 +16,7 @@ import net.sf.jsqlparser.expression.operators.conditional.XorExpression;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
@@ -526,29 +527,50 @@ public class SqlParserUpdateHelper {
}
private static void removeExpressionWithConstant(Expression expression, Set<String> removeFieldNames) {
if (!(expression instanceof EqualsTo)) {
return;
if (expression instanceof EqualsTo) {
ComparisonOperator comparisonOperator = (ComparisonOperator) expression;
String columnName = getColumnName(comparisonOperator.getLeftExpression(),
comparisonOperator.getRightExpression());
if (!removeFieldNames.contains(columnName)) {
return;
}
try {
ComparisonOperator constantExpression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(
JsqlConstants.EQUAL_CONSTANT);
comparisonOperator.setLeftExpression(constantExpression.getLeftExpression());
comparisonOperator.setRightExpression(constantExpression.getRightExpression());
comparisonOperator.setASTNode(constantExpression.getASTNode());
} catch (JSQLParserException e) {
log.error("JSQLParserException", e);
}
}
ComparisonOperator comparisonOperator = (ComparisonOperator) expression;
if (expression instanceof InExpression) {
InExpression inExpression = (InExpression) expression;
String columnName = getColumnName(inExpression.getLeftExpression(), inExpression.getRightExpression());
if (!removeFieldNames.contains(columnName)) {
return;
}
try {
InExpression constantExpression = (InExpression) CCJSqlParserUtil.parseCondExpression(
JsqlConstants.IN_CONSTANT);
inExpression.setLeftExpression(constantExpression.getLeftExpression());
inExpression.setRightItemsList(constantExpression.getRightItemsList());
inExpression.setASTNode(constantExpression.getASTNode());
} catch (JSQLParserException e) {
log.error("JSQLParserException", e);
}
}
}
private static String getColumnName(Expression leftExpression, Expression rightExpression) {
String columnName = "";
if (comparisonOperator.getRightExpression() instanceof Column) {
columnName = ((Column) (comparisonOperator).getRightExpression()).getColumnName();
if (leftExpression instanceof Column) {
columnName = ((Column) leftExpression).getColumnName();
}
if (comparisonOperator.getLeftExpression() instanceof Column) {
columnName = ((Column) (comparisonOperator).getLeftExpression()).getColumnName();
}
if (!removeFieldNames.contains(columnName)) {
return;
}
try {
ComparisonOperator constantExpression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(
JsqlConstants.EQUAL_CONSTANT);
comparisonOperator.setLeftExpression(constantExpression.getLeftExpression());
comparisonOperator.setRightExpression(constantExpression.getRightExpression());
comparisonOperator.setASTNode(constantExpression.getASTNode());
} catch (JSQLParserException e) {
log.error("JSQLParserException", e);
if (rightExpression instanceof Column) {
columnName = ((Column) rightExpression).getColumnName();
}
return columnName;
}
}

View File

@@ -469,6 +469,16 @@ class SqlParserUpdateHelperTest {
+ "AND 1 = 1 AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' "
+ "ORDER BY 播放量 DESC LIMIT 11",
replaceSql);
sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "and 歌曲名 in ('邓紫棋','周杰伦') and 歌曲名 in ('邓紫棋') and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'"
+ " order by 播放量 desc limit 11";
replaceSql = SqlParserUpdateHelper.removeWhereCondition(sql, removeFieldNames);
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "AND 1 IN (1) AND 1 IN (1) AND 数据日期 = '2023-08-09' AND "
+ "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11",
replaceSql);
}
private Map<String, String> initParams() {