(improvement)(chat) Fix the collectionName issue in the memory agent level. (#1300)

This commit is contained in:
lexluo09
2024-06-30 17:55:43 +08:00
committed by GitHub
parent 8bfd80c2c0
commit b56abd7348
7 changed files with 39 additions and 21 deletions

View File

@@ -1,18 +1,21 @@
package com.tencent.supersonic.chat.server.parser;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
@@ -24,13 +27,6 @@ import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -38,8 +34,12 @@ import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
@Slf4j
public class NL2SQLParser implements ChatParser {
@@ -226,7 +226,9 @@ public class NL2SQLParser implements ChatParser {
private void addExemplars(Integer agentId, QueryReq queryReq) {
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(agentId.toString(),
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
queryReq.getQueryText(), 5);
queryReq.getExemplars().addAll(exemplars);
}

View File

@@ -6,6 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -40,7 +41,9 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
public List<SimilarQueryRecallResp> getSimilarQueries(String queryText, Integer agentId) {
ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class);
List<SqlExemplar> exemplars = exemplarService.recallExemplars(agentId.toString(), queryText, 5);
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
List<SqlExemplar> exemplars = exemplarService.recallExemplars(memoryCollectionName, queryText, 5);
return exemplars.stream().map(sqlExemplar ->
SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build())
.collect(Collectors.toList());

View File

@@ -11,16 +11,16 @@ import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.common.util.BeanMapper;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.List;
@Service
public class MemoryServiceImpl implements MemoryService {
@@ -30,6 +30,9 @@ public class MemoryServiceImpl implements MemoryService {
@Autowired
private ExemplarService exemplarService;
@Autowired
private EmbeddingConfig embeddingConfig;
@Override
public void createMemory(ChatMemoryDO memory) {
chatMemoryRepository.createMemory(memory);
@@ -94,7 +97,7 @@ public class MemoryServiceImpl implements MemoryService {
}
private void enableMemory(ChatMemoryDO memory) {
exemplarService.storeExemplar(memory.getAgentId().toString(),
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
SqlExemplar.builder()
.question(memory.getQuestion())
.dbSchema(memory.getDbSchema())
@@ -103,7 +106,7 @@ public class MemoryServiceImpl implements MemoryService {
}
private void disableMemory(ChatMemoryDO memory) {
exemplarService.removeExemplar(memory.getAgentId().toString(),
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
SqlExemplar.builder()
.question(memory.getQuestion())
.dbSchema(memory.getDbSchema())

View File

@@ -7,6 +7,10 @@ import org.springframework.context.annotation.Configuration;
@Configuration
@Data
public class EmbeddingConfig {
@Value("${s2.embedding.memory.collection.prefix:memory_}")
private String memoryCollectionPrefix;
@Value("${s2.embedding.preset.collection:preset_query_collection}")
private String presetCollection;
@@ -25,4 +29,8 @@ public class EmbeddingConfig {
@Value("${s2.embedding.metric.analyzeQuery.nResult:5}")
private int metricAnalyzeQueryResultNum;
public String getMemoryCollectionName(Integer agentId) {
return memoryCollectionPrefix + agentId;
}
}

View File

@@ -7,9 +7,9 @@ WORKDIR /usr/src/app
# Argument to pass in the supersonic version at build time
ARG SUPERSONIC_VERSION
# Install necessary packages, including MySQL client
# Install the Vim editor.
RUN apt-get update && \
apt-get install -y default-mysql-client unzip && \
apt-get install -y vim && \
rm -rf /var/lib/apt/lists/*
# Copy the supersonic standalone zip file into the container

View File

@@ -42,7 +42,7 @@ services:
- 8.8.4.4
chroma:
image: chromadb/chroma:latest
image: chromadb/chroma:0.5.3
container_name: supersonic_chroma
ports:
- "8000:8000"

View File

@@ -12,8 +12,10 @@ langchain4j:
in-memory:
embedding-model:
model-name: bge-small-zh
chroma:
embedding-store:
persist-path: /tmp
baseUrl: http://0.0.0.0:8000
timeout: 120s
# ollama:
# chat-model:
# base-url: http://localhost:11434