(improvement)(chat) support updateFieldValueByLinkingValue and fix EntityListQuery NullPointerException (#182)

This commit is contained in:
lexluo09
2023-10-10 11:16:52 +08:00
committed by GitHub
parent 719b797037
commit eee39f56a8
8 changed files with 164 additions and 113 deletions

View File

@@ -27,6 +27,8 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
updateFieldNameByLinkingValue(semanticCorrectInfo); updateFieldNameByLinkingValue(semanticCorrectInfo);
updateFieldValueByLinkingValue(semanticCorrectInfo);
correctFieldName(semanticCorrectInfo); correctFieldName(semanticCorrectInfo);
} }
@@ -45,17 +47,7 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
} }
private void updateFieldNameByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) { private void updateFieldNameByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) {
Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT); List<ElementValue> linking = getLinkingValues(semanticCorrectInfo);
if (Objects.isNull(context)) {
return;
}
DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class);
if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) {
return;
}
LLMReq llmReq = dslParseResult.getLlmReq();
List<ElementValue> linking = llmReq.getLinking();
if (CollectionUtils.isEmpty(linking)) { if (CollectionUtils.isEmpty(linking)) {
return; return;
} }
@@ -68,4 +60,37 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
fieldValueToFieldNames); fieldValueToFieldNames);
semanticCorrectInfo.setSql(sql); semanticCorrectInfo.setSql(sql);
} }
private List<ElementValue> getLinkingValues(SemanticCorrectInfo semanticCorrectInfo) {
Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
if (Objects.isNull(context)) {
return null;
}
DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class);
if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) {
return null;
}
LLMReq llmReq = dslParseResult.getLlmReq();
return llmReq.getLinking();
}
private void updateFieldValueByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) {
List<ElementValue> linking = getLinkingValues(semanticCorrectInfo);
if (CollectionUtils.isEmpty(linking)) {
return;
}
Map<String, Map<String, String>> filedNameToValueMap = linking.stream().collect(
Collectors.groupingBy(ElementValue::getFieldName,
Collectors.mapping(ElementValue::getFieldValue, Collectors.toMap(
oldValue -> oldValue,
newValue -> newValue,
(existingValue, newValue) -> newValue)
)));
String sql = SqlParserUpdateHelper.replaceValue(semanticCorrectInfo.getSql(), filedNameToValueMap, false);
semanticCorrectInfo.setSql(sql);
}
} }

View File

@@ -1,9 +1,9 @@
package com.tencent.supersonic.chat.query.rule.entity; package com.tencent.supersonic.chat.query.rule.entity;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.ModelSchema; import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp; import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp; import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
@@ -12,10 +12,10 @@ import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.Order; import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import org.apache.commons.collections.CollectionUtils;
public abstract class EntityListQuery extends EntitySemanticQuery { public abstract class EntityListQuery extends EntitySemanticQuery {
@@ -41,13 +41,17 @@ public abstract class EntityListQuery extends EntitySemanticQuery {
ChatDefaultRichConfigResp chatDefaultConfig = chaConfigRichDesc ChatDefaultRichConfigResp chatDefaultConfig = chaConfigRichDesc
.getChatDetailRichConfig().getChatDefaultConfig(); .getChatDetailRichConfig().getChatDefaultConfig();
if (chatDefaultConfig != null) { if (chatDefaultConfig != null) {
chatDefaultConfig.getMetrics().stream() if (CollectionUtils.isNotEmpty(chatDefaultConfig.getMetrics())) {
.forEach(metric -> { chatDefaultConfig.getMetrics().stream()
metrics.add(metric); .forEach(metric -> {
orders.add(new Order(metric.getBizName(), Constants.DESC_UPPER)); metrics.add(metric);
}); orders.add(new Order(metric.getBizName(), Constants.DESC_UPPER));
chatDefaultConfig.getDimensions().stream() });
.forEach(dimension -> dimensions.add(dimension)); }
if (CollectionUtils.isNotEmpty(chatDefaultConfig.getDimensions())) {
chatDefaultConfig.getDimensions().stream()
.forEach(dimension -> dimensions.add(dimension));
}
} }

View File

@@ -7,7 +7,6 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@@ -43,15 +42,6 @@ public class EntityInfoExecuteResponder implements ExecuteResponder {
if (CollectionUtils.isEmpty(entities)) { if (CollectionUtils.isEmpty(entities)) {
return; return;
} }
QueryResultWithSchemaResp queryResultWithSchemaResp = semanticService.getQueryResultWithSchemaResp(entityInfo,
semanticParseInfo.getModelId(), entities, user);
if (Objects.isNull(queryResultWithSchemaResp)) {
return;
}
List<Map<String, Object>> entityResultList = queryResultWithSchemaResp.getResultList();
if (CollectionUtils.isEmpty(entityResultList)) {
return;
}
} }
} }

View File

@@ -21,43 +21,19 @@ import org.springframework.util.CollectionUtils;
public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter { public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter {
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
private boolean exactReplace;
private Map<String, Map<String, String>> filedNameToValueMap; private Map<String, Map<String, String>> filedNameToValueMap;
public FieldlValueReplaceVisitor(Map<String, Map<String, String>> filedNameToValueMap) {
public FieldlValueReplaceVisitor(boolean exactReplace, Map<String, Map<String, String>> filedNameToValueMap) {
this.exactReplace = exactReplace;
this.filedNameToValueMap = filedNameToValueMap; this.filedNameToValueMap = filedNameToValueMap;
} }
@Override @Override
public void visit(EqualsTo expr) { public void visit(EqualsTo expr) {
Expression leftExpression = expr.getLeftExpression(); replaceComparisonExpression(expr);
Expression rightExpression = expr.getRightExpression();
if (!(rightExpression instanceof StringValue)) {
return;
}
if (!(leftExpression instanceof Column)) {
return;
}
if (CollectionUtils.isEmpty(filedNameToValueMap)) {
return;
}
if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) {
return;
}
Column leftColumnName = (Column) leftExpression;
StringValue rightStringValue = (StringValue) rightExpression;
String columnName = leftColumnName.getColumnName();
if (StringUtils.isEmpty(columnName)) {
return;
}
Map<String, String> valueMap = filedNameToValueMap.get(columnName);
if (Objects.isNull(valueMap) || valueMap.isEmpty()) {
return;
}
String replaceValue = valueMap.get(rightStringValue.getValue());
if (StringUtils.isNotEmpty(replaceValue)) {
rightStringValue.setValue(replaceValue);
}
} }
public void visit(GreaterThan expr) { public void visit(GreaterThan expr) {
@@ -77,47 +53,57 @@ public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter {
} }
public <T extends Expression> void replaceComparisonExpression(T expression) { public <T extends Expression> void replaceComparisonExpression(T expression) {
if ((expression instanceof GreaterThanEquals) || (expression instanceof GreaterThan) Expression leftExpression = ((ComparisonOperator) expression).getLeftExpression();
|| (expression instanceof MinorThanEquals) || (expression instanceof MinorThan)) { Expression rightExpression = ((ComparisonOperator) expression).getRightExpression();
Expression leftExpression = ((ComparisonOperator) expression).getLeftExpression(); if (!(leftExpression instanceof Column)) {
Expression rightExpression = ((ComparisonOperator) expression).getRightExpression(); return;
if (!(leftExpression instanceof Column)) { }
return; if (CollectionUtils.isEmpty(filedNameToValueMap)) {
} return;
if (CollectionUtils.isEmpty(filedNameToValueMap)) { }
return; if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) {
} return;
if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) { }
return; Column leftColumnName = (Column) leftExpression;
}
Column leftColumnName = (Column) leftExpression;
String columnName = leftColumnName.getColumnName(); String columnName = leftColumnName.getColumnName();
if (StringUtils.isEmpty(columnName)) { if (StringUtils.isEmpty(columnName)) {
return; return;
} }
Map<String, String> valueMap = filedNameToValueMap.get(columnName);
if (Objects.isNull(valueMap) || valueMap.isEmpty()) {
return;
}
for (String oriValue : valueMap.keySet()) {
String replaceValue = valueMap.get(oriValue);
if (StringUtils.isNotEmpty(replaceValue)) {
if (rightExpression instanceof LongValue) {
LongValue rightStringValue = (LongValue) rightExpression;
rightStringValue.setValue(Long.parseLong(replaceValue));
}
if (rightExpression instanceof DoubleValue) {
DoubleValue rightStringValue = (DoubleValue) rightExpression;
rightStringValue.setValue(Double.parseDouble(replaceValue));
}
if (rightExpression instanceof StringValue) {
StringValue rightStringValue = (StringValue) rightExpression;
rightStringValue.setValue(replaceValue);
}
}
}
Map<String, String> valueMap = filedNameToValueMap.get(columnName);
if (Objects.isNull(valueMap) || valueMap.isEmpty()) {
return;
}
if (rightExpression instanceof LongValue) {
LongValue rightStringValue = (LongValue) rightExpression;
String replaceValue = getReplaceValue(valueMap, String.valueOf(rightStringValue.getValue()));
if (StringUtils.isNotEmpty(replaceValue)) {
rightStringValue.setValue(Long.parseLong(replaceValue));
}
}
if (rightExpression instanceof DoubleValue) {
DoubleValue rightStringValue = (DoubleValue) rightExpression;
String replaceValue = getReplaceValue(valueMap, String.valueOf(rightStringValue.getValue()));
if (StringUtils.isNotEmpty(replaceValue)) {
rightStringValue.setValue(Double.parseDouble(replaceValue));
}
}
if (rightExpression instanceof StringValue) {
StringValue rightStringValue = (StringValue) rightExpression;
String replaceValue = getReplaceValue(valueMap, String.valueOf(rightStringValue.getValue()));
if (StringUtils.isNotEmpty(replaceValue)) {
rightStringValue.setValue(replaceValue);
}
} }
} }
private String getReplaceValue(Map<String, String> valueMap, String beforeValue) {
String afterValue = valueMap.get(String.valueOf(beforeValue));
if (StringUtils.isEmpty(afterValue) && !exactReplace) {
return parseVisitorHelper.getReplaceValue(beforeValue, valueMap, false);
}
return afterValue;
}
} }

View File

@@ -38,7 +38,7 @@ public class GroupByReplaceVisitor implements GroupByVisitor {
((Function) expression).getParameters().getExpressions().get(0))) { ((Function) expression).getParameters().getExpressions().get(0))) {
columnName = ((Function) expression).getParameters().getExpressions().get(0).toString(); columnName = ((Function) expression).getParameters().getExpressions().get(0).toString();
} }
String replaceColumn = parseVisitorHelper.getReplaceColumn(columnName, fieldNameMap, String replaceColumn = parseVisitorHelper.getReplaceValue(columnName, fieldNameMap,
exactReplace); exactReplace);
if (StringUtils.isNotEmpty(replaceColumn)) { if (StringUtils.isNotEmpty(replaceColumn)) {
if (expression instanceof Column) { if (expression instanceof Column) {

View File

@@ -13,32 +13,32 @@ public class ParseVisitorHelper {
public void replaceColumn(Column column, Map<String, String> fieldNameMap, boolean exactReplace) { public void replaceColumn(Column column, Map<String, String> fieldNameMap, boolean exactReplace) {
String columnName = column.getColumnName(); String columnName = column.getColumnName();
String replaceColumn = getReplaceColumn(columnName, fieldNameMap, exactReplace); String replaceColumn = getReplaceValue(columnName, fieldNameMap, exactReplace);
if (StringUtils.isNotBlank(replaceColumn)) { if (StringUtils.isNotBlank(replaceColumn)) {
column.setColumnName(replaceColumn); column.setColumnName(replaceColumn);
} }
} }
public String getReplaceColumn(String columnName, Map<String, String> fieldNameMap, boolean exactReplace) { public String getReplaceValue(String beforeValue, Map<String, String> valueMap, boolean exactReplace) {
String fieldName = fieldNameMap.get(columnName); String value = valueMap.get(beforeValue);
if (StringUtils.isNotBlank(fieldName)) { if (StringUtils.isNotBlank(value)) {
return fieldName; return value;
} }
if (exactReplace) { if (exactReplace) {
return null; return null;
} }
Optional<Entry<String, String>> first = fieldNameMap.entrySet().stream().sorted((k1, k2) -> { Optional<Entry<String, String>> first = valueMap.entrySet().stream().sorted((k1, k2) -> {
String k1FieldNameDb = k1.getKey(); String k1Value = k1.getKey();
String k2FieldNameDb = k2.getKey(); String k2Value = k2.getKey();
Double k1Similarity = getSimilarity(columnName, k1FieldNameDb); Double k1Similarity = getSimilarity(beforeValue, k1Value);
Double k2Similarity = getSimilarity(columnName, k2FieldNameDb); Double k2Similarity = getSimilarity(beforeValue, k2Value);
return k2Similarity.compareTo(k1Similarity); return k2Similarity.compareTo(k1Similarity);
}).collect(Collectors.toList()).stream().findFirst(); }).collect(Collectors.toList()).stream().findFirst();
if (first.isPresent()) { if (first.isPresent()) {
return first.get().getValue(); return first.get().getValue();
} }
return columnName; return beforeValue;
} }
public static int editDistance(String word1, String word2) { public static int editDistance(String word1, String word2) {

View File

@@ -39,6 +39,11 @@ import org.springframework.util.CollectionUtils;
public class SqlParserUpdateHelper { public class SqlParserUpdateHelper {
public static String replaceValue(String sql, Map<String, Map<String, String>> filedNameToValueMap) { public static String replaceValue(String sql, Map<String, Map<String, String>> filedNameToValueMap) {
return replaceValue(sql, filedNameToValueMap, true);
}
public static String replaceValue(String sql, Map<String, Map<String, String>> filedNameToValueMap,
boolean exactReplace) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql); Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody(); SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) { if (!(selectBody instanceof PlainSelect)) {
@@ -46,7 +51,7 @@ public class SqlParserUpdateHelper {
} }
PlainSelect plainSelect = (PlainSelect) selectBody; PlainSelect plainSelect = (PlainSelect) selectBody;
Expression where = plainSelect.getWhere(); Expression where = plainSelect.getWhere();
FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(filedNameToValueMap); FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(exactReplace, filedNameToValueMap);
if (Objects.nonNull(where)) { if (Objects.nonNull(where)) {
where.accept(visitor); where.accept(visitor);
} }

View File

@@ -18,6 +18,47 @@ import org.junit.jupiter.api.Test;
*/ */
class SqlParserUpdateHelperTest { class SqlParserUpdateHelperTest {
@Test
void replaceValue() {
String replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "and 歌手名 = '杰伦' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'"
+ " order by 播放量 desc limit 11";
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, String> valueMap = new HashMap<>();
valueMap.put("杰伦", "周杰伦");
filedNameToValueMap.put("歌手名", valueMap);
replaceSql = SqlParserUpdateHelper.replaceValue(replaceSql, filedNameToValueMap);
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND "
+ "歌手名 = '周杰伦' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' "
+ "ORDER BY 播放量 DESC LIMIT 11", replaceSql);
replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "and 歌手名 = '周杰' and 歌手名 = '林俊' and 歌手名 = '陈' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'"
+ " order by 播放量 desc limit 11";
Map<String, Map<String, String>> filedNameToValueMap2 = new HashMap<>();
Map<String, String> valueMap2 = new HashMap<>();
valueMap2.put("周杰伦", "周杰伦");
valueMap2.put("林俊杰", "林俊杰");
valueMap2.put("陈奕迅", "陈奕迅");
filedNameToValueMap2.put("歌手名", valueMap2);
replaceSql = SqlParserUpdateHelper.replaceValue(replaceSql, filedNameToValueMap2, false);
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' "
+ "AND 歌手名 = '林俊杰' AND 歌手名 = '陈奕迅' AND 数据日期 = '2023-08-09' AND "
+ "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql);
}
@Test @Test
void replaceFieldNameByValue() { void replaceFieldNameByValue() {