diff --git a/common/src/main/java/dev/langchain4j/model/embedding/S2OnnxEmbeddingModel.java b/common/src/main/java/dev/langchain4j/model/embedding/S2OnnxEmbeddingModel.java index 4cba5b02e..5a376786f 100644 --- a/common/src/main/java/dev/langchain4j/model/embedding/S2OnnxEmbeddingModel.java +++ b/common/src/main/java/dev/langchain4j/model/embedding/S2OnnxEmbeddingModel.java @@ -52,7 +52,7 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel { try { return new OnnxBertBiEncoder( Files.newInputStream(pathToModel), - vocabularyFile.openStream(), + vocabularyFile, PoolingMode.MEAN ); } catch (IOException e) { diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java index b84779db8..011bb3bba 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java @@ -38,6 +38,7 @@ public class MetricTest extends BaseTest { expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE); expectedParseInfo.setAggType(NONE); + expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); @@ -70,7 +71,7 @@ public class MetricTest extends BaseTest { expectedResult.setQueryMode(MetricModelQuery.QUERY_MODE); expectedParseInfo.setAggType(NONE); - + expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay)); @@ -99,6 +100,7 @@ public class MetricTest extends BaseTest { expectedResult.setQueryMode(MetricGroupByQuery.QUERY_MODE); expectedParseInfo.setAggType(NONE); + expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门")); @@ -119,6 +121,7 @@ public class MetricTest extends BaseTest { expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE); expectedParseInfo.setAggType(NONE); + expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); List list = new ArrayList<>(); @@ -165,6 +168,7 @@ public class MetricTest extends BaseTest { expectedResult.setQueryMode(MetricGroupByQuery.QUERY_MODE); expectedParseInfo.setAggType(SUM); + expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门")); @@ -190,6 +194,7 @@ public class MetricTest extends BaseTest { expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE); expectedParseInfo.setAggType(NONE); + expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name", diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java index 55737e5df..839aeb8d6 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java @@ -27,6 +27,7 @@ public class MultiTurnsTest extends BaseTest { expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE); expectedParseInfo.setAggType(NONE); + expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); diff --git a/pom.xml b/pom.xml index a75cebc72..9baa0cdcf 100644 --- a/pom.xml +++ b/pom.xml @@ -67,6 +67,7 @@ 2.2.6 3.17 0.31.0 + 0.27.1 42.7.1 4.0.8 0.10.0 @@ -131,12 +132,12 @@ dev.langchain4j langchain4j-embeddings - ${langchain4j.version} + ${langchain4j.embedding.version} dev.langchain4j langchain4j-embeddings-bge-small-zh - ${langchain4j.version} + ${langchain4j.embedding.version} dev.langchain4j @@ -176,7 +177,7 @@ dev.langchain4j langchain4j-embeddings-all-minilm-l6-v2-q - ${langchain4j.version} + ${langchain4j.embedding.version} dev.langchain4j