(improvement)(chat) Fix the collectionName issue in the memory agent level. (#1300)

This commit is contained in:
lexluo09
2024-06-30 17:55:43 +08:00
committed by GitHub
parent 8bfd80c2c0
commit b56abd7348
7 changed files with 39 additions and 21 deletions

View File

@@ -1,18 +1,21 @@
package com.tencent.supersonic.chat.server.parser; package com.tencent.supersonic.chat.server.parser;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp;
@@ -24,13 +27,6 @@ import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.provider.ChatLanguageModelProvider; import dev.langchain4j.model.provider.ChatLanguageModelProvider;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@@ -38,8 +34,12 @@ import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.Builder;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE; import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
@Slf4j @Slf4j
public class NL2SQLParser implements ChatParser { public class NL2SQLParser implements ChatParser {
@@ -226,7 +226,9 @@ public class NL2SQLParser implements ChatParser {
private void addExemplars(Integer agentId, QueryReq queryReq) { private void addExemplars(Integer agentId, QueryReq queryReq) {
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class); ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(agentId.toString(), EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
queryReq.getQueryText(), 5); queryReq.getQueryText(), 5);
queryReq.getExemplars().addAll(exemplars); queryReq.getExemplars().addAll(exemplars);
} }

View File

@@ -6,6 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.ExemplarService; import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
@@ -40,7 +41,9 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
public List<SimilarQueryRecallResp> getSimilarQueries(String queryText, Integer agentId) { public List<SimilarQueryRecallResp> getSimilarQueries(String queryText, Integer agentId) {
ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class); ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class);
List<SqlExemplar> exemplars = exemplarService.recallExemplars(agentId.toString(), queryText, 5); EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
List<SqlExemplar> exemplars = exemplarService.recallExemplars(memoryCollectionName, queryText, 5);
return exemplars.stream().map(sqlExemplar -> return exemplars.stream().map(sqlExemplar ->
SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build()) SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build())
.collect(Collectors.toList()); .collect(Collectors.toList());

View File

@@ -11,16 +11,16 @@ import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository; import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.ExemplarService; import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.BeanMapper;
import java.util.List;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.List;
@Service @Service
public class MemoryServiceImpl implements MemoryService { public class MemoryServiceImpl implements MemoryService {
@@ -30,6 +30,9 @@ public class MemoryServiceImpl implements MemoryService {
@Autowired @Autowired
private ExemplarService exemplarService; private ExemplarService exemplarService;
@Autowired
private EmbeddingConfig embeddingConfig;
@Override @Override
public void createMemory(ChatMemoryDO memory) { public void createMemory(ChatMemoryDO memory) {
chatMemoryRepository.createMemory(memory); chatMemoryRepository.createMemory(memory);
@@ -94,7 +97,7 @@ public class MemoryServiceImpl implements MemoryService {
} }
private void enableMemory(ChatMemoryDO memory) { private void enableMemory(ChatMemoryDO memory) {
exemplarService.storeExemplar(memory.getAgentId().toString(), exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
SqlExemplar.builder() SqlExemplar.builder()
.question(memory.getQuestion()) .question(memory.getQuestion())
.dbSchema(memory.getDbSchema()) .dbSchema(memory.getDbSchema())
@@ -103,7 +106,7 @@ public class MemoryServiceImpl implements MemoryService {
} }
private void disableMemory(ChatMemoryDO memory) { private void disableMemory(ChatMemoryDO memory) {
exemplarService.removeExemplar(memory.getAgentId().toString(), exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
SqlExemplar.builder() SqlExemplar.builder()
.question(memory.getQuestion()) .question(memory.getQuestion())
.dbSchema(memory.getDbSchema()) .dbSchema(memory.getDbSchema())

View File

@@ -7,6 +7,10 @@ import org.springframework.context.annotation.Configuration;
@Configuration @Configuration
@Data @Data
public class EmbeddingConfig { public class EmbeddingConfig {
@Value("${s2.embedding.memory.collection.prefix:memory_}")
private String memoryCollectionPrefix;
@Value("${s2.embedding.preset.collection:preset_query_collection}") @Value("${s2.embedding.preset.collection:preset_query_collection}")
private String presetCollection; private String presetCollection;
@@ -25,4 +29,8 @@ public class EmbeddingConfig {
@Value("${s2.embedding.metric.analyzeQuery.nResult:5}") @Value("${s2.embedding.metric.analyzeQuery.nResult:5}")
private int metricAnalyzeQueryResultNum; private int metricAnalyzeQueryResultNum;
public String getMemoryCollectionName(Integer agentId) {
return memoryCollectionPrefix + agentId;
}
} }

View File

@@ -7,9 +7,9 @@ WORKDIR /usr/src/app
# Argument to pass in the supersonic version at build time # Argument to pass in the supersonic version at build time
ARG SUPERSONIC_VERSION ARG SUPERSONIC_VERSION
# Install necessary packages, including MySQL client # Install the Vim editor.
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y default-mysql-client unzip && \ apt-get install -y vim && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
# Copy the supersonic standalone zip file into the container # Copy the supersonic standalone zip file into the container

View File

@@ -42,7 +42,7 @@ services:
- 8.8.4.4 - 8.8.4.4
chroma: chroma:
image: chromadb/chroma:latest image: chromadb/chroma:0.5.3
container_name: supersonic_chroma container_name: supersonic_chroma
ports: ports:
- "8000:8000" - "8000:8000"

View File

@@ -12,8 +12,10 @@ langchain4j:
in-memory: in-memory:
embedding-model: embedding-model:
model-name: bge-small-zh model-name: bge-small-zh
chroma:
embedding-store: embedding-store:
persist-path: /tmp baseUrl: http://0.0.0.0:8000
timeout: 120s
# ollama: # ollama:
# chat-model: # chat-model:
# base-url: http://localhost:11434 # base-url: http://localhost:11434