diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java index cdb38d81a..abb1b11a4 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java @@ -4,9 +4,9 @@ package com.tencent.supersonic.chat.server.agent; import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.VisualConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.RecordInfo; import lombok.Data; import org.springframework.util.CollectionUtils; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java index 5ef9f9867..3a95dd763 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java @@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.util.LLMConnHelper; -import com.tencent.supersonic.common.config.ChatModelConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.DeleteMapping; import org.springframework.web.bind.annotation.PathVariable; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index 91d9ecdce..f7523d786 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -12,9 +12,9 @@ import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.ChatQueryService; import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.util.LLMConnHelper; -import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.VisualConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.util.JsonUtil; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.BeanUtils; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java index 40c8de89f..e0506079d 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java @@ -1,6 +1,6 @@ package com.tencent.supersonic.chat.server.util; -import com.tencent.supersonic.common.config.ChatModelConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.provider.ModelProvider; diff --git a/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java new file mode 100644 index 000000000..0d6157f92 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java @@ -0,0 +1,73 @@ +package com.tencent.supersonic.common.config; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.pojo.Parameter; +import dev.langchain4j.provider.OllamaModelFactory; +import dev.langchain4j.provider.OpenAiModelFactory; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.util.List; + +@Service("ChatModelParameterConfig") +@Slf4j +public class ChatModelParameterConfig extends ParameterConfig { + + public static final Parameter CHAT_MODEL_PROVIDER = + new Parameter("s2.chat.model.provider", OpenAiModelFactory.PROVIDER, + "接口协议", "", + "string", "对话模型配置", + Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER)); + + public static final Parameter CHAT_MODEL_BASE_URL = + new Parameter("s2.chat.model.base.url", "", + "BaseUrl", "", + "string", "对话模型配置"); + + public static final Parameter CHAT_MODEL_API_KEY = + new Parameter("s2.chat.model.api.key", "", + "ApiKey", "", + "string", "对话模型配置"); + + public static final Parameter CHAT_MODEL_NAME = + new Parameter("s2.chat.model.name", "", + "ModelName", "", + "string", "对话模型配置"); + + public static final Parameter CHAT_MODEL_TEMPERATURE = + new Parameter("s2.chat.model.temperature", "0.0", + "Temperature", "", + "number", "对话模型配置"); + + public static final Parameter CHAT_MODEL_TIMEOUT = + new Parameter("s2.chat.model.timeout", "60", + "超时时间(秒)", "", + "number", "对话模型配置"); + + @Override + public List getSysParameters() { + return Lists.newArrayList( + CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_API_KEY, + CHAT_MODEL_NAME, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT + ); + } + + public ChatModelConfig convert() { + String chatModelProvider = getParameterValue(CHAT_MODEL_PROVIDER); + String chatModelBaseUrl = getParameterValue(CHAT_MODEL_BASE_URL); + String chatModelApiKey = getParameterValue(CHAT_MODEL_API_KEY); + String chatModelName = getParameterValue(CHAT_MODEL_NAME); + String chatModelTemperature = getParameterValue(CHAT_MODEL_TEMPERATURE); + String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT); + + return ChatModelConfig.builder() + .provider(chatModelProvider) + .baseUrl(chatModelBaseUrl) + .apiKey(chatModelApiKey) + .modelName(chatModelName) + .temperature(Double.valueOf(chatModelTemperature)) + .timeOut(Long.valueOf(chatModelTimeout)) + .build(); + } +} diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java new file mode 100644 index 000000000..c403ecb03 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java @@ -0,0 +1,86 @@ +package com.tencent.supersonic.common.config; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; +import com.tencent.supersonic.common.pojo.Parameter; +import dev.langchain4j.provider.AzureModelFactory; +import dev.langchain4j.provider.DashscopeModelFactory; +import dev.langchain4j.provider.InMemoryModelFactory; +import dev.langchain4j.provider.OllamaModelFactory; +import dev.langchain4j.provider.OpenAiModelFactory; +import dev.langchain4j.provider.QianfanModelFactory; +import dev.langchain4j.provider.ZhipuModelFactory; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.util.List; + +@Service("EmbeddingModelConfig") +@Slf4j +public class EmbeddingModelParameterConfig extends ParameterConfig { + + public static final Parameter EMBEDDING_MODEL_PROVIDER = + new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER, + "接口协议", "", + "string", "向量模型配置", + Lists.newArrayList(InMemoryModelFactory.PROVIDER, + OpenAiModelFactory.PROVIDER, + OllamaModelFactory.PROVIDER, + AzureModelFactory.PROVIDER, + DashscopeModelFactory.PROVIDER, + QianfanModelFactory.PROVIDER, + ZhipuModelFactory.PROVIDER)); + + public static final Parameter EMBEDDING_MODEL_BASE_URL = + new Parameter("s2.embedding.model.base.url", "", + "BaseUrl", "", + "string", "向量模型配置"); + + public static final Parameter EMBEDDING_MODEL_API_KEY = + new Parameter("s2.embedding.model.api.key", "", + "ApiKey", "", + "string", "向量模型配置"); + + + public static final Parameter EMBEDDING_MODEL_NAME = + new Parameter("s2.embedding.model.name", "", + "ModelName", "", + "string", "向量模型配置"); + + public static final Parameter EMBEDDING_MODEL_PATH = + new Parameter("s2.embedding.model.path", "", + "模型路径", "", + "string", "向量模型配置"); + + public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH = + new Parameter("s2.embedding.model.vocabulary.path", "", + "词汇表路径", "", + "string", "向量模型配置"); + + @Override + public List getSysParameters() { + return Lists.newArrayList( + EMBEDDING_MODEL_PROVIDER, EMBEDDING_MODEL_BASE_URL, EMBEDDING_MODEL_API_KEY, + EMBEDDING_MODEL_NAME, EMBEDDING_MODEL_PATH, EMBEDDING_MODEL_VOCABULARY_PATH + ); + } + + public EmbeddingModelConfig convert() { + String provider = getParameterValue(EMBEDDING_MODEL_PROVIDER); + String baseUrl = getParameterValue(EMBEDDING_MODEL_BASE_URL); + String apiKey = getParameterValue(EMBEDDING_MODEL_API_KEY); + String modelName = getParameterValue(EMBEDDING_MODEL_NAME); + String modelPath = getParameterValue(EMBEDDING_MODEL_PATH); + String vocabularyPath = getParameterValue(EMBEDDING_MODEL_VOCABULARY_PATH); + + return EmbeddingModelConfig.builder() + .provider(provider) + .baseUrl(baseUrl) + .apiKey(apiKey) + .modelName(modelName) + .modelPath(modelPath) + .vocabularyPath(vocabularyPath) + .build(); + } + +} diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java new file mode 100644 index 000000000..5884f28b6 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java @@ -0,0 +1,62 @@ +package com.tencent.supersonic.common.config; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig; +import com.tencent.supersonic.common.pojo.Parameter; +import dev.langchain4j.provider.InMemoryModelFactory; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.util.List; + +@Service("EmbeddingStoreParameterConfig") +@Slf4j +public class EmbeddingStoreParameterConfig extends ParameterConfig { + public static final Parameter EMBEDDING_STORE_PROVIDER = + new Parameter("s2.embedding.store.provider", InMemoryModelFactory.PROVIDER, + "向量库类型", "", + "string", "向量库配置"); + + public static final Parameter EMBEDDING_STORE_BASE_URL = + new Parameter("s2.embedding.store.base.url", "", + "BaseUrl", "", + "string", "向量库配置"); + + public static final Parameter EMBEDDING_STORE_API_KEY = + new Parameter("s2.embedding.store.api.key", "", + "ApiKey", "", + "string", "向量库配置"); + public static final Parameter EMBEDDING_STORE_PERSIST_PATH = + new Parameter("s2.embedding.store.persist.path", "/tmp", + "持久化路径", "", + "string", "向量库配置"); + + public static final Parameter EMBEDDING_STORE_TIMEOUT = + new Parameter("s2.embedding.store.timeout", "60", + "超时时间(秒)", "", + "number", "向量库配置"); + + @Override + public List getSysParameters() { + return Lists.newArrayList( + EMBEDDING_STORE_PROVIDER, EMBEDDING_STORE_BASE_URL, EMBEDDING_STORE_API_KEY, + EMBEDDING_STORE_PERSIST_PATH, EMBEDDING_STORE_TIMEOUT + ); + } + + public EmbeddingStoreConfig convert() { + String provider = getParameterValue(EMBEDDING_STORE_PROVIDER); + String baseUrl = getParameterValue(EMBEDDING_STORE_BASE_URL); + String apiKey = getParameterValue(EMBEDDING_STORE_API_KEY); + String persistPath = getParameterValue(EMBEDDING_STORE_PERSIST_PATH); + String timeOut = getParameterValue(EMBEDDING_STORE_TIMEOUT); + + return EmbeddingStoreConfig.builder() + .provider(provider) + .baseUrl(baseUrl) + .apiKey(apiKey) + .persistPath(persistPath) + .timeOut(Long.valueOf(timeOut)) + .build(); + } +} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/config/ChatModelConfig.java b/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelConfig.java similarity index 93% rename from common/src/main/java/com/tencent/supersonic/common/config/ChatModelConfig.java rename to common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelConfig.java index 4acd09a09..17711e0b5 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/ChatModelConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/ChatModelConfig.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.common.config; +package com.tencent.supersonic.common.pojo; import com.tencent.supersonic.common.util.AESEncryptionUtil; import lombok.AllArgsConstructor; diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelConfig.java b/common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingModelConfig.java similarity index 88% rename from common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelConfig.java rename to common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingModelConfig.java index 5a024df59..5ef3c8dd8 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingModelConfig.java @@ -1,12 +1,14 @@ -package com.tencent.supersonic.common.config; +package com.tencent.supersonic.common.pojo; import lombok.AllArgsConstructor; +import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; import java.io.Serializable; @Data +@Builder @AllArgsConstructor @NoArgsConstructor public class EmbeddingModelConfig implements Serializable { diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreConfig.java b/common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingStoreConfig.java similarity index 63% rename from common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreConfig.java rename to common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingStoreConfig.java index 5ff9f324c..e10b19e99 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingStoreConfig.java @@ -1,10 +1,16 @@ -package com.tencent.supersonic.common.config; +package com.tencent.supersonic.common.pojo; +import lombok.AllArgsConstructor; +import lombok.Builder; import lombok.Data; +import lombok.NoArgsConstructor; import java.io.Serializable; @Data +@Builder +@AllArgsConstructor +@NoArgsConstructor public class EmbeddingStoreConfig implements Serializable { private static final long serialVersionUID = 1L; private String provider; diff --git a/common/src/main/java/com/tencent/supersonic/common/config/ModelConfig.java b/common/src/main/java/com/tencent/supersonic/common/pojo/ModelConfig.java similarity index 88% rename from common/src/main/java/com/tencent/supersonic/common/config/ModelConfig.java rename to common/src/main/java/com/tencent/supersonic/common/pojo/ModelConfig.java index b1aea608f..ba5dddacc 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/ModelConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/ModelConfig.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.common.config; +package com.tencent.supersonic.common.pojo; import lombok.AllArgsConstructor; import lombok.Data; diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/Parameter.java b/common/src/main/java/com/tencent/supersonic/common/pojo/Parameter.java index 60c352705..570b588fc 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/Parameter.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/Parameter.java @@ -4,6 +4,7 @@ import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; import org.apache.commons.lang3.StringUtils; + import java.util.List; @Data @@ -21,12 +22,7 @@ public class Parameter { public Parameter(String name, String defaultValue, String comment, String description, String dataType, String module) { - this.name = name; - this.defaultValue = defaultValue; - this.comment = comment; - this.description = description; - this.dataType = dataType; - this.module = module; + this(name, defaultValue, comment, description, dataType, module, null); } public Parameter(String name, String defaultValue, String comment, String description, diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java index ac6a7333e..c9b203a55 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java @@ -2,10 +2,13 @@ package com.tencent.supersonic.common.service.impl; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; +import com.tencent.supersonic.common.config.EmbeddingModelParameterConfig; +import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import com.tencent.supersonic.common.service.EmbeddingService; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.provider.ModelProvider; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.EmbeddingSearchResult; @@ -19,6 +22,12 @@ import dev.langchain4j.store.embedding.filter.Filter; import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder; import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.MapUtils; +import org.apache.commons.collections4.CollectionUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; @@ -27,11 +36,6 @@ import java.util.Map; import java.util.Objects; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.collections.MapUtils; -import org.apache.commons.collections4.CollectionUtils; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; @Service @Slf4j @@ -41,7 +45,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { private EmbeddingStoreFactory embeddingStoreFactory; @Autowired - private EmbeddingModel embeddingModel; + private EmbeddingModelParameterConfig embeddingModelParameterConfig; private Cache cache = CacheBuilder.newBuilder() .maximumSize(10000) @@ -55,6 +59,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { for (TextSegment query : queries) { String question = query.text(); try { + EmbeddingModel embeddingModel = getEmbeddingModel(); Embedding embedding = embeddingModel.embed(question).content(); boolean existSegment = existSegment(embeddingStore, query, embedding); if (existSegment) { @@ -122,6 +127,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { List queryTextsList = retrieveQuery.getQueryTextsList(); Map filterCondition = retrieveQuery.getFilterCondition(); for (String queryText : queryTextsList) { + EmbeddingModel embeddingModel = getEmbeddingModel(); Embedding embeddedText = embeddingModel.embed(queryText).content(); Filter filter = createCombinedFilter(filterCondition); EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() @@ -169,4 +175,9 @@ public class EmbeddingServiceImpl implements EmbeddingService { } return result; } + + private EmbeddingModel getEmbeddingModel() { + EmbeddingModelConfig embeddingModelConfig = embeddingModelParameterConfig.convert(); + return ModelProvider.getEmbeddingModel(embeddingModelConfig); + } } diff --git a/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java index d22b19eed..d38270f0d 100644 --- a/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java @@ -1,6 +1,6 @@ package dev.langchain4j.chroma.spring; -import com.tencent.supersonic.common.config.EmbeddingStoreConfig; +import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig; import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore; diff --git a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java index 21e0690cf..91922074d 100644 --- a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java @@ -1,7 +1,7 @@ package dev.langchain4j.inmemory.spring; import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.common.config.EmbeddingStoreConfig; +import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig; import com.tencent.supersonic.common.util.ContextUtils; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory; diff --git a/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java index b3d86af5b..074a30f70 100644 --- a/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java @@ -1,6 +1,6 @@ package dev.langchain4j.milvus.spring; -import com.tencent.supersonic.common.config.EmbeddingStoreConfig; +import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory; import dev.langchain4j.store.embedding.EmbeddingStore; diff --git a/common/src/main/java/dev/langchain4j/provider/AzureModelFactory.java b/common/src/main/java/dev/langchain4j/provider/AzureModelFactory.java index ecf381090..9f23d54c1 100644 --- a/common/src/main/java/dev/langchain4j/provider/AzureModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/AzureModelFactory.java @@ -1,7 +1,7 @@ package dev.langchain4j.provider; -import com.tencent.supersonic.common.config.ChatModelConfig; -import com.tencent.supersonic.common.config.EmbeddingModelConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import dev.langchain4j.model.azure.AzureOpenAiChatModel; import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel; import dev.langchain4j.model.chat.ChatLanguageModel; diff --git a/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java b/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java index 6fcce7d8a..db391abbf 100644 --- a/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java @@ -1,7 +1,7 @@ package dev.langchain4j.provider; -import com.tencent.supersonic.common.config.ChatModelConfig; -import com.tencent.supersonic.common.config.EmbeddingModelConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.dashscope.QwenChatModel; import dev.langchain4j.model.dashscope.QwenEmbeddingModel; diff --git a/common/src/main/java/dev/langchain4j/provider/EmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/provider/EmbeddingStoreFactory.java index 82a3e2b17..fea4cf7ad 100644 --- a/common/src/main/java/dev/langchain4j/provider/EmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/EmbeddingStoreFactory.java @@ -1,6 +1,6 @@ package dev.langchain4j.provider; -import com.tencent.supersonic.common.config.EmbeddingStoreConfig; +import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig; import dev.langchain4j.store.embedding.EmbeddingStore; public interface EmbeddingStoreFactory { diff --git a/common/src/main/java/dev/langchain4j/provider/InMemoryModelFactory.java b/common/src/main/java/dev/langchain4j/provider/InMemoryModelFactory.java index 444b0d836..e532cff40 100644 --- a/common/src/main/java/dev/langchain4j/provider/InMemoryModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/InMemoryModelFactory.java @@ -1,7 +1,7 @@ package dev.langchain4j.provider; -import com.tencent.supersonic.common.config.ChatModelConfig; -import com.tencent.supersonic.common.config.EmbeddingModelConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel; diff --git a/common/src/main/java/dev/langchain4j/provider/LocalAiModelFactory.java b/common/src/main/java/dev/langchain4j/provider/LocalAiModelFactory.java index 496035061..ea95287cd 100644 --- a/common/src/main/java/dev/langchain4j/provider/LocalAiModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/LocalAiModelFactory.java @@ -1,7 +1,7 @@ package dev.langchain4j.provider; -import com.tencent.supersonic.common.config.ChatModelConfig; -import com.tencent.supersonic.common.config.EmbeddingModelConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.localai.LocalAiChatModel; diff --git a/common/src/main/java/dev/langchain4j/provider/ModelFactory.java b/common/src/main/java/dev/langchain4j/provider/ModelFactory.java index 660ae13d6..b17e243bc 100644 --- a/common/src/main/java/dev/langchain4j/provider/ModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/ModelFactory.java @@ -1,7 +1,7 @@ package dev.langchain4j.provider; -import com.tencent.supersonic.common.config.ChatModelConfig; -import com.tencent.supersonic.common.config.EmbeddingModelConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.embedding.EmbeddingModel; diff --git a/common/src/main/java/dev/langchain4j/provider/ModelProvider.java b/common/src/main/java/dev/langchain4j/provider/ModelProvider.java index de0715506..ab3c8f142 100644 --- a/common/src/main/java/dev/langchain4j/provider/ModelProvider.java +++ b/common/src/main/java/dev/langchain4j/provider/ModelProvider.java @@ -1,8 +1,7 @@ package dev.langchain4j.provider; -import com.tencent.supersonic.common.config.ChatModelConfig; -import com.tencent.supersonic.common.config.EmbeddingModelConfig; -import com.tencent.supersonic.common.config.ModelConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import com.tencent.supersonic.common.util.ContextUtils; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.embedding.EmbeddingModel; @@ -10,7 +9,6 @@ import org.apache.commons.lang3.StringUtils; import java.util.HashMap; import java.util.Map; -import java.util.Objects; public class ModelProvider { private static final Map factories = new HashMap<>(); @@ -33,14 +31,10 @@ public class ModelProvider { throw new RuntimeException("Unsupported ChatLanguageModel provider: " + modelConfig.getProvider()); } - public static EmbeddingModel getEmbeddingModel(ModelConfig modelConfig) { - if (modelConfig == null || Objects.isNull(modelConfig.getEmbeddingModel()) - || StringUtils.isBlank(modelConfig.getEmbeddingModel().getBaseUrl()) - || StringUtils.isBlank(modelConfig.getEmbeddingModel().getProvider())) { + public static EmbeddingModel getEmbeddingModel(EmbeddingModelConfig embeddingModel) { + if (embeddingModel == null || StringUtils.isBlank(embeddingModel.getProvider())) { return ContextUtils.getBean(EmbeddingModel.class); } - EmbeddingModelConfig embeddingModel = modelConfig.getEmbeddingModel(); - ModelFactory modelFactory = factories.get(embeddingModel.getProvider().toUpperCase()); if (modelFactory != null) { return modelFactory.createEmbeddingModel(embeddingModel); diff --git a/common/src/main/java/dev/langchain4j/provider/OllamaModelFactory.java b/common/src/main/java/dev/langchain4j/provider/OllamaModelFactory.java index e027eae83..4eb24aa98 100644 --- a/common/src/main/java/dev/langchain4j/provider/OllamaModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/OllamaModelFactory.java @@ -1,7 +1,7 @@ package dev.langchain4j.provider; -import com.tencent.supersonic.common.config.ChatModelConfig; -import com.tencent.supersonic.common.config.EmbeddingModelConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.ollama.OllamaChatModel; diff --git a/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java b/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java index b48bae951..481293b22 100644 --- a/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java @@ -1,7 +1,7 @@ package dev.langchain4j.provider; -import com.tencent.supersonic.common.config.ChatModelConfig; -import com.tencent.supersonic.common.config.EmbeddingModelConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.openai.OpenAiChatModel; diff --git a/common/src/main/java/dev/langchain4j/provider/QianfanModelFactory.java b/common/src/main/java/dev/langchain4j/provider/QianfanModelFactory.java index 9051e243d..60bd411cb 100644 --- a/common/src/main/java/dev/langchain4j/provider/QianfanModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/QianfanModelFactory.java @@ -1,7 +1,7 @@ package dev.langchain4j.provider; -import com.tencent.supersonic.common.config.ChatModelConfig; -import com.tencent.supersonic.common.config.EmbeddingModelConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.qianfan.QianfanEmbeddingModel; diff --git a/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java b/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java index f2c4bfa0b..d37d26408 100644 --- a/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java @@ -1,7 +1,7 @@ package dev.langchain4j.provider; -import com.tencent.supersonic.common.config.ChatModelConfig; -import com.tencent.supersonic.common.config.EmbeddingModelConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel; diff --git a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java index f34f825e4..d3dc5876c 100644 --- a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java +++ b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java @@ -1,6 +1,6 @@ package dev.langchain4j.store.embedding; -import com.tencent.supersonic.common.config.EmbeddingStoreConfig; +import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig; import com.tencent.supersonic.common.util.ContextUtils; import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory; import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java index 1c9357b63..557fd03f3 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java @@ -3,8 +3,8 @@ package com.tencent.supersonic.headless.api.pojo.request; import com.google.common.collect.Lists; import com.google.common.collect.Sets; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.PromptConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.headless.api.pojo.QueryDataType; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java index e2ca5c6ea..c35abc0c0 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java @@ -2,17 +2,17 @@ package com.tencent.supersonic.headless.chat; import com.fasterxml.jackson.annotation.JsonIgnore; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.PromptConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; +import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; -import com.tencent.supersonic.common.pojo.SqlExemplar; -import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState; +import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.chat.parser.ParserConfig; import com.tencent.supersonic.headless.chat.query.SemanticQuery; @@ -21,7 +21,6 @@ import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; - import java.util.ArrayList; import java.util.Comparator; import java.util.List; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java index 4faa118ee..b3642d85e 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java @@ -57,15 +57,15 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { @Override protected void detectByBatch(ChatQueryContext chatQueryContext, Set results, Set detectDataSetIds, Set detectSegments) { - int embedddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MIN)); - int embedddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MAX)); + int embeddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MIN)); + int embeddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MAX)); int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH)); List queryTextsList = detectSegments.stream() .map(detectSegment -> detectSegment.trim()) .filter(detectSegment -> StringUtils.isNotBlank(detectSegment) - && detectSegment.length() >= embedddingMapperMin - && detectSegment.length() <= embedddingMapperMax) + && detectSegment.length() >= embeddingMapperMin + && detectSegment.length() <= embeddingMapperMax) .collect(Collectors.toList()); List> queryTextsSubList = Lists.partition(queryTextsList, diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java index d761ca955..f65b55de9 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java @@ -1,6 +1,6 @@ package com.tencent.supersonic.headless.chat.parser.llm; -import com.tencent.supersonic.common.config.ChatModelConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import dev.langchain4j.model.chat.ChatLanguageModel; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java index d9fe0b1de..901557216 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java @@ -2,8 +2,8 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql; import com.fasterxml.jackson.annotation.JsonValue; import com.google.common.collect.Lists; -import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.PromptConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import lombok.Data; diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java index 60db2d439..bad4a948d 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java @@ -8,7 +8,7 @@ import com.tencent.supersonic.chat.server.agent.AgentConfig; import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.agent.RuleParserTool; -import com.tencent.supersonic.common.config.ChatModelConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.util.DataUtils; import org.junit.jupiter.api.BeforeAll;