mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 13:47:09 +00:00
(improvement)(chat) Fix the collectionName issue in the memory agent level. (#1300)
This commit is contained in:
@@ -1,18 +1,21 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
@@ -24,13 +27,6 @@ import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
@@ -38,8 +34,12 @@ import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class NL2SQLParser implements ChatParser {
|
||||
@@ -226,7 +226,9 @@ public class NL2SQLParser implements ChatParser {
|
||||
|
||||
private void addExemplars(Integer agentId, QueryReq queryReq) {
|
||||
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
|
||||
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(agentId.toString(),
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
||||
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
|
||||
queryReq.getQueryText(), 5);
|
||||
queryReq.getExemplars().addAll(exemplars);
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -40,7 +41,9 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
|
||||
public List<SimilarQueryRecallResp> getSimilarQueries(String queryText, Integer agentId) {
|
||||
ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class);
|
||||
List<SqlExemplar> exemplars = exemplarService.recallExemplars(agentId.toString(), queryText, 5);
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
||||
List<SqlExemplar> exemplars = exemplarService.recallExemplars(memoryCollectionName, queryText, 5);
|
||||
return exemplars.stream().map(sqlExemplar ->
|
||||
SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
@@ -11,16 +11,16 @@ import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Service
|
||||
public class MemoryServiceImpl implements MemoryService {
|
||||
|
||||
@@ -30,6 +30,9 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
@Autowired
|
||||
private ExemplarService exemplarService;
|
||||
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
@Override
|
||||
public void createMemory(ChatMemoryDO memory) {
|
||||
chatMemoryRepository.createMemory(memory);
|
||||
@@ -94,7 +97,7 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
}
|
||||
|
||||
private void enableMemory(ChatMemoryDO memory) {
|
||||
exemplarService.storeExemplar(memory.getAgentId().toString(),
|
||||
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
SqlExemplar.builder()
|
||||
.question(memory.getQuestion())
|
||||
.dbSchema(memory.getDbSchema())
|
||||
@@ -103,7 +106,7 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
}
|
||||
|
||||
private void disableMemory(ChatMemoryDO memory) {
|
||||
exemplarService.removeExemplar(memory.getAgentId().toString(),
|
||||
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
SqlExemplar.builder()
|
||||
.question(memory.getQuestion())
|
||||
.dbSchema(memory.getDbSchema())
|
||||
|
||||
Reference in New Issue
Block a user