mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
(improvement)(chat) Fix the collectionName issue in the memory agent level. (#1300)
This commit is contained in:
@@ -1,18 +1,21 @@
|
|||||||
package com.tencent.supersonic.chat.server.parser;
|
package com.tencent.supersonic.chat.server.parser;
|
||||||
|
|
||||||
|
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
|
||||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||||
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.config.LLMConfig;
|
import com.tencent.supersonic.common.config.LLMConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
|
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||||
@@ -24,13 +27,6 @@ import dev.langchain4j.model.input.Prompt;
|
|||||||
import dev.langchain4j.model.input.PromptTemplate;
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -38,8 +34,12 @@ import java.util.Objects;
|
|||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.Builder;
|
||||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
import lombok.Data;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class NL2SQLParser implements ChatParser {
|
public class NL2SQLParser implements ChatParser {
|
||||||
@@ -226,7 +226,9 @@ public class NL2SQLParser implements ChatParser {
|
|||||||
|
|
||||||
private void addExemplars(Integer agentId, QueryReq queryReq) {
|
private void addExemplars(Integer agentId, QueryReq queryReq) {
|
||||||
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
|
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
|
||||||
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(agentId.toString(),
|
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||||
|
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
||||||
|
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
|
||||||
queryReq.getQueryText(), 5);
|
queryReq.getQueryText(), 5);
|
||||||
queryReq.getExemplars().addAll(exemplars);
|
queryReq.getExemplars().addAll(exemplars);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
|||||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||||
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
import com.tencent.supersonic.common.service.ExemplarService;
|
import com.tencent.supersonic.common.service.ExemplarService;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
@@ -40,7 +41,9 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
|
|||||||
|
|
||||||
public List<SimilarQueryRecallResp> getSimilarQueries(String queryText, Integer agentId) {
|
public List<SimilarQueryRecallResp> getSimilarQueries(String queryText, Integer agentId) {
|
||||||
ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class);
|
ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class);
|
||||||
List<SqlExemplar> exemplars = exemplarService.recallExemplars(agentId.toString(), queryText, 5);
|
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||||
|
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
||||||
|
List<SqlExemplar> exemplars = exemplarService.recallExemplars(memoryCollectionName, queryText, 5);
|
||||||
return exemplars.stream().map(sqlExemplar ->
|
return exemplars.stream().map(sqlExemplar ->
|
||||||
SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build())
|
SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build())
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|||||||
@@ -11,16 +11,16 @@ import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
|
|||||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
|
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
|
||||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||||
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
import com.tencent.supersonic.common.service.ExemplarService;
|
import com.tencent.supersonic.common.service.ExemplarService;
|
||||||
import com.tencent.supersonic.common.util.BeanMapper;
|
import com.tencent.supersonic.common.util.BeanMapper;
|
||||||
|
import java.util.List;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
public class MemoryServiceImpl implements MemoryService {
|
public class MemoryServiceImpl implements MemoryService {
|
||||||
|
|
||||||
@@ -30,6 +30,9 @@ public class MemoryServiceImpl implements MemoryService {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private ExemplarService exemplarService;
|
private ExemplarService exemplarService;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private EmbeddingConfig embeddingConfig;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void createMemory(ChatMemoryDO memory) {
|
public void createMemory(ChatMemoryDO memory) {
|
||||||
chatMemoryRepository.createMemory(memory);
|
chatMemoryRepository.createMemory(memory);
|
||||||
@@ -94,7 +97,7 @@ public class MemoryServiceImpl implements MemoryService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void enableMemory(ChatMemoryDO memory) {
|
private void enableMemory(ChatMemoryDO memory) {
|
||||||
exemplarService.storeExemplar(memory.getAgentId().toString(),
|
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||||
SqlExemplar.builder()
|
SqlExemplar.builder()
|
||||||
.question(memory.getQuestion())
|
.question(memory.getQuestion())
|
||||||
.dbSchema(memory.getDbSchema())
|
.dbSchema(memory.getDbSchema())
|
||||||
@@ -103,7 +106,7 @@ public class MemoryServiceImpl implements MemoryService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void disableMemory(ChatMemoryDO memory) {
|
private void disableMemory(ChatMemoryDO memory) {
|
||||||
exemplarService.removeExemplar(memory.getAgentId().toString(),
|
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||||
SqlExemplar.builder()
|
SqlExemplar.builder()
|
||||||
.question(memory.getQuestion())
|
.question(memory.getQuestion())
|
||||||
.dbSchema(memory.getDbSchema())
|
.dbSchema(memory.getDbSchema())
|
||||||
|
|||||||
@@ -7,6 +7,10 @@ import org.springframework.context.annotation.Configuration;
|
|||||||
@Configuration
|
@Configuration
|
||||||
@Data
|
@Data
|
||||||
public class EmbeddingConfig {
|
public class EmbeddingConfig {
|
||||||
|
|
||||||
|
@Value("${s2.embedding.memory.collection.prefix:memory_}")
|
||||||
|
private String memoryCollectionPrefix;
|
||||||
|
|
||||||
@Value("${s2.embedding.preset.collection:preset_query_collection}")
|
@Value("${s2.embedding.preset.collection:preset_query_collection}")
|
||||||
private String presetCollection;
|
private String presetCollection;
|
||||||
|
|
||||||
@@ -25,4 +29,8 @@ public class EmbeddingConfig {
|
|||||||
@Value("${s2.embedding.metric.analyzeQuery.nResult:5}")
|
@Value("${s2.embedding.metric.analyzeQuery.nResult:5}")
|
||||||
private int metricAnalyzeQueryResultNum;
|
private int metricAnalyzeQueryResultNum;
|
||||||
|
|
||||||
|
public String getMemoryCollectionName(Integer agentId) {
|
||||||
|
return memoryCollectionPrefix + agentId;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ WORKDIR /usr/src/app
|
|||||||
# Argument to pass in the supersonic version at build time
|
# Argument to pass in the supersonic version at build time
|
||||||
ARG SUPERSONIC_VERSION
|
ARG SUPERSONIC_VERSION
|
||||||
|
|
||||||
# Install necessary packages, including MySQL client
|
# Install the Vim editor.
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y default-mysql-client unzip && \
|
apt-get install -y vim && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy the supersonic standalone zip file into the container
|
# Copy the supersonic standalone zip file into the container
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ services:
|
|||||||
- 8.8.4.4
|
- 8.8.4.4
|
||||||
|
|
||||||
chroma:
|
chroma:
|
||||||
image: chromadb/chroma:latest
|
image: chromadb/chroma:0.5.3
|
||||||
container_name: supersonic_chroma
|
container_name: supersonic_chroma
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
|
|||||||
@@ -12,8 +12,10 @@ langchain4j:
|
|||||||
in-memory:
|
in-memory:
|
||||||
embedding-model:
|
embedding-model:
|
||||||
model-name: bge-small-zh
|
model-name: bge-small-zh
|
||||||
|
chroma:
|
||||||
embedding-store:
|
embedding-store:
|
||||||
persist-path: /tmp
|
baseUrl: http://0.0.0.0:8000
|
||||||
|
timeout: 120s
|
||||||
# ollama:
|
# ollama:
|
||||||
# chat-model:
|
# chat-model:
|
||||||
# base-url: http://localhost:11434
|
# base-url: http://localhost:11434
|
||||||
|
|||||||
Reference in New Issue
Block a user