(improvement)(chat) modelId similarity >=1 and change hanlpHelper FILE_SPILT to fit windows (#284)

This commit is contained in:
lexluo09
2023-10-24 16:34:04 +08:00
committed by GitHub
parent e4e39e0496
commit 9a3c71df4a
6 changed files with 101 additions and 8 deletions

View File

@@ -66,9 +66,10 @@ public class HeuristicModelResolver implements ModelResolver {
if (Objects.nonNull(modelElementMatches)) {
for (Entry<Long, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
Long modelId = modelElementMatch.getKey();
List<Double> modelMatchesScore = modelElementMatch.getValue().stream().filter(
elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType()))
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
List<Double> modelMatchesScore = modelElementMatch.getValue().stream()
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
.filter(elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType()))
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
if (!CollectionUtils.isEmpty(modelMatchesScore)) {
// get sum of model match score

View File

@@ -213,7 +213,6 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
return queryResult;
}
@Override
public ExplainResp explain(User user) {
ExplainSqlReq explainSqlReq = null;

View File

@@ -12,12 +12,14 @@ import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.query.pojo.Filter;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
@@ -25,6 +27,19 @@ import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.GroupByElement;
import net.sf.jsqlparser.statement.select.Limit;
import net.sf.jsqlparser.statement.select.OrderByElement;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
import net.sf.jsqlparser.statement.select.SelectItem;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
@@ -138,6 +153,65 @@ public class QueryReqBuilder {
return queryS2QLReq;
}
/**
* convert queryStructReq to QueryS2QLReq
*
* @param queryStructReq
* @return
*/
public static QueryS2QLReq buildS2QLReq(QueryStructReq queryStructReq) {
Select select = new Select();
// Set the select items (columns)
PlainSelect plainSelect = new PlainSelect();
List<SelectItem> selectItems = new ArrayList<>();
if(queryStructReq.getNativeQuery()){
}
selectItems.add(new SelectExpressionItem(new Column("column1")));
selectItems.add(new SelectExpressionItem(new Column("column2")));
plainSelect.setSelectItems(selectItems);
// Set the table name
Table table = new Table("table1");
plainSelect.setFromItem(table);
// Set the order by clause
OrderByElement orderByElement = new OrderByElement();
orderByElement.setExpression(new Column("column1"));
plainSelect.setOrderByElements(Collections.singletonList(orderByElement));
// Set the group by clause
GroupByElement groupByElement = new GroupByElement();
groupByElement.addGroupByExpression(new Column("column1"));
plainSelect.setGroupByElement(groupByElement);
// Set the having clause
Expression havingExpression = null;
try {
havingExpression = CCJSqlParserUtil.parseCondExpression("condition2");
} catch (JSQLParserException e) {
log.error("");
}
plainSelect.setHaving(havingExpression);
// Set the limit clause
Limit limit = new Limit();
limit.setRowCount(new LongValue(10));
plainSelect.setLimit(limit);
select.setSelectBody(plainSelect);
QueryS2QLReq result = new QueryS2QLReq();
result.setSql(select.toString());
result.setModelId(queryStructReq.getModelId());
result.setVariables(new HashMap<>());
return result;
}
private static List<Aggregator> getAggregatorByMetric(AggregateTypeEnum aggregateType, SchemaElement metric) {
List<Aggregator> aggregators = new ArrayList<>();
@@ -233,4 +307,6 @@ public class QueryReqBuilder {
queryStructCmd.setAggregators(aggregators);
return queryStructCmd;
}
}

View File

@@ -0,0 +1,17 @@
package com.tencent.supersonic.chat.utils;
import org.junit.jupiter.api.Test;
/**
* QueryReqBuilderTest
*/
class QueryReqBuilderTest {
@Test
void buildS2QLReq() {
}
}

View File

@@ -9,6 +9,7 @@ import com.hankcs.hanlp.seg.Segment;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.Arrays;
@@ -30,10 +31,8 @@ import org.springframework.util.ResourceUtils;
@Slf4j
public class HanlpHelper {
public static final String FILE_SPILT = "/";
public static final String FILE_SPILT = File.separator;
public static final String SPACE_SPILT = "#";
public static final String DICT_MAIN_FILE_NAME = "CustomDictionary.txt";
public static final String DICT_CLASS = "classes";
private static volatile DynamicCustomDictionary CustomDictionary;
private static volatile Segment segment;

View File

@@ -7,6 +7,7 @@ import org.apache.commons.lang3.StringUtils;
* such as : metric、dimension etc.
*/
public enum DictWordType {
METRIC("metric"),
DIMENSION("dimension"),