(improvement)(chat) Support embedding store configuration. (#1363)

This commit is contained in:
lexluo09
2024-07-07 00:30:19 +08:00
committed by GitHub
parent 3f460429e6
commit 4d7bfe07aa
37 changed files with 185 additions and 119 deletions

View File

@@ -4,8 +4,6 @@ package com.tencent.supersonic.chat.server.agent;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig; import com.tencent.supersonic.common.config.VisualConfig;
@@ -35,9 +33,7 @@ public class Agent extends RecordInfo {
private Integer status; private Integer status;
private List<String> examples; private List<String> examples;
private String agentConfig; private String agentConfig;
private ChatModelConfig llmConfig;
private ModelConfig modelConfig; private ModelConfig modelConfig;
private EmbeddingStoreConfig embeddingStore;
private PromptConfig promptConfig; private PromptConfig promptConfig;
private MultiTurnConfig multiTurnConfig; private MultiTurnConfig multiTurnConfig;
private VisualConfig visualConfig; private VisualConfig visualConfig;

View File

@@ -46,7 +46,7 @@ public class PlainTextExecutor implements ChatExecutor {
AgentService agentService = ContextUtils.getBean(AgentService.class); AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId()); Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(chatAgent.getLlmConfig()); ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatAgent.getModelConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage()); Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
QueryResult result = new QueryResult(); QueryResult result = new QueryResult();

View File

@@ -56,8 +56,8 @@ public class MemoryReviewTask {
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr); keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel( ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
chatAgent.getLlmConfig()); chatAgent.getModelConfig());
if (Objects.nonNull(chatLanguageModel)) { if (Objects.nonNull(chatLanguageModel)) {
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text(); String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
keyPipelineLog.info("MemoryReviewTask modelResp:{}", response); keyPipelineLog.info("MemoryReviewTask modelResp:{}", response);

View File

@@ -5,8 +5,8 @@ import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryReposi
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
@@ -168,7 +168,7 @@ public class NL2SQLParser implements ChatParser {
.curtSchema(curtMapStr) .curtSchema(curtMapStr)
.histSchema(histMapStr) .histSchema(histMapStr)
.histSQL(histSQL) .histSQL(histSQL)
.llmConfig(queryTextReq.getLlmConfig()) .modelConfig(queryTextReq.getModelConfig())
.build()); .build());
chatParseContext.setQueryText(rewrittenQuery); chatParseContext.setQueryText(rewrittenQuery);
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
@@ -181,7 +181,7 @@ public class NL2SQLParser implements ChatParser {
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr); keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr);
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(context.getLlmConfig()); ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(context.getModelConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage()); Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String result = response.content().text(); String result = response.content().text();
@@ -243,7 +243,7 @@ public class NL2SQLParser implements ChatParser {
private String curtSchema; private String curtSchema;
private String histSchema; private String histSchema;
private String histSQL; private String histSQL;
private ChatModelConfig llmConfig; private ModelConfig modelConfig;
} }
} }

View File

@@ -11,15 +11,18 @@ import java.util.Date;
@TableName("s2_agent") @TableName("s2_agent")
public class AgentDO { public class AgentDO {
/** /**
*
*/ */
@TableId(type = IdType.AUTO) @TableId(type = IdType.AUTO)
private Integer id; private Integer id;
/** /**
*
*/ */
private String name; private String name;
/** /**
*
*/ */
private String description; private String description;
@@ -29,35 +32,40 @@ public class AgentDO {
private Integer status; private Integer status;
/** /**
*
*/ */
private String examples; private String examples;
/** /**
*
*/ */
private String config; private String config;
/** /**
*
*/ */
private String createdBy; private String createdBy;
/** /**
*
*/ */
private Date createdAt; private Date createdAt;
/** /**
*
*/ */
private String updatedBy; private String updatedBy;
/** /**
*
*/ */
private Date updatedAt; private Date updatedAt;
/** /**
*
*/ */
private Integer enableSearch; private Integer enableSearch;
private String modelConfig;
private String llmConfig;
private String multiTurnConfig; private String multiTurnConfig;
private String visualConfig; private String visualConfig;

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.util.LLMConnHelper; import com.tencent.supersonic.chat.server.util.LLMConnHelper;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.ModelConfig;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.DeleteMapping; import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PathVariable;
@@ -30,16 +30,16 @@ public class AgentController {
@PostMapping @PostMapping
public Agent createAgent(@RequestBody Agent agent, public Agent createAgent(@RequestBody Agent agent,
HttpServletRequest httpServletRequest, HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) { HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse); User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
return agentService.createAgent(agent, user); return agentService.createAgent(agent, user);
} }
@PutMapping @PutMapping
public Agent updateAgent(@RequestBody Agent agent, public Agent updateAgent(@RequestBody Agent agent,
HttpServletRequest httpServletRequest, HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) { HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse); User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
return agentService.updateAgent(agent, user); return agentService.updateAgent(agent, user);
} }
@@ -51,8 +51,8 @@ public class AgentController {
} }
@PostMapping("/testLLMConn") @PostMapping("/testLLMConn")
public boolean testLLMConn(@RequestBody ChatModelConfig llmConfig) { public boolean testLLMConn(@RequestBody ModelConfig modelConfig) {
return LLMConnHelper.testConnection(llmConfig); return LLMConnHelper.testConnection(modelConfig);
} }
@RequestMapping("/getAgentList") @RequestMapping("/getAgentList")

View File

@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatService; import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.LLMConnHelper; import com.tencent.supersonic.chat.server.util.LLMConnHelper;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig; import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
@@ -88,7 +88,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
} }
private synchronized void doExecuteAgentExamples(Agent agent) { private synchronized void doExecuteAgentExamples(Agent agent) {
if (!agent.containsLLMParserTool() || !LLMConnHelper.testConnection(agent.getLlmConfig()) if (!agent.containsLLMParserTool() || !LLMConnHelper.testConnection(agent.getModelConfig())
|| CollectionUtils.isEmpty(agent.getExamples())) { || CollectionUtils.isEmpty(agent.getExamples())) {
return; return;
} }
@@ -122,7 +122,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
BeanUtils.copyProperties(agentDO, agent); BeanUtils.copyProperties(agentDO, agent);
agent.setAgentConfig(agentDO.getConfig()); agent.setAgentConfig(agentDO.getConfig());
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class)); agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
agent.setLlmConfig(JsonUtil.toObject(agentDO.getLlmConfig(), ChatModelConfig.class)); agent.setModelConfig(JsonUtil.toObject(agentDO.getModelConfig(), ModelConfig.class));
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class)); agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class)); agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class)); agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
@@ -134,7 +134,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
BeanUtils.copyProperties(agent, agentDO); BeanUtils.copyProperties(agent, agentDO);
agentDO.setConfig(agent.getAgentConfig()); agentDO.setConfig(agent.getAgentConfig());
agentDO.setExamples(JsonUtil.toString(agent.getExamples())); agentDO.setExamples(JsonUtil.toString(agent.getExamples()));
agentDO.setLlmConfig(JsonUtil.toString(agent.getLlmConfig())); agentDO.setModelConfig(JsonUtil.toString(agent.getModelConfig()));
agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig())); agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig()));
agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig())); agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig()));
agentDO.setPromptConfig(JsonUtil.toString(agent.getPromptConfig())); agentDO.setPromptConfig(JsonUtil.toString(agent.getPromptConfig()));

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.server.util; package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.provider.ModelProvider; import dev.langchain4j.provider.ModelProvider;
@@ -9,12 +9,13 @@ import org.apache.commons.lang3.StringUtils;
@Slf4j @Slf4j
public class LLMConnHelper { public class LLMConnHelper {
public static boolean testConnection(ChatModelConfig chatModel) { public static boolean testConnection(ModelConfig modelConfig) {
try { try {
if (chatModel == null || StringUtils.isBlank(chatModel.getBaseUrl())) { if (modelConfig == null || modelConfig.getChatModel() == null
|| StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) {
return false; return false;
} }
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(chatModel); ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig);
String response = chatLanguageModel.generate("Hi there"); String response = chatLanguageModel.generate("Hi there");
return StringUtils.isNotEmpty(response) ? true : false; return StringUtils.isNotEmpty(response) ? true : false;
} catch (Exception e) { } catch (Exception e) {

View File

@@ -30,7 +30,7 @@ public class QueryReqConverter {
&& MapUtils.isNotEmpty(queryTextReq.getMapInfo().getDataSetElementMatches())) { && MapUtils.isNotEmpty(queryTextReq.getMapInfo().getDataSetElementMatches())) {
queryTextReq.setMapInfo(queryTextReq.getMapInfo()); queryTextReq.setMapInfo(queryTextReq.getMapInfo());
} }
queryTextReq.setLlmConfig(agent.getLlmConfig()); queryTextReq.setModelConfig(agent.getModelConfig());
queryTextReq.setPromptConfig(agent.getPromptConfig()); queryTextReq.setPromptConfig(agent.getPromptConfig());
return queryTextReq; return queryTextReq;
} }

View File

@@ -1,15 +1,23 @@
package dev.langchain4j.chroma.spring; package dev.langchain4j.chroma.spring;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory; import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore; import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import java.time.Duration;
@Slf4j @Slf4j
public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
private Properties properties; private Properties properties;
public ChromaEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
this(createPropertiesFromConfig(storeConfig));
}
public ChromaEmbeddingStoreFactory(Properties properties) { public ChromaEmbeddingStoreFactory(Properties properties) {
this.properties = properties; this.properties = properties;
} }
@@ -23,4 +31,13 @@ public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
.timeout(storeProperties.getTimeout()) .timeout(storeProperties.getTimeout())
.build(); .build();
} }
private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
Properties properties = new Properties();
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
BeanUtils.copyProperties(storeConfig, embeddingStore);
embeddingStore.setTimeout(Duration.ofSeconds(storeConfig.getTimeOut()));
properties.setEmbeddingStore(embeddingStore);
return properties;
}
} }

View File

@@ -7,7 +7,7 @@ import java.time.Duration;
@Getter @Getter
@Setter @Setter
class EmbeddingStoreProperties { public class EmbeddingStoreProperties {
private String baseUrl; private String baseUrl;
private Duration timeout; private Duration timeout;

View File

@@ -1,6 +1,7 @@
package dev.langchain4j.inmemory.spring; package dev.langchain4j.inmemory.spring;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory; import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
@@ -9,6 +10,7 @@ import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils; import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
@@ -23,11 +25,22 @@ public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
public static final String PERSISTENT_FILE_PRE = "InMemory."; public static final String PERSISTENT_FILE_PRE = "InMemory.";
private Properties properties; private Properties properties;
public InMemoryEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
this(createPropertiesFromConfig(storeConfig));
}
public InMemoryEmbeddingStoreFactory(Properties properties) { public InMemoryEmbeddingStoreFactory(Properties properties) {
this.properties = properties; this.properties = properties;
} }
private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
Properties properties = new Properties();
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
BeanUtils.copyProperties(storeConfig, embeddingStore);
properties.setEmbeddingStore(embeddingStore);
return properties;
}
@Override @Override
public synchronized EmbeddingStore createEmbeddingStore(String collectionName) { public synchronized EmbeddingStore createEmbeddingStore(String collectionName) {
InMemoryEmbeddingStore<TextSegment> embeddingStore = reloadFromPersistFile(collectionName); InMemoryEmbeddingStore<TextSegment> embeddingStore = reloadFromPersistFile(collectionName);

View File

@@ -1,17 +1,32 @@
package dev.langchain4j.milvus.spring; package dev.langchain4j.milvus.spring;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory; import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore; import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
import org.springframework.beans.BeanUtils;
public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
private final Properties properties; private final Properties properties;
public MilvusEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
this(createPropertiesFromConfig(storeConfig));
}
public MilvusEmbeddingStoreFactory(Properties properties) { public MilvusEmbeddingStoreFactory(Properties properties) {
this.properties = properties; this.properties = properties;
} }
private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
Properties properties = new Properties();
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
BeanUtils.copyProperties(storeConfig, embeddingStore);
embeddingStore.setUri(storeConfig.getBaseUrl());
properties.setEmbeddingStore(embeddingStore);
return properties;
}
@Override @Override
public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) { public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) {
EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore(); EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore();

View File

@@ -13,15 +13,16 @@ import java.time.Duration;
@Service @Service
public class AzureModelFactory implements ModelFactory, InitializingBean { public class AzureModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "AZURE";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder() AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
.endpoint(chatModel.getBaseUrl()) .endpoint(modelConfig.getBaseUrl())
.apiKey(chatModel.getApiKey()) .apiKey(modelConfig.getApiKey())
.deploymentName(chatModel.getModelName()) .deploymentName(modelConfig.getModelName())
.temperature(chatModel.getTemperature()) .temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(chatModel.getTimeOut() == null ? 0L : chatModel.getTimeOut())); .timeout(Duration.ofSeconds(modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()));
return builder.build(); return builder.build();
} }
@@ -39,6 +40,6 @@ public class AzureModelFactory implements ModelFactory, InitializingBean {
@Override @Override
public void afterPropertiesSet() { public void afterPropertiesSet() {
ModelProvider.add(Provider.AZURE, this); ModelProvider.add(PROVIDER, this);
} }
} }

View File

@@ -11,14 +11,16 @@ import org.springframework.stereotype.Service;
@Service @Service
public class DashscopeModelFactory implements ModelFactory, InitializingBean { public class DashscopeModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "DASHSCOPE";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return QwenChatModel.builder() return QwenChatModel.builder()
.baseUrl(chatModel.getBaseUrl()) .baseUrl(modelConfig.getBaseUrl())
.apiKey(chatModel.getApiKey()) .apiKey(modelConfig.getApiKey())
.modelName(chatModel.getModelName()) .modelName(modelConfig.getModelName())
.temperature(chatModel.getTemperature() == null ? 0L : .temperature(modelConfig.getTemperature() == null ? 0L :
chatModel.getTemperature().floatValue()) modelConfig.getTemperature().floatValue())
.build(); .build();
} }
@@ -32,6 +34,6 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean {
@Override @Override
public void afterPropertiesSet() { public void afterPropertiesSet() {
ModelProvider.add(Provider.DASHSCOPE, this); ModelProvider.add(PROVIDER, this);
} }
} }

View File

@@ -16,9 +16,11 @@ import static dev.langchain4j.inmemory.spring.InMemoryAutoConfig.BGE_SMALL_ZH;
@Service @Service
public class InMemoryModelFactory implements ModelFactory, InitializingBean { public class InMemoryModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "IN_MEMORY";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return null; throw new UnsupportedOperationException("Not supported yet.");
} }
@Override @Override
@@ -40,6 +42,6 @@ public class InMemoryModelFactory implements ModelFactory, InitializingBean {
@Override @Override
public void afterPropertiesSet() { public void afterPropertiesSet() {
ModelProvider.add(Provider.IN_MEMORY, this); ModelProvider.add(PROVIDER, this);
} }
} }

View File

@@ -13,14 +13,16 @@ import java.time.Duration;
@Service @Service
public class LocalAiModelFactory implements ModelFactory, InitializingBean { public class LocalAiModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "LOCAL_AI";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return LocalAiChatModel return LocalAiChatModel
.builder() .builder()
.baseUrl(chatModel.getBaseUrl()) .baseUrl(modelConfig.getBaseUrl())
.modelName(chatModel.getModelName()) .modelName(modelConfig.getModelName())
.temperature(chatModel.getTemperature()) .temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(chatModel.getTimeOut())) .timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.build(); .build();
} }
@@ -34,6 +36,6 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
@Override @Override
public void afterPropertiesSet() { public void afterPropertiesSet() {
ModelProvider.add(Provider.LOCAL_AI, this); ModelProvider.add(PROVIDER, this);
} }
} }

View File

@@ -6,7 +6,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
public interface ModelFactory { public interface ModelFactory {
ChatLanguageModel createChatModel(ChatModelConfig llmConfig); ChatLanguageModel createChatModel(ChatModelConfig modelConfig);
EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel); EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel);
} }

View File

@@ -15,24 +15,11 @@ import java.util.Objects;
public class ModelProvider { public class ModelProvider {
private static final Map<String, ModelFactory> factories = new HashMap<>(); private static final Map<String, ModelFactory> factories = new HashMap<>();
public static void add(Provider provider, ModelFactory modelFactory) { public static void add(String provider, ModelFactory modelFactory) {
factories.put(provider.name(), modelFactory); factories.put(provider, modelFactory);
} }
public static ChatLanguageModel provideChatModel(ChatModelConfig llmConfig) { public static ChatLanguageModel getChatModel(ModelConfig modelConfig) {
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
return ContextUtils.getBean(ChatLanguageModel.class);
}
ModelFactory modelFactory = factories.get(llmConfig.getProvider().toUpperCase());
if (modelFactory != null) {
return modelFactory.createChatModel(llmConfig);
}
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + llmConfig.getProvider());
}
public static ChatLanguageModel provideChatModelNew(ModelConfig modelConfig) {
if (modelConfig == null || modelConfig.getChatModel() == null if (modelConfig == null || modelConfig.getChatModel() == null
|| StringUtils.isBlank(modelConfig.getChatModel().getProvider()) || StringUtils.isBlank(modelConfig.getChatModel().getProvider())
|| StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) { || StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) {
@@ -47,7 +34,7 @@ public class ModelProvider {
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + chatModel.getProvider()); throw new RuntimeException("Unsupported ChatLanguageModel provider: " + chatModel.getProvider());
} }
public static EmbeddingModel provideEmbeddingModel(ModelConfig modelConfig) { public static EmbeddingModel getEmbeddingModel(ModelConfig modelConfig) {
if (modelConfig == null || Objects.isNull(modelConfig.getEmbeddingModel()) if (modelConfig == null || Objects.isNull(modelConfig.getEmbeddingModel())
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getBaseUrl()) || StringUtils.isBlank(modelConfig.getEmbeddingModel().getBaseUrl())
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getProvider())) { || StringUtils.isBlank(modelConfig.getEmbeddingModel().getProvider())) {

View File

@@ -13,14 +13,16 @@ import java.time.Duration;
@Service @Service
public class OllamaModelFactory implements ModelFactory, InitializingBean { public class OllamaModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "OLLAMA";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return OllamaChatModel return OllamaChatModel
.builder() .builder()
.baseUrl(chatModel.getBaseUrl()) .baseUrl(modelConfig.getBaseUrl())
.modelName(chatModel.getModelName()) .modelName(modelConfig.getModelName())
.temperature(chatModel.getTemperature()) .temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(chatModel.getTimeOut())) .timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.build(); .build();
} }
@@ -37,6 +39,6 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean {
@Override @Override
public void afterPropertiesSet() { public void afterPropertiesSet() {
ModelProvider.add(Provider.OLLAMA, this); ModelProvider.add(PROVIDER, this);
} }
} }

View File

@@ -13,15 +13,17 @@ import java.time.Duration;
@Service @Service
public class OpenAiModelFactory implements ModelFactory, InitializingBean { public class OpenAiModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "OPEN_AI";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return OpenAiChatModel return OpenAiChatModel
.builder() .builder()
.baseUrl(chatModel.getBaseUrl()) .baseUrl(modelConfig.getBaseUrl())
.modelName(chatModel.getModelName()) .modelName(modelConfig.getModelName())
.apiKey(chatModel.keyDecrypt()) .apiKey(modelConfig.keyDecrypt())
.temperature(chatModel.getTemperature()) .temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(chatModel.getTimeOut())) .timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.build(); .build();
} }
@@ -39,6 +41,6 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
@Override @Override
public void afterPropertiesSet() { public void afterPropertiesSet() {
ModelProvider.add(Provider.OPEN_AI, this); ModelProvider.add(PROVIDER, this);
} }
} }

View File

@@ -1,12 +0,0 @@
package dev.langchain4j.provider;
public enum Provider {
OPEN_AI,
OLLAMA,
LOCAL_AI,
IN_MEMORY,
ZHIPU,
AZURE,
QIANFAN,
DASHSCOPE
}

View File

@@ -10,8 +10,11 @@ import org.springframework.stereotype.Service;
@Service @Service
public class QianfanModelFactory implements ModelFactory, InitializingBean { public class QianfanModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "QIANFAN";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return null; return null;
} }
@@ -29,6 +32,6 @@ public class QianfanModelFactory implements ModelFactory, InitializingBean {
@Override @Override
public void afterPropertiesSet() { public void afterPropertiesSet() {
ModelProvider.add(Provider.QIANFAN, this); ModelProvider.add(PROVIDER, this);
} }
} }

View File

@@ -10,8 +10,10 @@ import org.springframework.stereotype.Service;
@Service @Service
public class ZhipuModelFactory implements ModelFactory, InitializingBean { public class ZhipuModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "ZHIPU";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return null; return null;
} }
@@ -29,6 +31,6 @@ public class ZhipuModelFactory implements ModelFactory, InitializingBean {
@Override @Override
public void afterPropertiesSet() { public void afterPropertiesSet() {
ModelProvider.add(Provider.ZHIPU, this); ModelProvider.add(PROVIDER, this);
} }
} }

View File

@@ -0,0 +1,26 @@
package dev.langchain4j.store.embedding;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory;
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
import dev.langchain4j.milvus.spring.MilvusEmbeddingStoreFactory;
import org.apache.commons.lang3.StringUtils;
public class EmbeddingStoreFactoryProvider {
public static EmbeddingStoreFactory getFactory(EmbeddingStoreConfig storeConfig) {
if (storeConfig == null || StringUtils.isBlank(storeConfig.getProvider())) {
return ContextUtils.getBean(EmbeddingStoreFactory.class);
}
if (EmbeddingStoreType.CHROMA.name().equalsIgnoreCase(storeConfig.getProvider())) {
return new ChromaEmbeddingStoreFactory(storeConfig);
}
if (EmbeddingStoreType.MILVUS.name().equalsIgnoreCase(storeConfig.getProvider())) {
return new MilvusEmbeddingStoreFactory(storeConfig);
}
if (EmbeddingStoreType.IN_MEMORY.name().equalsIgnoreCase(storeConfig.getProvider())) {
return new InMemoryEmbeddingStoreFactory(storeConfig);
}
throw new RuntimeException("Unsupported EmbeddingStore provider: " + storeConfig.getProvider());
}
}

View File

@@ -1,4 +1,4 @@
package dev.langchain4j.provider; package dev.langchain4j.store.embedding;
public enum EmbeddingStoreType { public enum EmbeddingStoreType {
IN_MEMORY, IN_MEMORY,

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.api.pojo.request;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
@@ -27,7 +27,7 @@ public class QueryTextReq {
private MapModeEnum mapModeEnum = MapModeEnum.STRICT; private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
private SchemaMapInfo mapInfo = new SchemaMapInfo(); private SchemaMapInfo mapInfo = new SchemaMapInfo();
private QueryDataType queryDataType = QueryDataType.ALL; private QueryDataType queryDataType = QueryDataType.ALL;
private ChatModelConfig llmConfig; private ModelConfig modelConfig;
private PromptConfig promptConfig; private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars = Lists.newArrayList(); private List<SqlExemplar> dynamicExemplars = Lists.newArrayList();
} }

View File

@@ -2,15 +2,14 @@ package com.tencent.supersonic.headless.chat;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState; import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
@@ -21,7 +20,6 @@ import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
@@ -52,7 +50,6 @@ public class QueryContext {
private WorkflowState workflowState; private WorkflowState workflowState;
private QueryDataType queryDataType = QueryDataType.ALL; private QueryDataType queryDataType = QueryDataType.ALL;
private ModelConfig modelConfig; private ModelConfig modelConfig;
private ChatModelConfig llmConfig;
private PromptConfig promptConfig; private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars; private List<SqlExemplar> dynamicExemplars;

View File

@@ -103,7 +103,6 @@ public class LLMRequestService {
llmReq.setCurrentDate(DateUtils.getBeforeDate(0)); llmReq.setCurrentDate(DateUtils.getBeforeDate(0));
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE))); llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
llmReq.setModelConfig(queryCtx.getModelConfig()); llmReq.setModelConfig(queryCtx.getModelConfig());
llmReq.setLlmConfig(queryCtx.getLlmConfig());
llmReq.setPromptConfig(queryCtx.getPromptConfig()); llmReq.setPromptConfig(queryCtx.getPromptConfig());
llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars()); llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars());

View File

@@ -55,7 +55,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
Map<Prompt, String> prompt2Output = new ConcurrentHashMap<>(); Map<Prompt, String> prompt2Output = new ConcurrentHashMap<>();
prompt2Exemplar.keySet().parallelStream().forEach(prompt -> { prompt2Exemplar.keySet().parallelStream().forEach(prompt -> {
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage()); keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig()); ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage()); Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String result = response.content().text(); String result = response.content().text();
prompt2Output.put(prompt, result); prompt2Output.put(prompt, result);

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.headless.chat.parser.llm; package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -23,8 +23,8 @@ public abstract class SqlGenStrategy implements InitializingBean {
@Autowired @Autowired
protected PromptHelper promptHelper; protected PromptHelper promptHelper;
protected ChatLanguageModel getChatLanguageModel(ChatModelConfig llmConfig) { protected ChatLanguageModel getChatLanguageModel(ModelConfig modelConfig) {
return ModelProvider.provideChatModel(llmConfig); return ModelProvider.getChatModel(modelConfig);
} }
abstract LLMResp generate(LLMReq llmReq); abstract LLMResp generate(LLMReq llmReq);

View File

@@ -2,11 +2,10 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
import com.fasterxml.jackson.annotation.JsonValue; import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import lombok.Data; import lombok.Data;
import java.util.List; import java.util.List;
@@ -29,7 +28,6 @@ public class LLMReq {
private SqlGenType sqlGenType; private SqlGenType sqlGenType;
private ModelConfig modelConfig; private ModelConfig modelConfig;
private ChatModelConfig llmConfig;
private PromptConfig promptConfig; private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars; private List<SqlExemplar> dynamicExemplars;

View File

@@ -350,4 +350,7 @@ CREATE TABLE IF NOT EXISTS `s2_chat_memory` (
) ENGINE=InnoDB DEFAULT CHARSET=utf8; ) ENGINE=InnoDB DEFAULT CHARSET=utf8;
--20240705 --20240705
alter table s2_agent add column `prompt_config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL COMMENT '提示词配置'; alter table s2_agent add column `prompt_config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL COMMENT '提示词配置';
--20240707
alter table s2_agent add model_config varchar(6000) null;

View File

@@ -374,6 +374,7 @@ CREATE TABLE IF NOT EXISTS s2_agent
examples varchar(500) null, examples varchar(500) null,
config varchar(2000) null, config varchar(2000) null,
llm_config varchar(2000) null, llm_config varchar(2000) null,
model_config varchar(6000) null,
prompt_config varchar(5000) null, prompt_config varchar(5000) null,
multi_turn_config varchar(2000) null, multi_turn_config varchar(2000) null,
visual_config varchar(2000) null, visual_config varchar(2000) null,

View File

@@ -73,6 +73,7 @@ CREATE TABLE `s2_agent` (
`model` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL, `model` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL,
`config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL, `config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL,
`llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL, `llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
`model_config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL,
`multi_turn_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL, `multi_turn_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
`visual_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL, `visual_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
`created_by` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL, `created_by` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL,

View File

@@ -87,7 +87,6 @@ public class Text2SQLEval extends BaseTest {
agentConfig.getTools().add(getLLMQueryTool()); agentConfig.getTools().add(getLLMQueryTool());
agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
agent.setModelConfig(getLLMConfig(LLMType.GPT)); agent.setModelConfig(getLLMConfig(LLMType.GPT));
agent.setLlmConfig(getLLMConfig(LLMType.GPT).getChatModel());
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(); MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
multiTurnConfig.setEnableMultiTurn(enableMultiturn); multiTurnConfig.setEnableMultiTurn(enableMultiturn);
agent.setMultiTurnConfig(multiTurnConfig); agent.setMultiTurnConfig(multiTurnConfig);

View File

@@ -374,6 +374,7 @@ CREATE TABLE IF NOT EXISTS s2_agent
examples varchar(500) null, examples varchar(500) null,
config varchar(2000) null, config varchar(2000) null,
llm_config varchar(2000) null, llm_config varchar(2000) null,
model_config varchar(6000) null,
prompt_config varchar(5000) null, prompt_config varchar(5000) null,
multi_turn_config varchar(2000) null, multi_turn_config varchar(2000) null,
visual_config varchar(2000) null, visual_config varchar(2000) null,
@@ -383,7 +384,7 @@ CREATE TABLE IF NOT EXISTS s2_agent
updated_at TIMESTAMP null, updated_at TIMESTAMP null,
enable_search int null, enable_search int null,
PRIMARY KEY (`id`) PRIMARY KEY (`id`)
); COMMENT ON TABLE s2_agent IS 'agent information table'; ); COMMENT ON TABLE s2_agent IS 'agent information table';
-------demo for semantic and chat -------demo for semantic and chat