diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/HeuristicModelResolver.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/HeuristicModelResolver.java index b3a827f28..5a3f3b5c8 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/HeuristicModelResolver.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/HeuristicModelResolver.java @@ -66,9 +66,10 @@ public class HeuristicModelResolver implements ModelResolver { if (Objects.nonNull(modelElementMatches)) { for (Entry> modelElementMatch : modelElementMatches.entrySet()) { Long modelId = modelElementMatch.getKey(); - List modelMatchesScore = modelElementMatch.getValue().stream().filter( - elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType())) - .map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList()); + List modelMatchesScore = modelElementMatch.getValue().stream() + .filter(elementMatch -> elementMatch.getSimilarity() >= 1) + .filter(elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType())) + .map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList()); if (!CollectionUtils.isEmpty(modelMatchesScore)) { // get sum of model match score diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java index 1c1189a94..dcdd3c83f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java @@ -213,7 +213,6 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable { return queryResult; } - @Override public ExplainResp explain(User user) { ExplainSqlReq explainSqlReq = null; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/QueryReqBuilder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/QueryReqBuilder.java index 6f61dc241..83a0dd8f4 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/QueryReqBuilder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/QueryReqBuilder.java @@ -12,12 +12,14 @@ import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; import com.tencent.supersonic.semantic.api.query.pojo.Filter; -import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq; import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq; +import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import java.time.LocalDate; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; @@ -25,6 +27,19 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.LongValue; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.schema.Column; +import net.sf.jsqlparser.schema.Table; +import net.sf.jsqlparser.statement.select.GroupByElement; +import net.sf.jsqlparser.statement.select.Limit; +import net.sf.jsqlparser.statement.select.OrderByElement; +import net.sf.jsqlparser.statement.select.PlainSelect; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.statement.select.SelectExpressionItem; +import net.sf.jsqlparser.statement.select.SelectItem; import org.apache.logging.log4j.util.Strings; import org.springframework.beans.BeanUtils; import org.springframework.util.CollectionUtils; @@ -138,6 +153,65 @@ public class QueryReqBuilder { return queryS2QLReq; } + /** + * convert queryStructReq to QueryS2QLReq + * + * @param queryStructReq + * @return + */ + public static QueryS2QLReq buildS2QLReq(QueryStructReq queryStructReq) { + + Select select = new Select(); + + // Set the select items (columns) + PlainSelect plainSelect = new PlainSelect(); + List selectItems = new ArrayList<>(); + + if(queryStructReq.getNativeQuery()){ + + + } + + selectItems.add(new SelectExpressionItem(new Column("column1"))); + selectItems.add(new SelectExpressionItem(new Column("column2"))); + plainSelect.setSelectItems(selectItems); + + // Set the table name + Table table = new Table("table1"); + plainSelect.setFromItem(table); + + // Set the order by clause + OrderByElement orderByElement = new OrderByElement(); + orderByElement.setExpression(new Column("column1")); + plainSelect.setOrderByElements(Collections.singletonList(orderByElement)); + + // Set the group by clause + GroupByElement groupByElement = new GroupByElement(); + groupByElement.addGroupByExpression(new Column("column1")); + plainSelect.setGroupByElement(groupByElement); + + // Set the having clause + Expression havingExpression = null; + try { + havingExpression = CCJSqlParserUtil.parseCondExpression("condition2"); + } catch (JSQLParserException e) { + log.error(""); + } + plainSelect.setHaving(havingExpression); + + // Set the limit clause + Limit limit = new Limit(); + limit.setRowCount(new LongValue(10)); + plainSelect.setLimit(limit); + + select.setSelectBody(plainSelect); + + QueryS2QLReq result = new QueryS2QLReq(); + result.setSql(select.toString()); + result.setModelId(queryStructReq.getModelId()); + result.setVariables(new HashMap<>()); + return result; + } private static List getAggregatorByMetric(AggregateTypeEnum aggregateType, SchemaElement metric) { List aggregators = new ArrayList<>(); @@ -233,4 +307,6 @@ public class QueryReqBuilder { queryStructCmd.setAggregators(aggregators); return queryStructCmd; } + + } diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/utils/QueryReqBuilderTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/utils/QueryReqBuilderTest.java new file mode 100644 index 000000000..42ba7c11b --- /dev/null +++ b/chat/core/src/test/java/com/tencent/supersonic/chat/utils/QueryReqBuilderTest.java @@ -0,0 +1,17 @@ +package com.tencent.supersonic.chat.utils; + + +import org.junit.jupiter.api.Test; + +/** + * QueryReqBuilderTest + */ +class QueryReqBuilderTest { + + @Test + void buildS2QLReq() { + + + + } +} \ No newline at end of file diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/utils/HanlpHelper.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/utils/HanlpHelper.java index b549a121c..07051e499 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/utils/HanlpHelper.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/utils/HanlpHelper.java @@ -9,6 +9,7 @@ import com.hankcs.hanlp.seg.Segment; import com.hankcs.hanlp.seg.common.Term; import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.knowledge.dictionary.DictWord; +import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; import java.util.Arrays; @@ -30,10 +31,8 @@ import org.springframework.util.ResourceUtils; @Slf4j public class HanlpHelper { - public static final String FILE_SPILT = "/"; + public static final String FILE_SPILT = File.separator; public static final String SPACE_SPILT = "#"; - public static final String DICT_MAIN_FILE_NAME = "CustomDictionary.txt"; - public static final String DICT_CLASS = "classes"; private static volatile DynamicCustomDictionary CustomDictionary; private static volatile Segment segment; diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DictWordType.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DictWordType.java index 7322555b7..b21e2c72d 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DictWordType.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DictWordType.java @@ -7,6 +7,7 @@ import org.apache.commons.lang3.StringUtils; * such as : metric、dimension etc. */ public enum DictWordType { + METRIC("metric"), DIMENSION("dimension"),