mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
(improvement)(chat) Integrate llm configuration into system settings. (#1403)
This commit is contained in:
@@ -4,9 +4,9 @@ package com.tencent.supersonic.chat.server.agent;
|
|||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.google.common.collect.Sets;
|
import com.google.common.collect.Sets;
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
|
||||||
import com.tencent.supersonic.common.config.PromptConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
import com.tencent.supersonic.common.config.VisualConfig;
|
import com.tencent.supersonic.common.config.VisualConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
|
|||||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.web.bind.annotation.DeleteMapping;
|
import org.springframework.web.bind.annotation.DeleteMapping;
|
||||||
import org.springframework.web.bind.annotation.PathVariable;
|
import org.springframework.web.bind.annotation.PathVariable;
|
||||||
|
|||||||
@@ -12,9 +12,9 @@ import com.tencent.supersonic.chat.server.service.AgentService;
|
|||||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
|
||||||
import com.tencent.supersonic.common.config.PromptConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
import com.tencent.supersonic.common.config.VisualConfig;
|
import com.tencent.supersonic.common.config.VisualConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.server.util;
|
package com.tencent.supersonic.chat.server.util;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.provider.ModelProvider;
|
import dev.langchain4j.provider.ModelProvider;
|
||||||
|
|||||||
@@ -0,0 +1,73 @@
|
|||||||
|
package com.tencent.supersonic.common.config;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.Parameter;
|
||||||
|
import dev.langchain4j.provider.OllamaModelFactory;
|
||||||
|
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Service("ChatModelParameterConfig")
|
||||||
|
@Slf4j
|
||||||
|
public class ChatModelParameterConfig extends ParameterConfig {
|
||||||
|
|
||||||
|
public static final Parameter CHAT_MODEL_PROVIDER =
|
||||||
|
new Parameter("s2.chat.model.provider", OpenAiModelFactory.PROVIDER,
|
||||||
|
"接口协议", "",
|
||||||
|
"string", "对话模型配置",
|
||||||
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER));
|
||||||
|
|
||||||
|
public static final Parameter CHAT_MODEL_BASE_URL =
|
||||||
|
new Parameter("s2.chat.model.base.url", "",
|
||||||
|
"BaseUrl", "",
|
||||||
|
"string", "对话模型配置");
|
||||||
|
|
||||||
|
public static final Parameter CHAT_MODEL_API_KEY =
|
||||||
|
new Parameter("s2.chat.model.api.key", "",
|
||||||
|
"ApiKey", "",
|
||||||
|
"string", "对话模型配置");
|
||||||
|
|
||||||
|
public static final Parameter CHAT_MODEL_NAME =
|
||||||
|
new Parameter("s2.chat.model.name", "",
|
||||||
|
"ModelName", "",
|
||||||
|
"string", "对话模型配置");
|
||||||
|
|
||||||
|
public static final Parameter CHAT_MODEL_TEMPERATURE =
|
||||||
|
new Parameter("s2.chat.model.temperature", "0.0",
|
||||||
|
"Temperature", "",
|
||||||
|
"number", "对话模型配置");
|
||||||
|
|
||||||
|
public static final Parameter CHAT_MODEL_TIMEOUT =
|
||||||
|
new Parameter("s2.chat.model.timeout", "60",
|
||||||
|
"超时时间(秒)", "",
|
||||||
|
"number", "对话模型配置");
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Parameter> getSysParameters() {
|
||||||
|
return Lists.newArrayList(
|
||||||
|
CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_API_KEY,
|
||||||
|
CHAT_MODEL_NAME, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ChatModelConfig convert() {
|
||||||
|
String chatModelProvider = getParameterValue(CHAT_MODEL_PROVIDER);
|
||||||
|
String chatModelBaseUrl = getParameterValue(CHAT_MODEL_BASE_URL);
|
||||||
|
String chatModelApiKey = getParameterValue(CHAT_MODEL_API_KEY);
|
||||||
|
String chatModelName = getParameterValue(CHAT_MODEL_NAME);
|
||||||
|
String chatModelTemperature = getParameterValue(CHAT_MODEL_TEMPERATURE);
|
||||||
|
String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT);
|
||||||
|
|
||||||
|
return ChatModelConfig.builder()
|
||||||
|
.provider(chatModelProvider)
|
||||||
|
.baseUrl(chatModelBaseUrl)
|
||||||
|
.apiKey(chatModelApiKey)
|
||||||
|
.modelName(chatModelName)
|
||||||
|
.temperature(Double.valueOf(chatModelTemperature))
|
||||||
|
.timeOut(Long.valueOf(chatModelTimeout))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
package com.tencent.supersonic.common.config;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.Parameter;
|
||||||
|
import dev.langchain4j.provider.AzureModelFactory;
|
||||||
|
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||||
|
import dev.langchain4j.provider.InMemoryModelFactory;
|
||||||
|
import dev.langchain4j.provider.OllamaModelFactory;
|
||||||
|
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||||
|
import dev.langchain4j.provider.QianfanModelFactory;
|
||||||
|
import dev.langchain4j.provider.ZhipuModelFactory;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Service("EmbeddingModelConfig")
|
||||||
|
@Slf4j
|
||||||
|
public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_MODEL_PROVIDER =
|
||||||
|
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER,
|
||||||
|
"接口协议", "",
|
||||||
|
"string", "向量模型配置",
|
||||||
|
Lists.newArrayList(InMemoryModelFactory.PROVIDER,
|
||||||
|
OpenAiModelFactory.PROVIDER,
|
||||||
|
OllamaModelFactory.PROVIDER,
|
||||||
|
AzureModelFactory.PROVIDER,
|
||||||
|
DashscopeModelFactory.PROVIDER,
|
||||||
|
QianfanModelFactory.PROVIDER,
|
||||||
|
ZhipuModelFactory.PROVIDER));
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_MODEL_BASE_URL =
|
||||||
|
new Parameter("s2.embedding.model.base.url", "",
|
||||||
|
"BaseUrl", "",
|
||||||
|
"string", "向量模型配置");
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_MODEL_API_KEY =
|
||||||
|
new Parameter("s2.embedding.model.api.key", "",
|
||||||
|
"ApiKey", "",
|
||||||
|
"string", "向量模型配置");
|
||||||
|
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_MODEL_NAME =
|
||||||
|
new Parameter("s2.embedding.model.name", "",
|
||||||
|
"ModelName", "",
|
||||||
|
"string", "向量模型配置");
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_MODEL_PATH =
|
||||||
|
new Parameter("s2.embedding.model.path", "",
|
||||||
|
"模型路径", "",
|
||||||
|
"string", "向量模型配置");
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH =
|
||||||
|
new Parameter("s2.embedding.model.vocabulary.path", "",
|
||||||
|
"词汇表路径", "",
|
||||||
|
"string", "向量模型配置");
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Parameter> getSysParameters() {
|
||||||
|
return Lists.newArrayList(
|
||||||
|
EMBEDDING_MODEL_PROVIDER, EMBEDDING_MODEL_BASE_URL, EMBEDDING_MODEL_API_KEY,
|
||||||
|
EMBEDDING_MODEL_NAME, EMBEDDING_MODEL_PATH, EMBEDDING_MODEL_VOCABULARY_PATH
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public EmbeddingModelConfig convert() {
|
||||||
|
String provider = getParameterValue(EMBEDDING_MODEL_PROVIDER);
|
||||||
|
String baseUrl = getParameterValue(EMBEDDING_MODEL_BASE_URL);
|
||||||
|
String apiKey = getParameterValue(EMBEDDING_MODEL_API_KEY);
|
||||||
|
String modelName = getParameterValue(EMBEDDING_MODEL_NAME);
|
||||||
|
String modelPath = getParameterValue(EMBEDDING_MODEL_PATH);
|
||||||
|
String vocabularyPath = getParameterValue(EMBEDDING_MODEL_VOCABULARY_PATH);
|
||||||
|
|
||||||
|
return EmbeddingModelConfig.builder()
|
||||||
|
.provider(provider)
|
||||||
|
.baseUrl(baseUrl)
|
||||||
|
.apiKey(apiKey)
|
||||||
|
.modelName(modelName)
|
||||||
|
.modelPath(modelPath)
|
||||||
|
.vocabularyPath(vocabularyPath)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
package com.tencent.supersonic.common.config;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.Parameter;
|
||||||
|
import dev.langchain4j.provider.InMemoryModelFactory;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Service("EmbeddingStoreParameterConfig")
|
||||||
|
@Slf4j
|
||||||
|
public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
||||||
|
public static final Parameter EMBEDDING_STORE_PROVIDER =
|
||||||
|
new Parameter("s2.embedding.store.provider", InMemoryModelFactory.PROVIDER,
|
||||||
|
"向量库类型", "",
|
||||||
|
"string", "向量库配置");
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_STORE_BASE_URL =
|
||||||
|
new Parameter("s2.embedding.store.base.url", "",
|
||||||
|
"BaseUrl", "",
|
||||||
|
"string", "向量库配置");
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_STORE_API_KEY =
|
||||||
|
new Parameter("s2.embedding.store.api.key", "",
|
||||||
|
"ApiKey", "",
|
||||||
|
"string", "向量库配置");
|
||||||
|
public static final Parameter EMBEDDING_STORE_PERSIST_PATH =
|
||||||
|
new Parameter("s2.embedding.store.persist.path", "/tmp",
|
||||||
|
"持久化路径", "",
|
||||||
|
"string", "向量库配置");
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_STORE_TIMEOUT =
|
||||||
|
new Parameter("s2.embedding.store.timeout", "60",
|
||||||
|
"超时时间(秒)", "",
|
||||||
|
"number", "向量库配置");
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Parameter> getSysParameters() {
|
||||||
|
return Lists.newArrayList(
|
||||||
|
EMBEDDING_STORE_PROVIDER, EMBEDDING_STORE_BASE_URL, EMBEDDING_STORE_API_KEY,
|
||||||
|
EMBEDDING_STORE_PERSIST_PATH, EMBEDDING_STORE_TIMEOUT
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public EmbeddingStoreConfig convert() {
|
||||||
|
String provider = getParameterValue(EMBEDDING_STORE_PROVIDER);
|
||||||
|
String baseUrl = getParameterValue(EMBEDDING_STORE_BASE_URL);
|
||||||
|
String apiKey = getParameterValue(EMBEDDING_STORE_API_KEY);
|
||||||
|
String persistPath = getParameterValue(EMBEDDING_STORE_PERSIST_PATH);
|
||||||
|
String timeOut = getParameterValue(EMBEDDING_STORE_TIMEOUT);
|
||||||
|
|
||||||
|
return EmbeddingStoreConfig.builder()
|
||||||
|
.provider(provider)
|
||||||
|
.baseUrl(baseUrl)
|
||||||
|
.apiKey(apiKey)
|
||||||
|
.persistPath(persistPath)
|
||||||
|
.timeOut(Long.valueOf(timeOut))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.common.config;
|
package com.tencent.supersonic.common.pojo;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.AESEncryptionUtil;
|
import com.tencent.supersonic.common.util.AESEncryptionUtil;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
@@ -1,12 +1,14 @@
|
|||||||
package com.tencent.supersonic.common.config;
|
package com.tencent.supersonic.common.pojo;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@Builder
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class EmbeddingModelConfig implements Serializable {
|
public class EmbeddingModelConfig implements Serializable {
|
||||||
@@ -1,10 +1,16 @@
|
|||||||
package com.tencent.supersonic.common.config;
|
package com.tencent.supersonic.common.pojo;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@Builder
|
||||||
|
@AllArgsConstructor
|
||||||
|
@NoArgsConstructor
|
||||||
public class EmbeddingStoreConfig implements Serializable {
|
public class EmbeddingStoreConfig implements Serializable {
|
||||||
private static final long serialVersionUID = 1L;
|
private static final long serialVersionUID = 1L;
|
||||||
private String provider;
|
private String provider;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.common.config;
|
package com.tencent.supersonic.common.pojo;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -4,6 +4,7 @@ import lombok.AllArgsConstructor;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@@ -21,12 +22,7 @@ public class Parameter {
|
|||||||
|
|
||||||
public Parameter(String name, String defaultValue, String comment,
|
public Parameter(String name, String defaultValue, String comment,
|
||||||
String description, String dataType, String module) {
|
String description, String dataType, String module) {
|
||||||
this.name = name;
|
this(name, defaultValue, comment, description, dataType, module, null);
|
||||||
this.defaultValue = defaultValue;
|
|
||||||
this.comment = comment;
|
|
||||||
this.description = description;
|
|
||||||
this.dataType = dataType;
|
|
||||||
this.module = module;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Parameter(String name, String defaultValue, String comment, String description,
|
public Parameter(String name, String defaultValue, String comment, String description,
|
||||||
|
|||||||
@@ -2,10 +2,13 @@ package com.tencent.supersonic.common.service.impl;
|
|||||||
|
|
||||||
import com.google.common.cache.Cache;
|
import com.google.common.cache.Cache;
|
||||||
import com.google.common.cache.CacheBuilder;
|
import com.google.common.cache.CacheBuilder;
|
||||||
|
import com.tencent.supersonic.common.config.EmbeddingModelParameterConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||||
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
|
import dev.langchain4j.provider.ModelProvider;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||||
@@ -19,6 +22,12 @@ import dev.langchain4j.store.embedding.filter.Filter;
|
|||||||
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
|
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
|
||||||
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
||||||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.MapUtils;
|
||||||
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -27,11 +36,6 @@ import java.util.Map;
|
|||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.collections.MapUtils;
|
|
||||||
import org.apache.commons.collections4.CollectionUtils;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -41,7 +45,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
private EmbeddingStoreFactory embeddingStoreFactory;
|
private EmbeddingStoreFactory embeddingStoreFactory;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private EmbeddingModel embeddingModel;
|
private EmbeddingModelParameterConfig embeddingModelParameterConfig;
|
||||||
|
|
||||||
private Cache<String, Boolean> cache = CacheBuilder.newBuilder()
|
private Cache<String, Boolean> cache = CacheBuilder.newBuilder()
|
||||||
.maximumSize(10000)
|
.maximumSize(10000)
|
||||||
@@ -55,6 +59,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
for (TextSegment query : queries) {
|
for (TextSegment query : queries) {
|
||||||
String question = query.text();
|
String question = query.text();
|
||||||
try {
|
try {
|
||||||
|
EmbeddingModel embeddingModel = getEmbeddingModel();
|
||||||
Embedding embedding = embeddingModel.embed(question).content();
|
Embedding embedding = embeddingModel.embed(question).content();
|
||||||
boolean existSegment = existSegment(embeddingStore, query, embedding);
|
boolean existSegment = existSegment(embeddingStore, query, embedding);
|
||||||
if (existSegment) {
|
if (existSegment) {
|
||||||
@@ -122,6 +127,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
||||||
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
||||||
for (String queryText : queryTextsList) {
|
for (String queryText : queryTextsList) {
|
||||||
|
EmbeddingModel embeddingModel = getEmbeddingModel();
|
||||||
Embedding embeddedText = embeddingModel.embed(queryText).content();
|
Embedding embeddedText = embeddingModel.embed(queryText).content();
|
||||||
Filter filter = createCombinedFilter(filterCondition);
|
Filter filter = createCombinedFilter(filterCondition);
|
||||||
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||||
@@ -169,4 +175,9 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private EmbeddingModel getEmbeddingModel() {
|
||||||
|
EmbeddingModelConfig embeddingModelConfig = embeddingModelParameterConfig.convert();
|
||||||
|
return ModelProvider.getEmbeddingModel(embeddingModelConfig);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package dev.langchain4j.chroma.spring;
|
package dev.langchain4j.chroma.spring;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||||
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore;
|
import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package dev.langchain4j.inmemory.spring;
|
package dev.langchain4j.inmemory.spring;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package dev.langchain4j.milvus.spring;
|
package dev.langchain4j.milvus.spring;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package dev.langchain4j.provider;
|
package dev.langchain4j.provider;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import dev.langchain4j.model.azure.AzureOpenAiChatModel;
|
import dev.langchain4j.model.azure.AzureOpenAiChatModel;
|
||||||
import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel;
|
import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package dev.langchain4j.provider;
|
package dev.langchain4j.provider;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.dashscope.QwenChatModel;
|
import dev.langchain4j.model.dashscope.QwenChatModel;
|
||||||
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
|
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package dev.langchain4j.provider;
|
package dev.langchain4j.provider;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
|
|
||||||
public interface EmbeddingStoreFactory {
|
public interface EmbeddingStoreFactory {
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package dev.langchain4j.provider;
|
package dev.langchain4j.provider;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||||
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
|
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package dev.langchain4j.provider;
|
package dev.langchain4j.provider;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.localai.LocalAiChatModel;
|
import dev.langchain4j.model.localai.LocalAiChatModel;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package dev.langchain4j.provider;
|
package dev.langchain4j.provider;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
package dev.langchain4j.provider;
|
package dev.langchain4j.provider;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import com.tencent.supersonic.common.config.ModelConfig;
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
@@ -10,7 +9,6 @@ import org.apache.commons.lang3.StringUtils;
|
|||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
|
||||||
|
|
||||||
public class ModelProvider {
|
public class ModelProvider {
|
||||||
private static final Map<String, ModelFactory> factories = new HashMap<>();
|
private static final Map<String, ModelFactory> factories = new HashMap<>();
|
||||||
@@ -33,14 +31,10 @@ public class ModelProvider {
|
|||||||
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + modelConfig.getProvider());
|
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + modelConfig.getProvider());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static EmbeddingModel getEmbeddingModel(ModelConfig modelConfig) {
|
public static EmbeddingModel getEmbeddingModel(EmbeddingModelConfig embeddingModel) {
|
||||||
if (modelConfig == null || Objects.isNull(modelConfig.getEmbeddingModel())
|
if (embeddingModel == null || StringUtils.isBlank(embeddingModel.getProvider())) {
|
||||||
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getBaseUrl())
|
|
||||||
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getProvider())) {
|
|
||||||
return ContextUtils.getBean(EmbeddingModel.class);
|
return ContextUtils.getBean(EmbeddingModel.class);
|
||||||
}
|
}
|
||||||
EmbeddingModelConfig embeddingModel = modelConfig.getEmbeddingModel();
|
|
||||||
|
|
||||||
ModelFactory modelFactory = factories.get(embeddingModel.getProvider().toUpperCase());
|
ModelFactory modelFactory = factories.get(embeddingModel.getProvider().toUpperCase());
|
||||||
if (modelFactory != null) {
|
if (modelFactory != null) {
|
||||||
return modelFactory.createEmbeddingModel(embeddingModel);
|
return modelFactory.createEmbeddingModel(embeddingModel);
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package dev.langchain4j.provider;
|
package dev.langchain4j.provider;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.ollama.OllamaChatModel;
|
import dev.langchain4j.model.ollama.OllamaChatModel;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package dev.langchain4j.provider;
|
package dev.langchain4j.provider;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package dev.langchain4j.provider;
|
package dev.langchain4j.provider;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
|
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package dev.langchain4j.provider;
|
package dev.langchain4j.provider;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package dev.langchain4j.store.embedding;
|
package dev.langchain4j.store.embedding;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory;
|
import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory;
|
||||||
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
|
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ package com.tencent.supersonic.headless.api.pojo.request;
|
|||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.google.common.collect.Sets;
|
import com.google.common.collect.Sets;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
|
||||||
import com.tencent.supersonic.common.config.PromptConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||||
|
|||||||
@@ -2,17 +2,17 @@ package com.tencent.supersonic.headless.chat;
|
|||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
|
||||||
import com.tencent.supersonic.common.config.PromptConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
|
import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||||
@@ -21,7 +21,6 @@ import lombok.Builder;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|||||||
@@ -57,15 +57,15 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
@Override
|
@Override
|
||||||
protected void detectByBatch(ChatQueryContext chatQueryContext, Set<EmbeddingResult> results,
|
protected void detectByBatch(ChatQueryContext chatQueryContext, Set<EmbeddingResult> results,
|
||||||
Set<Long> detectDataSetIds, Set<String> detectSegments) {
|
Set<Long> detectDataSetIds, Set<String> detectSegments) {
|
||||||
int embedddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MIN));
|
int embeddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MIN));
|
||||||
int embedddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MAX));
|
int embeddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MAX));
|
||||||
int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
|
int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
|
||||||
|
|
||||||
List<String> queryTextsList = detectSegments.stream()
|
List<String> queryTextsList = detectSegments.stream()
|
||||||
.map(detectSegment -> detectSegment.trim())
|
.map(detectSegment -> detectSegment.trim())
|
||||||
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
|
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
|
||||||
&& detectSegment.length() >= embedddingMapperMin
|
&& detectSegment.length() >= embeddingMapperMin
|
||||||
&& detectSegment.length() <= embedddingMapperMax)
|
&& detectSegment.length() <= embeddingMapperMax)
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
|
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
|||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonValue;
|
import com.fasterxml.jackson.annotation.JsonValue;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
|
||||||
import com.tencent.supersonic.common.config.PromptConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import com.tencent.supersonic.chat.server.agent.AgentConfig;
|
|||||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||||
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
||||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.util.DataUtils;
|
import com.tencent.supersonic.util.DataUtils;
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
|
|||||||
Reference in New Issue
Block a user