diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index ab13a879a..6ad549d98 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -1,18 +1,21 @@ package com.tencent.supersonic.chat.server.parser; +import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE; + import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; -import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.util.QueryReqConverter; +import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.LLMConfig; +import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.response.MapResp; @@ -24,13 +27,6 @@ import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.provider.ChatLanguageModelProvider; -import lombok.Builder; -import lombok.Data; -import lombok.extern.slf4j.Slf4j; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.util.CollectionUtils; - import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -38,8 +34,12 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; - -import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE; +import lombok.Builder; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.util.CollectionUtils; @Slf4j public class NL2SQLParser implements ChatParser { @@ -226,7 +226,9 @@ public class NL2SQLParser implements ChatParser { private void addExemplars(Integer agentId, QueryReq queryReq) { ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class); - List exemplars = exemplarManager.recallExemplars(agentId.toString(), + EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); + String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId); + List exemplars = exemplarManager.recallExemplars(memoryCollectionName, queryReq.getQueryText(), 5); queryReq.getExemplars().addAll(exemplars); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java index 7a95d237f..dc7c8b7a8 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java @@ -6,6 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.pojo.ChatParseContext; +import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.service.ExemplarService; import com.tencent.supersonic.common.util.ContextUtils; @@ -40,7 +41,9 @@ public class QueryRecommendProcessor implements ParseResultProcessor { public List getSimilarQueries(String queryText, Integer agentId) { ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class); - List exemplars = exemplarService.recallExemplars(agentId.toString(), queryText, 5); + EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); + String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId); + List exemplars = exemplarService.recallExemplars(memoryCollectionName, queryText, 5); return exemplars.stream().map(sqlExemplar -> SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build()) .collect(Collectors.toList()); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java index 49fb9aa4d..a666a705b 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java @@ -11,16 +11,16 @@ import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository; import com.tencent.supersonic.chat.server.service.MemoryService; +import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.service.ExemplarService; import com.tencent.supersonic.common.util.BeanMapper; +import java.util.List; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; -import java.util.List; - @Service public class MemoryServiceImpl implements MemoryService { @@ -30,6 +30,9 @@ public class MemoryServiceImpl implements MemoryService { @Autowired private ExemplarService exemplarService; + @Autowired + private EmbeddingConfig embeddingConfig; + @Override public void createMemory(ChatMemoryDO memory) { chatMemoryRepository.createMemory(memory); @@ -94,7 +97,7 @@ public class MemoryServiceImpl implements MemoryService { } private void enableMemory(ChatMemoryDO memory) { - exemplarService.storeExemplar(memory.getAgentId().toString(), + exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()), SqlExemplar.builder() .question(memory.getQuestion()) .dbSchema(memory.getDbSchema()) @@ -103,7 +106,7 @@ public class MemoryServiceImpl implements MemoryService { } private void disableMemory(ChatMemoryDO memory) { - exemplarService.removeExemplar(memory.getAgentId().toString(), + exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()), SqlExemplar.builder() .question(memory.getQuestion()) .dbSchema(memory.getDbSchema()) diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java index 6b47eb4cc..3d1134605 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java @@ -7,6 +7,10 @@ import org.springframework.context.annotation.Configuration; @Configuration @Data public class EmbeddingConfig { + + @Value("${s2.embedding.memory.collection.prefix:memory_}") + private String memoryCollectionPrefix; + @Value("${s2.embedding.preset.collection:preset_query_collection}") private String presetCollection; @@ -25,4 +29,8 @@ public class EmbeddingConfig { @Value("${s2.embedding.metric.analyzeQuery.nResult:5}") private int metricAnalyzeQueryResultNum; + public String getMemoryCollectionName(Integer agentId) { + return memoryCollectionPrefix + agentId; + } + } diff --git a/docker/Dockerfile b/docker/Dockerfile index 432a9a299..eb9c22e93 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -7,9 +7,9 @@ WORKDIR /usr/src/app # Argument to pass in the supersonic version at build time ARG SUPERSONIC_VERSION -# Install necessary packages, including MySQL client +# Install the Vim editor. RUN apt-get update && \ - apt-get install -y default-mysql-client unzip && \ + apt-get install -y vim && \ rm -rf /var/lib/apt/lists/* # Copy the supersonic standalone zip file into the container diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index d3e5c1c62..aa057bd5d 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -42,7 +42,7 @@ services: - 8.8.4.4 chroma: - image: chromadb/chroma:latest + image: chromadb/chroma:0.5.3 container_name: supersonic_chroma ports: - "8000:8000" diff --git a/launchers/standalone/src/main/resources/langchain4j-config.yaml b/launchers/standalone/src/main/resources/langchain4j-config.yaml index 38e2da7a9..4aa14dc5a 100644 --- a/launchers/standalone/src/main/resources/langchain4j-config.yaml +++ b/launchers/standalone/src/main/resources/langchain4j-config.yaml @@ -12,8 +12,10 @@ langchain4j: in-memory: embedding-model: model-name: bge-small-zh + chroma: embedding-store: - persist-path: /tmp + baseUrl: http://0.0.0.0:8000 + timeout: 120s # ollama: # chat-model: # base-url: http://localhost:11434