(improvement)(chat) Integrate llm configuration into system settings. (#1403)

This commit is contained in:
lexluo09
2024-07-14 14:47:17 +08:00
committed by GitHub
parent 407c8d4702
commit 4eb6193699
34 changed files with 294 additions and 65 deletions

View File

@@ -4,9 +4,9 @@ package com.tencent.supersonic.chat.server.agent;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig; import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.RecordInfo; import com.tencent.supersonic.common.pojo.RecordInfo;
import lombok.Data; import lombok.Data;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.util.LLMConnHelper; import com.tencent.supersonic.chat.server.util.LLMConnHelper;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.DeleteMapping; import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PathVariable;

View File

@@ -12,9 +12,9 @@ import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatQueryService; import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.LLMConnHelper; import com.tencent.supersonic.chat.server.util.LLMConnHelper;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig; import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.server.util; package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.provider.ModelProvider; import dev.langchain4j.provider.ModelProvider;

View File

@@ -0,0 +1,73 @@
package com.tencent.supersonic.common.config;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import dev.langchain4j.provider.OllamaModelFactory;
import dev.langchain4j.provider.OpenAiModelFactory;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
@Service("ChatModelParameterConfig")
@Slf4j
public class ChatModelParameterConfig extends ParameterConfig {
public static final Parameter CHAT_MODEL_PROVIDER =
new Parameter("s2.chat.model.provider", OpenAiModelFactory.PROVIDER,
"接口协议", "",
"string", "对话模型配置",
Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER));
public static final Parameter CHAT_MODEL_BASE_URL =
new Parameter("s2.chat.model.base.url", "",
"BaseUrl", "",
"string", "对话模型配置");
public static final Parameter CHAT_MODEL_API_KEY =
new Parameter("s2.chat.model.api.key", "",
"ApiKey", "",
"string", "对话模型配置");
public static final Parameter CHAT_MODEL_NAME =
new Parameter("s2.chat.model.name", "",
"ModelName", "",
"string", "对话模型配置");
public static final Parameter CHAT_MODEL_TEMPERATURE =
new Parameter("s2.chat.model.temperature", "0.0",
"Temperature", "",
"number", "对话模型配置");
public static final Parameter CHAT_MODEL_TIMEOUT =
new Parameter("s2.chat.model.timeout", "60",
"超时时间(秒)", "",
"number", "对话模型配置");
@Override
public List<Parameter> getSysParameters() {
return Lists.newArrayList(
CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_API_KEY,
CHAT_MODEL_NAME, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT
);
}
public ChatModelConfig convert() {
String chatModelProvider = getParameterValue(CHAT_MODEL_PROVIDER);
String chatModelBaseUrl = getParameterValue(CHAT_MODEL_BASE_URL);
String chatModelApiKey = getParameterValue(CHAT_MODEL_API_KEY);
String chatModelName = getParameterValue(CHAT_MODEL_NAME);
String chatModelTemperature = getParameterValue(CHAT_MODEL_TEMPERATURE);
String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT);
return ChatModelConfig.builder()
.provider(chatModelProvider)
.baseUrl(chatModelBaseUrl)
.apiKey(chatModelApiKey)
.modelName(chatModelName)
.temperature(Double.valueOf(chatModelTemperature))
.timeOut(Long.valueOf(chatModelTimeout))
.build();
}
}

View File

@@ -0,0 +1,86 @@
package com.tencent.supersonic.common.config;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import dev.langchain4j.provider.AzureModelFactory;
import dev.langchain4j.provider.DashscopeModelFactory;
import dev.langchain4j.provider.InMemoryModelFactory;
import dev.langchain4j.provider.OllamaModelFactory;
import dev.langchain4j.provider.OpenAiModelFactory;
import dev.langchain4j.provider.QianfanModelFactory;
import dev.langchain4j.provider.ZhipuModelFactory;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
@Service("EmbeddingModelConfig")
@Slf4j
public class EmbeddingModelParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_MODEL_PROVIDER =
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER,
"接口协议", "",
"string", "向量模型配置",
Lists.newArrayList(InMemoryModelFactory.PROVIDER,
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER));
public static final Parameter EMBEDDING_MODEL_BASE_URL =
new Parameter("s2.embedding.model.base.url", "",
"BaseUrl", "",
"string", "向量模型配置");
public static final Parameter EMBEDDING_MODEL_API_KEY =
new Parameter("s2.embedding.model.api.key", "",
"ApiKey", "",
"string", "向量模型配置");
public static final Parameter EMBEDDING_MODEL_NAME =
new Parameter("s2.embedding.model.name", "",
"ModelName", "",
"string", "向量模型配置");
public static final Parameter EMBEDDING_MODEL_PATH =
new Parameter("s2.embedding.model.path", "",
"模型路径", "",
"string", "向量模型配置");
public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH =
new Parameter("s2.embedding.model.vocabulary.path", "",
"词汇表路径", "",
"string", "向量模型配置");
@Override
public List<Parameter> getSysParameters() {
return Lists.newArrayList(
EMBEDDING_MODEL_PROVIDER, EMBEDDING_MODEL_BASE_URL, EMBEDDING_MODEL_API_KEY,
EMBEDDING_MODEL_NAME, EMBEDDING_MODEL_PATH, EMBEDDING_MODEL_VOCABULARY_PATH
);
}
public EmbeddingModelConfig convert() {
String provider = getParameterValue(EMBEDDING_MODEL_PROVIDER);
String baseUrl = getParameterValue(EMBEDDING_MODEL_BASE_URL);
String apiKey = getParameterValue(EMBEDDING_MODEL_API_KEY);
String modelName = getParameterValue(EMBEDDING_MODEL_NAME);
String modelPath = getParameterValue(EMBEDDING_MODEL_PATH);
String vocabularyPath = getParameterValue(EMBEDDING_MODEL_VOCABULARY_PATH);
return EmbeddingModelConfig.builder()
.provider(provider)
.baseUrl(baseUrl)
.apiKey(apiKey)
.modelName(modelName)
.modelPath(modelPath)
.vocabularyPath(vocabularyPath)
.build();
}
}

View File

@@ -0,0 +1,62 @@
package com.tencent.supersonic.common.config;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import dev.langchain4j.provider.InMemoryModelFactory;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
@Service("EmbeddingStoreParameterConfig")
@Slf4j
public class EmbeddingStoreParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_STORE_PROVIDER =
new Parameter("s2.embedding.store.provider", InMemoryModelFactory.PROVIDER,
"向量库类型", "",
"string", "向量库配置");
public static final Parameter EMBEDDING_STORE_BASE_URL =
new Parameter("s2.embedding.store.base.url", "",
"BaseUrl", "",
"string", "向量库配置");
public static final Parameter EMBEDDING_STORE_API_KEY =
new Parameter("s2.embedding.store.api.key", "",
"ApiKey", "",
"string", "向量库配置");
public static final Parameter EMBEDDING_STORE_PERSIST_PATH =
new Parameter("s2.embedding.store.persist.path", "/tmp",
"持久化路径", "",
"string", "向量库配置");
public static final Parameter EMBEDDING_STORE_TIMEOUT =
new Parameter("s2.embedding.store.timeout", "60",
"超时时间(秒)", "",
"number", "向量库配置");
@Override
public List<Parameter> getSysParameters() {
return Lists.newArrayList(
EMBEDDING_STORE_PROVIDER, EMBEDDING_STORE_BASE_URL, EMBEDDING_STORE_API_KEY,
EMBEDDING_STORE_PERSIST_PATH, EMBEDDING_STORE_TIMEOUT
);
}
public EmbeddingStoreConfig convert() {
String provider = getParameterValue(EMBEDDING_STORE_PROVIDER);
String baseUrl = getParameterValue(EMBEDDING_STORE_BASE_URL);
String apiKey = getParameterValue(EMBEDDING_STORE_API_KEY);
String persistPath = getParameterValue(EMBEDDING_STORE_PERSIST_PATH);
String timeOut = getParameterValue(EMBEDDING_STORE_TIMEOUT);
return EmbeddingStoreConfig.builder()
.provider(provider)
.baseUrl(baseUrl)
.apiKey(apiKey)
.persistPath(persistPath)
.timeOut(Long.valueOf(timeOut))
.build();
}
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.common.config; package com.tencent.supersonic.common.pojo;
import com.tencent.supersonic.common.util.AESEncryptionUtil; import com.tencent.supersonic.common.util.AESEncryptionUtil;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;

View File

@@ -1,12 +1,14 @@
package com.tencent.supersonic.common.config; package com.tencent.supersonic.common.pojo;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import java.io.Serializable; import java.io.Serializable;
@Data @Data
@Builder
@AllArgsConstructor @AllArgsConstructor
@NoArgsConstructor @NoArgsConstructor
public class EmbeddingModelConfig implements Serializable { public class EmbeddingModelConfig implements Serializable {

View File

@@ -1,10 +1,16 @@
package com.tencent.supersonic.common.config; package com.tencent.supersonic.common.pojo;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor;
import java.io.Serializable; import java.io.Serializable;
@Data @Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class EmbeddingStoreConfig implements Serializable { public class EmbeddingStoreConfig implements Serializable {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
private String provider; private String provider;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.common.config; package com.tencent.supersonic.common.pojo;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;

View File

@@ -4,6 +4,7 @@ import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import java.util.List; import java.util.List;
@Data @Data
@@ -21,12 +22,7 @@ public class Parameter {
public Parameter(String name, String defaultValue, String comment, public Parameter(String name, String defaultValue, String comment,
String description, String dataType, String module) { String description, String dataType, String module) {
this.name = name; this(name, defaultValue, comment, description, dataType, module, null);
this.defaultValue = defaultValue;
this.comment = comment;
this.description = description;
this.dataType = dataType;
this.module = module;
} }
public Parameter(String name, String defaultValue, String comment, String description, public Parameter(String name, String defaultValue, String comment, String description,

View File

@@ -2,10 +2,13 @@ package com.tencent.supersonic.common.service.impl;
import com.google.common.cache.Cache; import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheBuilder;
import com.tencent.supersonic.common.config.EmbeddingModelParameterConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import com.tencent.supersonic.common.service.EmbeddingService; import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.provider.ModelProvider;
import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult; import dev.langchain4j.store.embedding.EmbeddingSearchResult;
@@ -19,6 +22,12 @@ import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder; import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore; import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
@@ -27,11 +36,6 @@ import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service @Service
@Slf4j @Slf4j
@@ -41,7 +45,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
private EmbeddingStoreFactory embeddingStoreFactory; private EmbeddingStoreFactory embeddingStoreFactory;
@Autowired @Autowired
private EmbeddingModel embeddingModel; private EmbeddingModelParameterConfig embeddingModelParameterConfig;
private Cache<String, Boolean> cache = CacheBuilder.newBuilder() private Cache<String, Boolean> cache = CacheBuilder.newBuilder()
.maximumSize(10000) .maximumSize(10000)
@@ -55,6 +59,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
for (TextSegment query : queries) { for (TextSegment query : queries) {
String question = query.text(); String question = query.text();
try { try {
EmbeddingModel embeddingModel = getEmbeddingModel();
Embedding embedding = embeddingModel.embed(question).content(); Embedding embedding = embeddingModel.embed(question).content();
boolean existSegment = existSegment(embeddingStore, query, embedding); boolean existSegment = existSegment(embeddingStore, query, embedding);
if (existSegment) { if (existSegment) {
@@ -122,6 +127,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
List<String> queryTextsList = retrieveQuery.getQueryTextsList(); List<String> queryTextsList = retrieveQuery.getQueryTextsList();
Map<String, String> filterCondition = retrieveQuery.getFilterCondition(); Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
for (String queryText : queryTextsList) { for (String queryText : queryTextsList) {
EmbeddingModel embeddingModel = getEmbeddingModel();
Embedding embeddedText = embeddingModel.embed(queryText).content(); Embedding embeddedText = embeddingModel.embed(queryText).content();
Filter filter = createCombinedFilter(filterCondition); Filter filter = createCombinedFilter(filterCondition);
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
@@ -169,4 +175,9 @@ public class EmbeddingServiceImpl implements EmbeddingService {
} }
return result; return result;
} }
private EmbeddingModel getEmbeddingModel() {
EmbeddingModelConfig embeddingModelConfig = embeddingModelParameterConfig.convert();
return ModelProvider.getEmbeddingModel(embeddingModelConfig);
}
} }

View File

@@ -1,6 +1,6 @@
package dev.langchain4j.chroma.spring; package dev.langchain4j.chroma.spring;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig; import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory; import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore; import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore;

View File

@@ -1,7 +1,7 @@
package dev.langchain4j.inmemory.spring; package dev.langchain4j.inmemory.spring;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig; import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory; import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;

View File

@@ -1,6 +1,6 @@
package dev.langchain4j.milvus.spring; package dev.langchain4j.milvus.spring;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig; import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory; import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStore;

View File

@@ -1,7 +1,7 @@
package dev.langchain4j.provider; package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.azure.AzureOpenAiChatModel; import dev.langchain4j.model.azure.AzureOpenAiChatModel;
import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel; import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;

View File

@@ -1,7 +1,7 @@
package dev.langchain4j.provider; package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.dashscope.QwenChatModel; import dev.langchain4j.model.dashscope.QwenChatModel;
import dev.langchain4j.model.dashscope.QwenEmbeddingModel; import dev.langchain4j.model.dashscope.QwenEmbeddingModel;

View File

@@ -1,6 +1,6 @@
package dev.langchain4j.provider; package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig; import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStore;
public interface EmbeddingStoreFactory { public interface EmbeddingStoreFactory {

View File

@@ -1,7 +1,7 @@
package dev.langchain4j.provider; package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel; import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;

View File

@@ -1,7 +1,7 @@
package dev.langchain4j.provider; package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.localai.LocalAiChatModel; import dev.langchain4j.model.localai.LocalAiChatModel;

View File

@@ -1,7 +1,7 @@
package dev.langchain4j.provider; package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;

View File

@@ -1,8 +1,7 @@
package dev.langchain4j.provider; package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
@@ -10,7 +9,6 @@ import org.apache.commons.lang3.StringUtils;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Objects;
public class ModelProvider { public class ModelProvider {
private static final Map<String, ModelFactory> factories = new HashMap<>(); private static final Map<String, ModelFactory> factories = new HashMap<>();
@@ -33,14 +31,10 @@ public class ModelProvider {
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + modelConfig.getProvider()); throw new RuntimeException("Unsupported ChatLanguageModel provider: " + modelConfig.getProvider());
} }
public static EmbeddingModel getEmbeddingModel(ModelConfig modelConfig) { public static EmbeddingModel getEmbeddingModel(EmbeddingModelConfig embeddingModel) {
if (modelConfig == null || Objects.isNull(modelConfig.getEmbeddingModel()) if (embeddingModel == null || StringUtils.isBlank(embeddingModel.getProvider())) {
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getBaseUrl())
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getProvider())) {
return ContextUtils.getBean(EmbeddingModel.class); return ContextUtils.getBean(EmbeddingModel.class);
} }
EmbeddingModelConfig embeddingModel = modelConfig.getEmbeddingModel();
ModelFactory modelFactory = factories.get(embeddingModel.getProvider().toUpperCase()); ModelFactory modelFactory = factories.get(embeddingModel.getProvider().toUpperCase());
if (modelFactory != null) { if (modelFactory != null) {
return modelFactory.createEmbeddingModel(embeddingModel); return modelFactory.createEmbeddingModel(embeddingModel);

View File

@@ -1,7 +1,7 @@
package dev.langchain4j.provider; package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.ollama.OllamaChatModel; import dev.langchain4j.model.ollama.OllamaChatModel;

View File

@@ -1,7 +1,7 @@
package dev.langchain4j.provider; package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiChatModel; import dev.langchain4j.model.openai.OpenAiChatModel;

View File

@@ -1,7 +1,7 @@
package dev.langchain4j.provider; package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel; import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;

View File

@@ -1,7 +1,7 @@
package dev.langchain4j.provider; package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel; import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;

View File

@@ -1,6 +1,6 @@
package dev.langchain4j.store.embedding; package dev.langchain4j.store.embedding;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig; import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory; import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory;
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory; import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;

View File

@@ -3,8 +3,8 @@ package com.tencent.supersonic.headless.api.pojo.request;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.QueryDataType;

View File

@@ -2,17 +2,17 @@ package com.tencent.supersonic.headless.chat;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState; import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.chat.parser.ParserConfig; import com.tencent.supersonic.headless.chat.parser.ParserConfig;
import com.tencent.supersonic.headless.chat.query.SemanticQuery; import com.tencent.supersonic.headless.chat.query.SemanticQuery;
@@ -21,7 +21,6 @@ import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;

View File

@@ -57,15 +57,15 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
@Override @Override
protected void detectByBatch(ChatQueryContext chatQueryContext, Set<EmbeddingResult> results, protected void detectByBatch(ChatQueryContext chatQueryContext, Set<EmbeddingResult> results,
Set<Long> detectDataSetIds, Set<String> detectSegments) { Set<Long> detectDataSetIds, Set<String> detectSegments) {
int embedddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MIN)); int embeddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MIN));
int embedddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MAX)); int embeddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MAX));
int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH)); int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
List<String> queryTextsList = detectSegments.stream() List<String> queryTextsList = detectSegments.stream()
.map(detectSegment -> detectSegment.trim()) .map(detectSegment -> detectSegment.trim())
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment) .filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
&& detectSegment.length() >= embedddingMapperMin && detectSegment.length() >= embeddingMapperMin
&& detectSegment.length() <= embedddingMapperMax) && detectSegment.length() <= embeddingMapperMax)
.collect(Collectors.toList()); .collect(Collectors.toList());
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList, List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.headless.chat.parser.llm; package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
import com.fasterxml.jackson.annotation.JsonValue; import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import lombok.Data; import lombok.Data;

View File

@@ -8,7 +8,7 @@ import com.tencent.supersonic.chat.server.agent.AgentConfig;
import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.RuleParserTool; import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.DataUtils;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;