mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(chat) Support embedding store configuration. (#1363)
This commit is contained in:
@@ -4,8 +4,6 @@ package com.tencent.supersonic.chat.server.agent;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
||||
import com.tencent.supersonic.common.config.ModelConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.config.VisualConfig;
|
||||
@@ -35,9 +33,7 @@ public class Agent extends RecordInfo {
|
||||
private Integer status;
|
||||
private List<String> examples;
|
||||
private String agentConfig;
|
||||
private ChatModelConfig llmConfig;
|
||||
private ModelConfig modelConfig;
|
||||
private EmbeddingStoreConfig embeddingStore;
|
||||
private PromptConfig promptConfig;
|
||||
private MultiTurnConfig multiTurnConfig;
|
||||
private VisualConfig visualConfig;
|
||||
|
||||
@@ -46,7 +46,7 @@ public class PlainTextExecutor implements ChatExecutor {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(chatAgent.getLlmConfig());
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatAgent.getModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
|
||||
QueryResult result = new QueryResult();
|
||||
|
||||
@@ -56,8 +56,8 @@ public class MemoryReviewTask {
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
|
||||
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(
|
||||
chatAgent.getLlmConfig());
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
||||
chatAgent.getModelConfig());
|
||||
if (Objects.nonNull(chatLanguageModel)) {
|
||||
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
|
||||
keyPipelineLog.info("MemoryReviewTask modelResp:{}", response);
|
||||
|
||||
@@ -5,8 +5,8 @@ import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryReposi
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.config.ModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -168,7 +168,7 @@ public class NL2SQLParser implements ChatParser {
|
||||
.curtSchema(curtMapStr)
|
||||
.histSchema(histMapStr)
|
||||
.histSQL(histSQL)
|
||||
.llmConfig(queryTextReq.getLlmConfig())
|
||||
.modelConfig(queryTextReq.getModelConfig())
|
||||
.build());
|
||||
chatParseContext.setQueryText(rewrittenQuery);
|
||||
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
|
||||
@@ -181,7 +181,7 @@ public class NL2SQLParser implements ChatParser {
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr);
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(context.getLlmConfig());
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(context.getModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
|
||||
String result = response.content().text();
|
||||
@@ -243,7 +243,7 @@ public class NL2SQLParser implements ChatParser {
|
||||
private String curtSchema;
|
||||
private String histSchema;
|
||||
private String histSQL;
|
||||
private ChatModelConfig llmConfig;
|
||||
private ModelConfig modelConfig;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -11,15 +11,18 @@ import java.util.Date;
|
||||
@TableName("s2_agent")
|
||||
public class AgentDO {
|
||||
/**
|
||||
*
|
||||
*/
|
||||
@TableId(type = IdType.AUTO)
|
||||
private Integer id;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String name;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String description;
|
||||
|
||||
@@ -29,35 +32,40 @@ public class AgentDO {
|
||||
private Integer status;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String examples;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String config;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String createdBy;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private Date createdAt;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String updatedBy;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private Date updatedAt;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private Integer enableSearch;
|
||||
|
||||
private String llmConfig;
|
||||
|
||||
private String modelConfig;
|
||||
private String multiTurnConfig;
|
||||
|
||||
private String visualConfig;
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.ModelConfig;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.DeleteMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
@@ -30,16 +30,16 @@ public class AgentController {
|
||||
|
||||
@PostMapping
|
||||
public Agent createAgent(@RequestBody Agent agent,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
return agentService.createAgent(agent, user);
|
||||
}
|
||||
|
||||
@PutMapping
|
||||
public Agent updateAgent(@RequestBody Agent agent,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
return agentService.updateAgent(agent, user);
|
||||
}
|
||||
@@ -51,8 +51,8 @@ public class AgentController {
|
||||
}
|
||||
|
||||
@PostMapping("/testLLMConn")
|
||||
public boolean testLLMConn(@RequestBody ChatModelConfig llmConfig) {
|
||||
return LLMConnHelper.testConnection(llmConfig);
|
||||
public boolean testLLMConn(@RequestBody ModelConfig modelConfig) {
|
||||
return LLMConnHelper.testConnection(modelConfig);
|
||||
}
|
||||
|
||||
@RequestMapping("/getAgentList")
|
||||
|
||||
@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.ModelConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.config.VisualConfig;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
@@ -88,7 +88,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
}
|
||||
|
||||
private synchronized void doExecuteAgentExamples(Agent agent) {
|
||||
if (!agent.containsLLMParserTool() || !LLMConnHelper.testConnection(agent.getLlmConfig())
|
||||
if (!agent.containsLLMParserTool() || !LLMConnHelper.testConnection(agent.getModelConfig())
|
||||
|| CollectionUtils.isEmpty(agent.getExamples())) {
|
||||
return;
|
||||
}
|
||||
@@ -122,7 +122,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
BeanUtils.copyProperties(agentDO, agent);
|
||||
agent.setAgentConfig(agentDO.getConfig());
|
||||
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
|
||||
agent.setLlmConfig(JsonUtil.toObject(agentDO.getLlmConfig(), ChatModelConfig.class));
|
||||
agent.setModelConfig(JsonUtil.toObject(agentDO.getModelConfig(), ModelConfig.class));
|
||||
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
|
||||
agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
|
||||
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
|
||||
@@ -134,7 +134,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
BeanUtils.copyProperties(agent, agentDO);
|
||||
agentDO.setConfig(agent.getAgentConfig());
|
||||
agentDO.setExamples(JsonUtil.toString(agent.getExamples()));
|
||||
agentDO.setLlmConfig(JsonUtil.toString(agent.getLlmConfig()));
|
||||
agentDO.setModelConfig(JsonUtil.toString(agent.getModelConfig()));
|
||||
agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig()));
|
||||
agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig()));
|
||||
agentDO.setPromptConfig(JsonUtil.toString(agent.getPromptConfig()));
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.ModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
@@ -9,12 +9,13 @@ import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
public class LLMConnHelper {
|
||||
public static boolean testConnection(ChatModelConfig chatModel) {
|
||||
public static boolean testConnection(ModelConfig modelConfig) {
|
||||
try {
|
||||
if (chatModel == null || StringUtils.isBlank(chatModel.getBaseUrl())) {
|
||||
if (modelConfig == null || modelConfig.getChatModel() == null
|
||||
|| StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) {
|
||||
return false;
|
||||
}
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(chatModel);
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig);
|
||||
String response = chatLanguageModel.generate("Hi there");
|
||||
return StringUtils.isNotEmpty(response) ? true : false;
|
||||
} catch (Exception e) {
|
||||
|
||||
@@ -30,7 +30,7 @@ public class QueryReqConverter {
|
||||
&& MapUtils.isNotEmpty(queryTextReq.getMapInfo().getDataSetElementMatches())) {
|
||||
queryTextReq.setMapInfo(queryTextReq.getMapInfo());
|
||||
}
|
||||
queryTextReq.setLlmConfig(agent.getLlmConfig());
|
||||
queryTextReq.setModelConfig(agent.getModelConfig());
|
||||
queryTextReq.setPromptConfig(agent.getPromptConfig());
|
||||
return queryTextReq;
|
||||
}
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
package dev.langchain4j.chroma.spring;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
||||
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
@Slf4j
|
||||
public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||
|
||||
private Properties properties;
|
||||
|
||||
public ChromaEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
|
||||
this(createPropertiesFromConfig(storeConfig));
|
||||
}
|
||||
|
||||
public ChromaEmbeddingStoreFactory(Properties properties) {
|
||||
this.properties = properties;
|
||||
}
|
||||
@@ -23,4 +31,13 @@ public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||
.timeout(storeProperties.getTimeout())
|
||||
.build();
|
||||
}
|
||||
|
||||
private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
|
||||
Properties properties = new Properties();
|
||||
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
|
||||
BeanUtils.copyProperties(storeConfig, embeddingStore);
|
||||
embeddingStore.setTimeout(Duration.ofSeconds(storeConfig.getTimeOut()));
|
||||
properties.setEmbeddingStore(embeddingStore);
|
||||
return properties;
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import java.time.Duration;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
class EmbeddingStoreProperties {
|
||||
public class EmbeddingStoreProperties {
|
||||
|
||||
private String baseUrl;
|
||||
private Duration timeout;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dev.langchain4j.inmemory.spring;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
||||
@@ -9,6 +10,7 @@ import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
@@ -23,11 +25,22 @@ public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||
public static final String PERSISTENT_FILE_PRE = "InMemory.";
|
||||
private Properties properties;
|
||||
|
||||
public InMemoryEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
|
||||
this(createPropertiesFromConfig(storeConfig));
|
||||
}
|
||||
|
||||
public InMemoryEmbeddingStoreFactory(Properties properties) {
|
||||
this.properties = properties;
|
||||
}
|
||||
|
||||
private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
|
||||
Properties properties = new Properties();
|
||||
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
|
||||
BeanUtils.copyProperties(storeConfig, embeddingStore);
|
||||
properties.setEmbeddingStore(embeddingStore);
|
||||
return properties;
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized EmbeddingStore createEmbeddingStore(String collectionName) {
|
||||
InMemoryEmbeddingStore<TextSegment> embeddingStore = reloadFromPersistFile(collectionName);
|
||||
|
||||
@@ -1,17 +1,32 @@
|
||||
package dev.langchain4j.milvus.spring;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||
private final Properties properties;
|
||||
|
||||
public MilvusEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
|
||||
this(createPropertiesFromConfig(storeConfig));
|
||||
}
|
||||
|
||||
public MilvusEmbeddingStoreFactory(Properties properties) {
|
||||
this.properties = properties;
|
||||
}
|
||||
|
||||
private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
|
||||
Properties properties = new Properties();
|
||||
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
|
||||
BeanUtils.copyProperties(storeConfig, embeddingStore);
|
||||
embeddingStore.setUri(storeConfig.getBaseUrl());
|
||||
properties.setEmbeddingStore(embeddingStore);
|
||||
return properties;
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) {
|
||||
EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore();
|
||||
|
||||
@@ -13,15 +13,16 @@ import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class AzureModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "AZURE";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
|
||||
.endpoint(chatModel.getBaseUrl())
|
||||
.apiKey(chatModel.getApiKey())
|
||||
.deploymentName(chatModel.getModelName())
|
||||
.temperature(chatModel.getTemperature())
|
||||
.timeout(Duration.ofSeconds(chatModel.getTimeOut() == null ? 0L : chatModel.getTimeOut()));
|
||||
.endpoint(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey())
|
||||
.deploymentName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()));
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@@ -39,6 +40,6 @@ public class AzureModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.AZURE, this);
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,14 +11,16 @@ import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "DASHSCOPE";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return QwenChatModel.builder()
|
||||
.baseUrl(chatModel.getBaseUrl())
|
||||
.apiKey(chatModel.getApiKey())
|
||||
.modelName(chatModel.getModelName())
|
||||
.temperature(chatModel.getTemperature() == null ? 0L :
|
||||
chatModel.getTemperature().floatValue())
|
||||
.baseUrl(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey())
|
||||
.modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature() == null ? 0L :
|
||||
modelConfig.getTemperature().floatValue())
|
||||
.build();
|
||||
}
|
||||
|
||||
@@ -32,6 +34,6 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.DASHSCOPE, this);
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,9 +16,11 @@ import static dev.langchain4j.inmemory.spring.InMemoryAutoConfig.BGE_SMALL_ZH;
|
||||
|
||||
@Service
|
||||
public class InMemoryModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "IN_MEMORY";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
return null;
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
throw new UnsupportedOperationException("Not supported yet.");
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -40,6 +42,6 @@ public class InMemoryModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.IN_MEMORY, this);
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,14 +13,16 @@ import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class LocalAiModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "LOCAL_AI";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return LocalAiChatModel
|
||||
.builder()
|
||||
.baseUrl(chatModel.getBaseUrl())
|
||||
.modelName(chatModel.getModelName())
|
||||
.temperature(chatModel.getTemperature())
|
||||
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
|
||||
.baseUrl(modelConfig.getBaseUrl())
|
||||
.modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
|
||||
@@ -34,6 +36,6 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.LOCAL_AI, this);
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
|
||||
public interface ModelFactory {
|
||||
ChatLanguageModel createChatModel(ChatModelConfig llmConfig);
|
||||
ChatLanguageModel createChatModel(ChatModelConfig modelConfig);
|
||||
|
||||
EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel);
|
||||
}
|
||||
|
||||
@@ -15,24 +15,11 @@ import java.util.Objects;
|
||||
public class ModelProvider {
|
||||
private static final Map<String, ModelFactory> factories = new HashMap<>();
|
||||
|
||||
public static void add(Provider provider, ModelFactory modelFactory) {
|
||||
factories.put(provider.name(), modelFactory);
|
||||
public static void add(String provider, ModelFactory modelFactory) {
|
||||
factories.put(provider, modelFactory);
|
||||
}
|
||||
|
||||
public static ChatLanguageModel provideChatModel(ChatModelConfig llmConfig) {
|
||||
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|
||||
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
|
||||
return ContextUtils.getBean(ChatLanguageModel.class);
|
||||
}
|
||||
ModelFactory modelFactory = factories.get(llmConfig.getProvider().toUpperCase());
|
||||
if (modelFactory != null) {
|
||||
return modelFactory.createChatModel(llmConfig);
|
||||
}
|
||||
|
||||
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + llmConfig.getProvider());
|
||||
}
|
||||
|
||||
public static ChatLanguageModel provideChatModelNew(ModelConfig modelConfig) {
|
||||
public static ChatLanguageModel getChatModel(ModelConfig modelConfig) {
|
||||
if (modelConfig == null || modelConfig.getChatModel() == null
|
||||
|| StringUtils.isBlank(modelConfig.getChatModel().getProvider())
|
||||
|| StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) {
|
||||
@@ -47,7 +34,7 @@ public class ModelProvider {
|
||||
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + chatModel.getProvider());
|
||||
}
|
||||
|
||||
public static EmbeddingModel provideEmbeddingModel(ModelConfig modelConfig) {
|
||||
public static EmbeddingModel getEmbeddingModel(ModelConfig modelConfig) {
|
||||
if (modelConfig == null || Objects.isNull(modelConfig.getEmbeddingModel())
|
||||
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getBaseUrl())
|
||||
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getProvider())) {
|
||||
|
||||
@@ -13,14 +13,16 @@ import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class OllamaModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "OLLAMA";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return OllamaChatModel
|
||||
.builder()
|
||||
.baseUrl(chatModel.getBaseUrl())
|
||||
.modelName(chatModel.getModelName())
|
||||
.temperature(chatModel.getTemperature())
|
||||
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
|
||||
.baseUrl(modelConfig.getBaseUrl())
|
||||
.modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
|
||||
@@ -37,6 +39,6 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.OLLAMA, this);
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,15 +13,17 @@ import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "OPEN_AI";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return OpenAiChatModel
|
||||
.builder()
|
||||
.baseUrl(chatModel.getBaseUrl())
|
||||
.modelName(chatModel.getModelName())
|
||||
.apiKey(chatModel.keyDecrypt())
|
||||
.temperature(chatModel.getTemperature())
|
||||
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
|
||||
.baseUrl(modelConfig.getBaseUrl())
|
||||
.modelName(modelConfig.getModelName())
|
||||
.apiKey(modelConfig.keyDecrypt())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
|
||||
@@ -39,6 +41,6 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.OPEN_AI, this);
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
public enum Provider {
|
||||
OPEN_AI,
|
||||
OLLAMA,
|
||||
LOCAL_AI,
|
||||
IN_MEMORY,
|
||||
ZHIPU,
|
||||
AZURE,
|
||||
QIANFAN,
|
||||
DASHSCOPE
|
||||
}
|
||||
@@ -10,8 +10,11 @@ import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class QianfanModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
public static final String PROVIDER = "QIANFAN";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -29,6 +32,6 @@ public class QianfanModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.QIANFAN, this);
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,8 +10,10 @@ import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "ZHIPU";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -29,6 +31,6 @@ public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.ZHIPU, this);
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory;
|
||||
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
|
||||
import dev.langchain4j.milvus.spring.MilvusEmbeddingStoreFactory;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
public class EmbeddingStoreFactoryProvider {
|
||||
public static EmbeddingStoreFactory getFactory(EmbeddingStoreConfig storeConfig) {
|
||||
if (storeConfig == null || StringUtils.isBlank(storeConfig.getProvider())) {
|
||||
return ContextUtils.getBean(EmbeddingStoreFactory.class);
|
||||
}
|
||||
if (EmbeddingStoreType.CHROMA.name().equalsIgnoreCase(storeConfig.getProvider())) {
|
||||
return new ChromaEmbeddingStoreFactory(storeConfig);
|
||||
}
|
||||
if (EmbeddingStoreType.MILVUS.name().equalsIgnoreCase(storeConfig.getProvider())) {
|
||||
return new MilvusEmbeddingStoreFactory(storeConfig);
|
||||
}
|
||||
if (EmbeddingStoreType.IN_MEMORY.name().equalsIgnoreCase(storeConfig.getProvider())) {
|
||||
return new InMemoryEmbeddingStoreFactory(storeConfig);
|
||||
}
|
||||
throw new RuntimeException("Unsupported EmbeddingStore provider: " + storeConfig.getProvider());
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package dev.langchain4j.provider;
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
public enum EmbeddingStoreType {
|
||||
IN_MEMORY,
|
||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.api.pojo.request;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.ModelConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
@@ -27,7 +27,7 @@ public class QueryTextReq {
|
||||
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private ChatModelConfig llmConfig;
|
||||
private ModelConfig modelConfig;
|
||||
private PromptConfig promptConfig;
|
||||
private List<SqlExemplar> dynamicExemplars = Lists.newArrayList();
|
||||
}
|
||||
|
||||
@@ -2,15 +2,14 @@ package com.tencent.supersonic.headless.chat;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.ModelConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
@@ -21,7 +20,6 @@ import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
@@ -52,7 +50,6 @@ public class QueryContext {
|
||||
private WorkflowState workflowState;
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private ModelConfig modelConfig;
|
||||
private ChatModelConfig llmConfig;
|
||||
private PromptConfig promptConfig;
|
||||
private List<SqlExemplar> dynamicExemplars;
|
||||
|
||||
|
||||
@@ -103,7 +103,6 @@ public class LLMRequestService {
|
||||
llmReq.setCurrentDate(DateUtils.getBeforeDate(0));
|
||||
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
||||
llmReq.setModelConfig(queryCtx.getModelConfig());
|
||||
llmReq.setLlmConfig(queryCtx.getLlmConfig());
|
||||
llmReq.setPromptConfig(queryCtx.getPromptConfig());
|
||||
llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars());
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
Map<Prompt, String> prompt2Output = new ConcurrentHashMap<>();
|
||||
prompt2Exemplar.keySet().parallelStream().forEach(prompt -> {
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
String result = response.content().text();
|
||||
prompt2Output.put(prompt, result);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.ModelConfig;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
@@ -23,8 +23,8 @@ public abstract class SqlGenStrategy implements InitializingBean {
|
||||
@Autowired
|
||||
protected PromptHelper promptHelper;
|
||||
|
||||
protected ChatLanguageModel getChatLanguageModel(ChatModelConfig llmConfig) {
|
||||
return ModelProvider.provideChatModel(llmConfig);
|
||||
protected ChatLanguageModel getChatLanguageModel(ModelConfig modelConfig) {
|
||||
return ModelProvider.getChatModel(modelConfig);
|
||||
}
|
||||
|
||||
abstract LLMResp generate(LLMReq llmReq);
|
||||
|
||||
@@ -2,11 +2,10 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonValue;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.ModelConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
@@ -29,7 +28,6 @@ public class LLMReq {
|
||||
private SqlGenType sqlGenType;
|
||||
|
||||
private ModelConfig modelConfig;
|
||||
private ChatModelConfig llmConfig;
|
||||
private PromptConfig promptConfig;
|
||||
|
||||
private List<SqlExemplar> dynamicExemplars;
|
||||
|
||||
@@ -350,4 +350,7 @@ CREATE TABLE IF NOT EXISTS `s2_chat_memory` (
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
|
||||
|
||||
--20240705
|
||||
alter table s2_agent add column `prompt_config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL COMMENT '提示词配置';
|
||||
alter table s2_agent add column `prompt_config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL COMMENT '提示词配置';
|
||||
|
||||
--20240707
|
||||
alter table s2_agent add model_config varchar(6000) null;
|
||||
@@ -374,6 +374,7 @@ CREATE TABLE IF NOT EXISTS s2_agent
|
||||
examples varchar(500) null,
|
||||
config varchar(2000) null,
|
||||
llm_config varchar(2000) null,
|
||||
model_config varchar(6000) null,
|
||||
prompt_config varchar(5000) null,
|
||||
multi_turn_config varchar(2000) null,
|
||||
visual_config varchar(2000) null,
|
||||
|
||||
@@ -73,6 +73,7 @@ CREATE TABLE `s2_agent` (
|
||||
`model` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`model_config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`multi_turn_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`visual_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`created_by` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
|
||||
@@ -87,7 +87,6 @@ public class Text2SQLEval extends BaseTest {
|
||||
agentConfig.getTools().add(getLLMQueryTool());
|
||||
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
||||
agent.setModelConfig(getLLMConfig(LLMType.GPT));
|
||||
agent.setLlmConfig(getLLMConfig(LLMType.GPT).getChatModel());
|
||||
MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
|
||||
multiTurnConfig.setEnableMultiTurn(enableMultiturn);
|
||||
agent.setMultiTurnConfig(multiTurnConfig);
|
||||
|
||||
@@ -374,6 +374,7 @@ CREATE TABLE IF NOT EXISTS s2_agent
|
||||
examples varchar(500) null,
|
||||
config varchar(2000) null,
|
||||
llm_config varchar(2000) null,
|
||||
model_config varchar(6000) null,
|
||||
prompt_config varchar(5000) null,
|
||||
multi_turn_config varchar(2000) null,
|
||||
visual_config varchar(2000) null,
|
||||
@@ -383,7 +384,7 @@ CREATE TABLE IF NOT EXISTS s2_agent
|
||||
updated_at TIMESTAMP null,
|
||||
enable_search int null,
|
||||
PRIMARY KEY (`id`)
|
||||
); COMMENT ON TABLE s2_agent IS 'agent information table';
|
||||
); COMMENT ON TABLE s2_agent IS 'agent information table';
|
||||
|
||||
|
||||
-------demo for semantic and chat
|
||||
|
||||
Reference in New Issue
Block a user