mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
(improvement)(chat) Integrate llm configuration into system settings. (#1403)
This commit is contained in:
@@ -4,9 +4,9 @@ package com.tencent.supersonic.chat.server.agent;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.config.VisualConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.DeleteMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
|
||||
@@ -12,9 +12,9 @@ import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.config.VisualConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import dev.langchain4j.provider.OllamaModelFactory;
|
||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Service("ChatModelParameterConfig")
|
||||
@Slf4j
|
||||
public class ChatModelParameterConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter CHAT_MODEL_PROVIDER =
|
||||
new Parameter("s2.chat.model.provider", OpenAiModelFactory.PROVIDER,
|
||||
"接口协议", "",
|
||||
"string", "对话模型配置",
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER));
|
||||
|
||||
public static final Parameter CHAT_MODEL_BASE_URL =
|
||||
new Parameter("s2.chat.model.base.url", "",
|
||||
"BaseUrl", "",
|
||||
"string", "对话模型配置");
|
||||
|
||||
public static final Parameter CHAT_MODEL_API_KEY =
|
||||
new Parameter("s2.chat.model.api.key", "",
|
||||
"ApiKey", "",
|
||||
"string", "对话模型配置");
|
||||
|
||||
public static final Parameter CHAT_MODEL_NAME =
|
||||
new Parameter("s2.chat.model.name", "",
|
||||
"ModelName", "",
|
||||
"string", "对话模型配置");
|
||||
|
||||
public static final Parameter CHAT_MODEL_TEMPERATURE =
|
||||
new Parameter("s2.chat.model.temperature", "0.0",
|
||||
"Temperature", "",
|
||||
"number", "对话模型配置");
|
||||
|
||||
public static final Parameter CHAT_MODEL_TIMEOUT =
|
||||
new Parameter("s2.chat.model.timeout", "60",
|
||||
"超时时间(秒)", "",
|
||||
"number", "对话模型配置");
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(
|
||||
CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_API_KEY,
|
||||
CHAT_MODEL_NAME, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT
|
||||
);
|
||||
}
|
||||
|
||||
public ChatModelConfig convert() {
|
||||
String chatModelProvider = getParameterValue(CHAT_MODEL_PROVIDER);
|
||||
String chatModelBaseUrl = getParameterValue(CHAT_MODEL_BASE_URL);
|
||||
String chatModelApiKey = getParameterValue(CHAT_MODEL_API_KEY);
|
||||
String chatModelName = getParameterValue(CHAT_MODEL_NAME);
|
||||
String chatModelTemperature = getParameterValue(CHAT_MODEL_TEMPERATURE);
|
||||
String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT);
|
||||
|
||||
return ChatModelConfig.builder()
|
||||
.provider(chatModelProvider)
|
||||
.baseUrl(chatModelBaseUrl)
|
||||
.apiKey(chatModelApiKey)
|
||||
.modelName(chatModelName)
|
||||
.temperature(Double.valueOf(chatModelTemperature))
|
||||
.timeOut(Long.valueOf(chatModelTimeout))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import dev.langchain4j.provider.AzureModelFactory;
|
||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||
import dev.langchain4j.provider.InMemoryModelFactory;
|
||||
import dev.langchain4j.provider.OllamaModelFactory;
|
||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||
import dev.langchain4j.provider.QianfanModelFactory;
|
||||
import dev.langchain4j.provider.ZhipuModelFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Service("EmbeddingModelConfig")
|
||||
@Slf4j
|
||||
public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_PROVIDER =
|
||||
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER,
|
||||
"接口协议", "",
|
||||
"string", "向量模型配置",
|
||||
Lists.newArrayList(InMemoryModelFactory.PROVIDER,
|
||||
OpenAiModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER));
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_BASE_URL =
|
||||
new Parameter("s2.embedding.model.base.url", "",
|
||||
"BaseUrl", "",
|
||||
"string", "向量模型配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_API_KEY =
|
||||
new Parameter("s2.embedding.model.api.key", "",
|
||||
"ApiKey", "",
|
||||
"string", "向量模型配置");
|
||||
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_NAME =
|
||||
new Parameter("s2.embedding.model.name", "",
|
||||
"ModelName", "",
|
||||
"string", "向量模型配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_PATH =
|
||||
new Parameter("s2.embedding.model.path", "",
|
||||
"模型路径", "",
|
||||
"string", "向量模型配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH =
|
||||
new Parameter("s2.embedding.model.vocabulary.path", "",
|
||||
"词汇表路径", "",
|
||||
"string", "向量模型配置");
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(
|
||||
EMBEDDING_MODEL_PROVIDER, EMBEDDING_MODEL_BASE_URL, EMBEDDING_MODEL_API_KEY,
|
||||
EMBEDDING_MODEL_NAME, EMBEDDING_MODEL_PATH, EMBEDDING_MODEL_VOCABULARY_PATH
|
||||
);
|
||||
}
|
||||
|
||||
public EmbeddingModelConfig convert() {
|
||||
String provider = getParameterValue(EMBEDDING_MODEL_PROVIDER);
|
||||
String baseUrl = getParameterValue(EMBEDDING_MODEL_BASE_URL);
|
||||
String apiKey = getParameterValue(EMBEDDING_MODEL_API_KEY);
|
||||
String modelName = getParameterValue(EMBEDDING_MODEL_NAME);
|
||||
String modelPath = getParameterValue(EMBEDDING_MODEL_PATH);
|
||||
String vocabularyPath = getParameterValue(EMBEDDING_MODEL_VOCABULARY_PATH);
|
||||
|
||||
return EmbeddingModelConfig.builder()
|
||||
.provider(provider)
|
||||
.baseUrl(baseUrl)
|
||||
.apiKey(apiKey)
|
||||
.modelName(modelName)
|
||||
.modelPath(modelPath)
|
||||
.vocabularyPath(vocabularyPath)
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import dev.langchain4j.provider.InMemoryModelFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Service("EmbeddingStoreParameterConfig")
|
||||
@Slf4j
|
||||
public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
||||
public static final Parameter EMBEDDING_STORE_PROVIDER =
|
||||
new Parameter("s2.embedding.store.provider", InMemoryModelFactory.PROVIDER,
|
||||
"向量库类型", "",
|
||||
"string", "向量库配置");
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_BASE_URL =
|
||||
new Parameter("s2.embedding.store.base.url", "",
|
||||
"BaseUrl", "",
|
||||
"string", "向量库配置");
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_API_KEY =
|
||||
new Parameter("s2.embedding.store.api.key", "",
|
||||
"ApiKey", "",
|
||||
"string", "向量库配置");
|
||||
public static final Parameter EMBEDDING_STORE_PERSIST_PATH =
|
||||
new Parameter("s2.embedding.store.persist.path", "/tmp",
|
||||
"持久化路径", "",
|
||||
"string", "向量库配置");
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_TIMEOUT =
|
||||
new Parameter("s2.embedding.store.timeout", "60",
|
||||
"超时时间(秒)", "",
|
||||
"number", "向量库配置");
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(
|
||||
EMBEDDING_STORE_PROVIDER, EMBEDDING_STORE_BASE_URL, EMBEDDING_STORE_API_KEY,
|
||||
EMBEDDING_STORE_PERSIST_PATH, EMBEDDING_STORE_TIMEOUT
|
||||
);
|
||||
}
|
||||
|
||||
public EmbeddingStoreConfig convert() {
|
||||
String provider = getParameterValue(EMBEDDING_STORE_PROVIDER);
|
||||
String baseUrl = getParameterValue(EMBEDDING_STORE_BASE_URL);
|
||||
String apiKey = getParameterValue(EMBEDDING_STORE_API_KEY);
|
||||
String persistPath = getParameterValue(EMBEDDING_STORE_PERSIST_PATH);
|
||||
String timeOut = getParameterValue(EMBEDDING_STORE_TIMEOUT);
|
||||
|
||||
return EmbeddingStoreConfig.builder()
|
||||
.provider(provider)
|
||||
.baseUrl(baseUrl)
|
||||
.apiKey(apiKey)
|
||||
.persistPath(persistPath)
|
||||
.timeOut(Long.valueOf(timeOut))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
package com.tencent.supersonic.common.pojo;
|
||||
|
||||
import com.tencent.supersonic.common.util.AESEncryptionUtil;
|
||||
import lombok.AllArgsConstructor;
|
||||
@@ -1,12 +1,14 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
package com.tencent.supersonic.common.pojo;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class EmbeddingModelConfig implements Serializable {
|
||||
@@ -1,10 +1,16 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
package com.tencent.supersonic.common.pojo;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class EmbeddingStoreConfig implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
private String provider;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
package com.tencent.supersonic.common.pojo;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
@@ -4,6 +4,7 @@ import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@@ -21,12 +22,7 @@ public class Parameter {
|
||||
|
||||
public Parameter(String name, String defaultValue, String comment,
|
||||
String description, String dataType, String module) {
|
||||
this.name = name;
|
||||
this.defaultValue = defaultValue;
|
||||
this.comment = comment;
|
||||
this.description = description;
|
||||
this.dataType = dataType;
|
||||
this.module = module;
|
||||
this(name, defaultValue, comment, description, dataType, module, null);
|
||||
}
|
||||
|
||||
public Parameter(String name, String defaultValue, String comment, String description,
|
||||
|
||||
@@ -2,10 +2,13 @@ package com.tencent.supersonic.common.service.impl;
|
||||
|
||||
import com.google.common.cache.Cache;
|
||||
import com.google.common.cache.CacheBuilder;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelParameterConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||
@@ -19,6 +22,12 @@ import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
||||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
@@ -27,11 +36,6 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@@ -41,7 +45,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
private EmbeddingStoreFactory embeddingStoreFactory;
|
||||
|
||||
@Autowired
|
||||
private EmbeddingModel embeddingModel;
|
||||
private EmbeddingModelParameterConfig embeddingModelParameterConfig;
|
||||
|
||||
private Cache<String, Boolean> cache = CacheBuilder.newBuilder()
|
||||
.maximumSize(10000)
|
||||
@@ -55,6 +59,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
for (TextSegment query : queries) {
|
||||
String question = query.text();
|
||||
try {
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel();
|
||||
Embedding embedding = embeddingModel.embed(question).content();
|
||||
boolean existSegment = existSegment(embeddingStore, query, embedding);
|
||||
if (existSegment) {
|
||||
@@ -122,6 +127,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
||||
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
||||
for (String queryText : queryTextsList) {
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel();
|
||||
Embedding embeddedText = embeddingModel.embed(queryText).content();
|
||||
Filter filter = createCombinedFilter(filterCondition);
|
||||
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||
@@ -169,4 +175,9 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private EmbeddingModel getEmbeddingModel() {
|
||||
EmbeddingModelConfig embeddingModelConfig = embeddingModelParameterConfig.convert();
|
||||
return ModelProvider.getEmbeddingModel(embeddingModelConfig);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package dev.langchain4j.chroma.spring;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package dev.langchain4j.inmemory.spring;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package dev.langchain4j.milvus.spring;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.azure.AzureOpenAiChatModel;
|
||||
import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.dashscope.QwenChatModel;
|
||||
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
|
||||
public interface EmbeddingStoreFactory {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.localai.LocalAiChatModel;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.config.ModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
@@ -10,7 +9,6 @@ import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class ModelProvider {
|
||||
private static final Map<String, ModelFactory> factories = new HashMap<>();
|
||||
@@ -33,14 +31,10 @@ public class ModelProvider {
|
||||
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + modelConfig.getProvider());
|
||||
}
|
||||
|
||||
public static EmbeddingModel getEmbeddingModel(ModelConfig modelConfig) {
|
||||
if (modelConfig == null || Objects.isNull(modelConfig.getEmbeddingModel())
|
||||
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getBaseUrl())
|
||||
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getProvider())) {
|
||||
public static EmbeddingModel getEmbeddingModel(EmbeddingModelConfig embeddingModel) {
|
||||
if (embeddingModel == null || StringUtils.isBlank(embeddingModel.getProvider())) {
|
||||
return ContextUtils.getBean(EmbeddingModel.class);
|
||||
}
|
||||
EmbeddingModelConfig embeddingModel = modelConfig.getEmbeddingModel();
|
||||
|
||||
ModelFactory modelFactory = factories.get(embeddingModel.getProvider().toUpperCase());
|
||||
if (modelFactory != null) {
|
||||
return modelFactory.createEmbeddingModel(embeddingModel);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.ollama.OllamaChatModel;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory;
|
||||
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
|
||||
|
||||
@@ -3,8 +3,8 @@ package com.tencent.supersonic.headless.api.pojo.request;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
|
||||
@@ -2,17 +2,17 @@ package com.tencent.supersonic.headless.chat;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
@@ -21,7 +21,6 @@ import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
||||
@@ -57,15 +57,15 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
@Override
|
||||
protected void detectByBatch(ChatQueryContext chatQueryContext, Set<EmbeddingResult> results,
|
||||
Set<Long> detectDataSetIds, Set<String> detectSegments) {
|
||||
int embedddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MIN));
|
||||
int embedddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MAX));
|
||||
int embeddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MIN));
|
||||
int embeddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MAX));
|
||||
int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
|
||||
|
||||
List<String> queryTextsList = detectSegments.stream()
|
||||
.map(detectSegment -> detectSegment.trim())
|
||||
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
|
||||
&& detectSegment.length() >= embedddingMapperMin
|
||||
&& detectSegment.length() <= embedddingMapperMax)
|
||||
&& detectSegment.length() >= embeddingMapperMin
|
||||
&& detectSegment.length() <= embeddingMapperMax)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
|
||||
@@ -2,8 +2,8 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonValue;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import lombok.Data;
|
||||
|
||||
@@ -8,7 +8,7 @@ import com.tencent.supersonic.chat.server.agent.AgentConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
|
||||
Reference in New Issue
Block a user