mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(chat) support updateFieldValueByLinkingValue and fix EntityListQuery NullPointerException (#182)
This commit is contained in:
@@ -27,6 +27,8 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
|
||||
|
||||
updateFieldNameByLinkingValue(semanticCorrectInfo);
|
||||
|
||||
updateFieldValueByLinkingValue(semanticCorrectInfo);
|
||||
|
||||
correctFieldName(semanticCorrectInfo);
|
||||
}
|
||||
|
||||
@@ -45,17 +47,7 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
private void updateFieldNameByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
|
||||
if (Objects.isNull(context)) {
|
||||
return;
|
||||
}
|
||||
|
||||
DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class);
|
||||
if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) {
|
||||
return;
|
||||
}
|
||||
LLMReq llmReq = dslParseResult.getLlmReq();
|
||||
List<ElementValue> linking = llmReq.getLinking();
|
||||
List<ElementValue> linking = getLinkingValues(semanticCorrectInfo);
|
||||
if (CollectionUtils.isEmpty(linking)) {
|
||||
return;
|
||||
}
|
||||
@@ -68,4 +60,37 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
|
||||
fieldValueToFieldNames);
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
}
|
||||
|
||||
private List<ElementValue> getLinkingValues(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
|
||||
if (Objects.isNull(context)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class);
|
||||
if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) {
|
||||
return null;
|
||||
}
|
||||
LLMReq llmReq = dslParseResult.getLlmReq();
|
||||
return llmReq.getLinking();
|
||||
}
|
||||
|
||||
|
||||
private void updateFieldValueByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
List<ElementValue> linking = getLinkingValues(semanticCorrectInfo);
|
||||
if (CollectionUtils.isEmpty(linking)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> filedNameToValueMap = linking.stream().collect(
|
||||
Collectors.groupingBy(ElementValue::getFieldName,
|
||||
Collectors.mapping(ElementValue::getFieldValue, Collectors.toMap(
|
||||
oldValue -> oldValue,
|
||||
newValue -> newValue,
|
||||
(existingValue, newValue) -> newValue)
|
||||
)));
|
||||
|
||||
String sql = SqlParserUpdateHelper.replaceValue(semanticCorrectInfo.getSql(), filedNameToValueMap, false);
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
package com.tencent.supersonic.chat.query.rule.entity;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
|
||||
@@ -12,10 +12,10 @@ import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
public abstract class EntityListQuery extends EntitySemanticQuery {
|
||||
|
||||
@@ -41,13 +41,17 @@ public abstract class EntityListQuery extends EntitySemanticQuery {
|
||||
ChatDefaultRichConfigResp chatDefaultConfig = chaConfigRichDesc
|
||||
.getChatDetailRichConfig().getChatDefaultConfig();
|
||||
if (chatDefaultConfig != null) {
|
||||
chatDefaultConfig.getMetrics().stream()
|
||||
.forEach(metric -> {
|
||||
metrics.add(metric);
|
||||
orders.add(new Order(metric.getBizName(), Constants.DESC_UPPER));
|
||||
});
|
||||
chatDefaultConfig.getDimensions().stream()
|
||||
.forEach(dimension -> dimensions.add(dimension));
|
||||
if (CollectionUtils.isNotEmpty(chatDefaultConfig.getMetrics())) {
|
||||
chatDefaultConfig.getMetrics().stream()
|
||||
.forEach(metric -> {
|
||||
metrics.add(metric);
|
||||
orders.add(new Order(metric.getBizName(), Constants.DESC_UPPER));
|
||||
});
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(chatDefaultConfig.getDimensions())) {
|
||||
chatDefaultConfig.getDimensions().stream()
|
||||
.forEach(dimension -> dimensions.add(dimension));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
@@ -43,15 +42,6 @@ public class EntityInfoExecuteResponder implements ExecuteResponder {
|
||||
if (CollectionUtils.isEmpty(entities)) {
|
||||
return;
|
||||
}
|
||||
QueryResultWithSchemaResp queryResultWithSchemaResp = semanticService.getQueryResultWithSchemaResp(entityInfo,
|
||||
semanticParseInfo.getModelId(), entities, user);
|
||||
if (Objects.isNull(queryResultWithSchemaResp)) {
|
||||
return;
|
||||
}
|
||||
List<Map<String, Object>> entityResultList = queryResultWithSchemaResp.getResultList();
|
||||
if (CollectionUtils.isEmpty(entityResultList)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -21,43 +21,19 @@ import org.springframework.util.CollectionUtils;
|
||||
|
||||
public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
|
||||
private boolean exactReplace;
|
||||
private Map<String, Map<String, String>> filedNameToValueMap;
|
||||
|
||||
public FieldlValueReplaceVisitor(Map<String, Map<String, String>> filedNameToValueMap) {
|
||||
|
||||
public FieldlValueReplaceVisitor(boolean exactReplace, Map<String, Map<String, String>> filedNameToValueMap) {
|
||||
this.exactReplace = exactReplace;
|
||||
this.filedNameToValueMap = filedNameToValueMap;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(EqualsTo expr) {
|
||||
Expression leftExpression = expr.getLeftExpression();
|
||||
Expression rightExpression = expr.getRightExpression();
|
||||
if (!(rightExpression instanceof StringValue)) {
|
||||
return;
|
||||
}
|
||||
if (!(leftExpression instanceof Column)) {
|
||||
return;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(filedNameToValueMap)) {
|
||||
return;
|
||||
}
|
||||
if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) {
|
||||
return;
|
||||
}
|
||||
Column leftColumnName = (Column) leftExpression;
|
||||
StringValue rightStringValue = (StringValue) rightExpression;
|
||||
|
||||
String columnName = leftColumnName.getColumnName();
|
||||
if (StringUtils.isEmpty(columnName)) {
|
||||
return;
|
||||
}
|
||||
Map<String, String> valueMap = filedNameToValueMap.get(columnName);
|
||||
if (Objects.isNull(valueMap) || valueMap.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
String replaceValue = valueMap.get(rightStringValue.getValue());
|
||||
if (StringUtils.isNotEmpty(replaceValue)) {
|
||||
rightStringValue.setValue(replaceValue);
|
||||
}
|
||||
replaceComparisonExpression(expr);
|
||||
}
|
||||
|
||||
public void visit(GreaterThan expr) {
|
||||
@@ -77,47 +53,57 @@ public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
}
|
||||
|
||||
public <T extends Expression> void replaceComparisonExpression(T expression) {
|
||||
if ((expression instanceof GreaterThanEquals) || (expression instanceof GreaterThan)
|
||||
|| (expression instanceof MinorThanEquals) || (expression instanceof MinorThan)) {
|
||||
Expression leftExpression = ((ComparisonOperator) expression).getLeftExpression();
|
||||
Expression rightExpression = ((ComparisonOperator) expression).getRightExpression();
|
||||
if (!(leftExpression instanceof Column)) {
|
||||
return;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(filedNameToValueMap)) {
|
||||
return;
|
||||
}
|
||||
if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) {
|
||||
return;
|
||||
}
|
||||
Column leftColumnName = (Column) leftExpression;
|
||||
Expression leftExpression = ((ComparisonOperator) expression).getLeftExpression();
|
||||
Expression rightExpression = ((ComparisonOperator) expression).getRightExpression();
|
||||
if (!(leftExpression instanceof Column)) {
|
||||
return;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(filedNameToValueMap)) {
|
||||
return;
|
||||
}
|
||||
if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) {
|
||||
return;
|
||||
}
|
||||
Column leftColumnName = (Column) leftExpression;
|
||||
|
||||
String columnName = leftColumnName.getColumnName();
|
||||
if (StringUtils.isEmpty(columnName)) {
|
||||
return;
|
||||
}
|
||||
Map<String, String> valueMap = filedNameToValueMap.get(columnName);
|
||||
if (Objects.isNull(valueMap) || valueMap.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
for (String oriValue : valueMap.keySet()) {
|
||||
String replaceValue = valueMap.get(oriValue);
|
||||
if (StringUtils.isNotEmpty(replaceValue)) {
|
||||
if (rightExpression instanceof LongValue) {
|
||||
LongValue rightStringValue = (LongValue) rightExpression;
|
||||
rightStringValue.setValue(Long.parseLong(replaceValue));
|
||||
}
|
||||
if (rightExpression instanceof DoubleValue) {
|
||||
DoubleValue rightStringValue = (DoubleValue) rightExpression;
|
||||
rightStringValue.setValue(Double.parseDouble(replaceValue));
|
||||
}
|
||||
if (rightExpression instanceof StringValue) {
|
||||
StringValue rightStringValue = (StringValue) rightExpression;
|
||||
rightStringValue.setValue(replaceValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
String columnName = leftColumnName.getColumnName();
|
||||
if (StringUtils.isEmpty(columnName)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, String> valueMap = filedNameToValueMap.get(columnName);
|
||||
if (Objects.isNull(valueMap) || valueMap.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (rightExpression instanceof LongValue) {
|
||||
LongValue rightStringValue = (LongValue) rightExpression;
|
||||
String replaceValue = getReplaceValue(valueMap, String.valueOf(rightStringValue.getValue()));
|
||||
if (StringUtils.isNotEmpty(replaceValue)) {
|
||||
rightStringValue.setValue(Long.parseLong(replaceValue));
|
||||
}
|
||||
}
|
||||
if (rightExpression instanceof DoubleValue) {
|
||||
DoubleValue rightStringValue = (DoubleValue) rightExpression;
|
||||
String replaceValue = getReplaceValue(valueMap, String.valueOf(rightStringValue.getValue()));
|
||||
if (StringUtils.isNotEmpty(replaceValue)) {
|
||||
rightStringValue.setValue(Double.parseDouble(replaceValue));
|
||||
}
|
||||
}
|
||||
if (rightExpression instanceof StringValue) {
|
||||
StringValue rightStringValue = (StringValue) rightExpression;
|
||||
String replaceValue = getReplaceValue(valueMap, String.valueOf(rightStringValue.getValue()));
|
||||
if (StringUtils.isNotEmpty(replaceValue)) {
|
||||
rightStringValue.setValue(replaceValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private String getReplaceValue(Map<String, String> valueMap, String beforeValue) {
|
||||
String afterValue = valueMap.get(String.valueOf(beforeValue));
|
||||
if (StringUtils.isEmpty(afterValue) && !exactReplace) {
|
||||
return parseVisitorHelper.getReplaceValue(beforeValue, valueMap, false);
|
||||
}
|
||||
return afterValue;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ public class GroupByReplaceVisitor implements GroupByVisitor {
|
||||
((Function) expression).getParameters().getExpressions().get(0))) {
|
||||
columnName = ((Function) expression).getParameters().getExpressions().get(0).toString();
|
||||
}
|
||||
String replaceColumn = parseVisitorHelper.getReplaceColumn(columnName, fieldNameMap,
|
||||
String replaceColumn = parseVisitorHelper.getReplaceValue(columnName, fieldNameMap,
|
||||
exactReplace);
|
||||
if (StringUtils.isNotEmpty(replaceColumn)) {
|
||||
if (expression instanceof Column) {
|
||||
|
||||
@@ -13,32 +13,32 @@ public class ParseVisitorHelper {
|
||||
|
||||
public void replaceColumn(Column column, Map<String, String> fieldNameMap, boolean exactReplace) {
|
||||
String columnName = column.getColumnName();
|
||||
String replaceColumn = getReplaceColumn(columnName, fieldNameMap, exactReplace);
|
||||
String replaceColumn = getReplaceValue(columnName, fieldNameMap, exactReplace);
|
||||
if (StringUtils.isNotBlank(replaceColumn)) {
|
||||
column.setColumnName(replaceColumn);
|
||||
}
|
||||
}
|
||||
|
||||
public String getReplaceColumn(String columnName, Map<String, String> fieldNameMap, boolean exactReplace) {
|
||||
String fieldName = fieldNameMap.get(columnName);
|
||||
if (StringUtils.isNotBlank(fieldName)) {
|
||||
return fieldName;
|
||||
public String getReplaceValue(String beforeValue, Map<String, String> valueMap, boolean exactReplace) {
|
||||
String value = valueMap.get(beforeValue);
|
||||
if (StringUtils.isNotBlank(value)) {
|
||||
return value;
|
||||
}
|
||||
if (exactReplace) {
|
||||
return null;
|
||||
}
|
||||
Optional<Entry<String, String>> first = fieldNameMap.entrySet().stream().sorted((k1, k2) -> {
|
||||
String k1FieldNameDb = k1.getKey();
|
||||
String k2FieldNameDb = k2.getKey();
|
||||
Double k1Similarity = getSimilarity(columnName, k1FieldNameDb);
|
||||
Double k2Similarity = getSimilarity(columnName, k2FieldNameDb);
|
||||
Optional<Entry<String, String>> first = valueMap.entrySet().stream().sorted((k1, k2) -> {
|
||||
String k1Value = k1.getKey();
|
||||
String k2Value = k2.getKey();
|
||||
Double k1Similarity = getSimilarity(beforeValue, k1Value);
|
||||
Double k2Similarity = getSimilarity(beforeValue, k2Value);
|
||||
return k2Similarity.compareTo(k1Similarity);
|
||||
}).collect(Collectors.toList()).stream().findFirst();
|
||||
|
||||
if (first.isPresent()) {
|
||||
return first.get().getValue();
|
||||
}
|
||||
return columnName;
|
||||
return beforeValue;
|
||||
}
|
||||
|
||||
public static int editDistance(String word1, String word2) {
|
||||
|
||||
@@ -39,6 +39,11 @@ import org.springframework.util.CollectionUtils;
|
||||
public class SqlParserUpdateHelper {
|
||||
|
||||
public static String replaceValue(String sql, Map<String, Map<String, String>> filedNameToValueMap) {
|
||||
return replaceValue(sql, filedNameToValueMap, true);
|
||||
}
|
||||
|
||||
public static String replaceValue(String sql, Map<String, Map<String, String>> filedNameToValueMap,
|
||||
boolean exactReplace) {
|
||||
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
|
||||
SelectBody selectBody = selectStatement.getSelectBody();
|
||||
if (!(selectBody instanceof PlainSelect)) {
|
||||
@@ -46,7 +51,7 @@ public class SqlParserUpdateHelper {
|
||||
}
|
||||
PlainSelect plainSelect = (PlainSelect) selectBody;
|
||||
Expression where = plainSelect.getWhere();
|
||||
FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(filedNameToValueMap);
|
||||
FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(exactReplace, filedNameToValueMap);
|
||||
if (Objects.nonNull(where)) {
|
||||
where.accept(visitor);
|
||||
}
|
||||
|
||||
@@ -18,6 +18,47 @@ import org.junit.jupiter.api.Test;
|
||||
*/
|
||||
class SqlParserUpdateHelperTest {
|
||||
|
||||
@Test
|
||||
void replaceValue() {
|
||||
|
||||
String replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 "
|
||||
+ "and 歌手名 = '杰伦' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'"
|
||||
+ " order by 播放量 desc limit 11";
|
||||
|
||||
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
|
||||
|
||||
Map<String, String> valueMap = new HashMap<>();
|
||||
valueMap.put("杰伦", "周杰伦");
|
||||
filedNameToValueMap.put("歌手名", valueMap);
|
||||
|
||||
replaceSql = SqlParserUpdateHelper.replaceValue(replaceSql, filedNameToValueMap);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND "
|
||||
+ "歌手名 = '周杰伦' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' "
|
||||
+ "ORDER BY 播放量 DESC LIMIT 11", replaceSql);
|
||||
|
||||
replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 "
|
||||
+ "and 歌手名 = '周杰' and 歌手名 = '林俊' and 歌手名 = '陈' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'"
|
||||
+ " order by 播放量 desc limit 11";
|
||||
|
||||
Map<String, Map<String, String>> filedNameToValueMap2 = new HashMap<>();
|
||||
|
||||
Map<String, String> valueMap2 = new HashMap<>();
|
||||
valueMap2.put("周杰伦", "周杰伦");
|
||||
valueMap2.put("林俊杰", "林俊杰");
|
||||
valueMap2.put("陈奕迅", "陈奕迅");
|
||||
filedNameToValueMap2.put("歌手名", valueMap2);
|
||||
|
||||
replaceSql = SqlParserUpdateHelper.replaceValue(replaceSql, filedNameToValueMap2, false);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' "
|
||||
+ "AND 歌手名 = '林俊杰' AND 歌手名 = '陈奕迅' AND 数据日期 = '2023-08-09' AND "
|
||||
+ "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
void replaceFieldNameByValue() {
|
||||
|
||||
Reference in New Issue
Block a user