mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 05:43:51 +00:00
(improvement)(Chat) Put agent examples into ChatMemory automatically (#1249)
Co-authored-by: lxwcodemonkey
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.enums;
|
||||
|
||||
|
||||
public enum MemoryReviewResult {
|
||||
|
||||
POSITIVE,
|
||||
NEGATIVE
|
||||
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.enums;
|
||||
|
||||
|
||||
public enum MemoryStatus {
|
||||
|
||||
PENDING,
|
||||
ENABLED,
|
||||
DISABLED;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class ChatMemoryFilter {
|
||||
|
||||
private String question;
|
||||
|
||||
private List<String> questions;
|
||||
|
||||
private MemoryStatus status;
|
||||
|
||||
private MemoryReviewResult llmReviewRet;
|
||||
|
||||
private MemoryReviewResult humanReviewRet;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.PageBaseReq;
|
||||
import lombok.Data;
|
||||
|
||||
|
||||
@Data
|
||||
public class PageMemoryReq extends PageBaseReq {
|
||||
|
||||
private ChatMemoryFilter chatMemoryFilter = new ChatMemoryFilter();
|
||||
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
@@ -32,7 +33,7 @@ public class SqlExecutor implements ChatExecutor {
|
||||
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
|
||||
memoryService.createMemory(ChatMemoryDO.builder()
|
||||
.agentId(chatExecuteContext.getAgentId())
|
||||
.status(ChatMemoryDO.Status.PENDING)
|
||||
.status(MemoryStatus.PENDING)
|
||||
.question(chatExecuteContext.getQueryText())
|
||||
.s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getS2SQL())
|
||||
.dbSchema(buildSchemaStr(chatExecuteContext.getParseInfo()))
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package com.tencent.supersonic.chat.server.memory;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Component;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class AgentExample2MemoryTransformer {
|
||||
|
||||
@Autowired
|
||||
private ChatService chatService;
|
||||
|
||||
@Autowired
|
||||
private MemoryService memoryService;
|
||||
|
||||
@Async
|
||||
public void transform(Agent agent) {
|
||||
if (!agent.containsLLMParserTool() || agent.getLlmConfig() == null) {
|
||||
return;
|
||||
}
|
||||
List<String> examples = agent.getExamples();
|
||||
ChatMemoryFilter chatMemoryFilter = ChatMemoryFilter.builder().questions(examples).build();
|
||||
List<String> memoriesExisted = memoryService.getMemories(chatMemoryFilter)
|
||||
.stream().map(ChatMemoryDO::getQuestion).collect(Collectors.toList());
|
||||
for (String example : examples) {
|
||||
if (memoriesExisted.contains(example)) {
|
||||
continue;
|
||||
}
|
||||
ChatParseReq chatParseReq = new ChatParseReq();
|
||||
chatParseReq.setAgentId(agent.getId());
|
||||
chatParseReq.setQueryText(example);
|
||||
chatParseReq.setUser(User.getFakeUser());
|
||||
chatParseReq.setChatId(-1);
|
||||
chatService.parseAndExecute(chatParseReq);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,8 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.memory;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.ReviewResult;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.common.util.S2ChatModelProvider;
|
||||
@@ -48,8 +47,7 @@ public class MemoryReviewTask {
|
||||
|
||||
@Scheduled(fixedDelay = 60 * 1000)
|
||||
public void review() {
|
||||
memoryService.getMemories().stream()
|
||||
.filter(c -> c.getStatus() == ChatMemoryDO.Status.PENDING)
|
||||
memoryService.getMemoriesForLlmReview().stream()
|
||||
.forEach(m -> {
|
||||
Agent chatAgent = agentService.getAgent(m.getAgentId());
|
||||
String promptStr = String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getS2sql());
|
||||
@@ -62,7 +60,7 @@ public class MemoryReviewTask {
|
||||
|
||||
Matcher matcher = OUTPUT_PATTERN.matcher(response);
|
||||
if (matcher.find()) {
|
||||
m.setLlmReviewRet(ReviewResult.valueOf(matcher.group(1)));
|
||||
m.setLlmReviewRet(MemoryReviewResult.valueOf(matcher.group(1)));
|
||||
m.setLlmReviewCmt(matcher.group(2));
|
||||
memoryService.updateMemory(m);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import com.baomidou.mybatisplus.annotation.IdType;
|
||||
import com.baomidou.mybatisplus.annotation.TableField;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
@@ -31,16 +33,16 @@ public class ChatMemoryDO {
|
||||
private String s2sql;
|
||||
|
||||
@TableField("status")
|
||||
private Status status;
|
||||
private MemoryStatus status;
|
||||
|
||||
@TableField("llm_review")
|
||||
private ReviewResult llmReviewRet;
|
||||
private MemoryReviewResult llmReviewRet;
|
||||
|
||||
@TableField("llm_comment")
|
||||
private String llmReviewCmt;
|
||||
|
||||
@TableField("human_review")
|
||||
private ReviewResult humanReviewRet;
|
||||
private MemoryReviewResult humanReviewRet;
|
||||
|
||||
@TableField("human_comment")
|
||||
private String humanReviewCmt;
|
||||
@@ -57,14 +59,4 @@ public class ChatMemoryDO {
|
||||
@TableField("updated_at")
|
||||
private Date updatedAt;
|
||||
|
||||
public enum ReviewResult {
|
||||
POSITIVE,
|
||||
NEGATIVE
|
||||
}
|
||||
|
||||
public enum Status {
|
||||
PENDING,
|
||||
ENABLED,
|
||||
DISABLED;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository;
|
||||
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
|
||||
import java.util.List;
|
||||
@@ -11,5 +12,5 @@ public interface ChatMemoryRepository {
|
||||
|
||||
ChatMemoryDO getMemory(Long id);
|
||||
|
||||
List<ChatMemoryDO> getMemories();
|
||||
List<ChatMemoryDO> getMemories(QueryWrapper<ChatMemoryDO> queryWrapper);
|
||||
}
|
||||
|
||||
@@ -35,8 +35,8 @@ public class ChatMemoryRepositoryImpl implements ChatMemoryRepository {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMemoryDO> getMemories() {
|
||||
return chatMemoryMapper.selectList(new QueryWrapper<>());
|
||||
public List<ChatMemoryDO> getMemories(QueryWrapper<ChatMemoryDO> queryWrapper) {
|
||||
return chatMemoryMapper.selectList(queryWrapper);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
package com.tencent.supersonic.chat.server.rest;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.util.Date;
|
||||
|
||||
@RestController
|
||||
@RequestMapping({"/api/chat/memory"})
|
||||
public class MemoryController {
|
||||
|
||||
@Autowired
|
||||
private MemoryService memoryService;
|
||||
|
||||
@PostMapping("/createMemory")
|
||||
public Boolean createMemory(@RequestBody ChatMemoryDO memory,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
memory.setCreatedBy(user.getName());
|
||||
memory.setUpdatedBy(user.getName());
|
||||
memory.setCreatedAt(new Date());
|
||||
memory.setUpdatedAt(new Date());
|
||||
memoryService.createMemory(memory);
|
||||
return true;
|
||||
}
|
||||
|
||||
@PostMapping("/updateMemory")
|
||||
public void updateMemory(ChatMemoryDO memory,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
memory.setUpdatedBy(user.getName());
|
||||
memory.setUpdatedAt(new Date());
|
||||
memoryService.updateMemory(memory);
|
||||
}
|
||||
|
||||
@RequestMapping("/pageMemories")
|
||||
public PageInfo<ChatMemoryDO> pageMemories(@RequestBody PageMemoryReq pageMemoryReq) {
|
||||
return memoryService.pageMemories(pageMemoryReq);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -20,6 +20,8 @@ public interface ChatService {
|
||||
|
||||
QueryResult performExecution(ChatExecuteReq chatExecuteReq) throws Exception;
|
||||
|
||||
QueryResult parseAndExecute(ChatParseReq chatParseReq);
|
||||
|
||||
Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception;
|
||||
|
||||
SemanticParseInfo queryContext(Integer chatId);
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
|
||||
import java.util.List;
|
||||
@@ -9,5 +12,9 @@ public interface MemoryService {
|
||||
|
||||
void updateMemory(ChatMemoryDO memory);
|
||||
|
||||
List<ChatMemoryDO> getMemories();
|
||||
PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq);
|
||||
|
||||
List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter);
|
||||
|
||||
List<ChatMemoryDO> getMemoriesForLlmReview();
|
||||
}
|
||||
|
||||
@@ -4,13 +4,15 @@ import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.memory.AgentExample2MemoryTransformer;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.config.VisualConfig;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
@@ -19,6 +21,9 @@ import java.util.stream.Collectors;
|
||||
public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
implements AgentService {
|
||||
|
||||
@Autowired
|
||||
private AgentExample2MemoryTransformer agentExample2MemoryTransformer;
|
||||
|
||||
@Override
|
||||
public List<Agent> getAgents() {
|
||||
return getAgentDOList().stream()
|
||||
@@ -30,6 +35,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
agent.createdBy(user.getName());
|
||||
AgentDO agentDO = convert(agent);
|
||||
save(agentDO);
|
||||
agentExample2MemoryTransformer.transform(agent);
|
||||
return agentDO.getId();
|
||||
}
|
||||
|
||||
@@ -37,6 +43,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
public void updateAgent(Agent agent, User user) {
|
||||
agent.updatedBy(user.getName());
|
||||
updateById(convert(agent));
|
||||
agentExample2MemoryTransformer.transform(agent);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -95,6 +95,19 @@ public class ChatServiceImpl implements ChatService {
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult parseAndExecute(ChatParseReq chatParseReq) {
|
||||
ParseResp parseResp = performParsing(chatParseReq);
|
||||
ChatExecuteReq chatExecuteReq = new ChatExecuteReq();
|
||||
chatExecuteReq.setQueryId(parseResp.getQueryId());
|
||||
chatExecuteReq.setChatId(chatParseReq.getChatId());
|
||||
chatExecuteReq.setUser(chatParseReq.getUser());
|
||||
chatExecuteReq.setAgentId(chatParseReq.getAgentId());
|
||||
chatExecuteReq.setQueryText(chatParseReq.getQueryText());
|
||||
chatExecuteReq.setParseId(parseResp.getSelectedParses().get(0).getId());
|
||||
return performExecution(chatExecuteReq);
|
||||
}
|
||||
|
||||
private ChatParseContext buildParseContext(ChatParseReq chatParseReq) {
|
||||
ChatParseContext chatParseContext = new ChatParseContext();
|
||||
BeanMapper.mapper(chatParseReq, chatParseContext);
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.github.pagehelper.PageHelper;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.ReviewResult;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.Status;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -23,24 +29,53 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
|
||||
@Override
|
||||
public void createMemory(ChatMemoryDO memory) {
|
||||
if (ReviewResult.POSITIVE.equals(memory.getHumanReviewRet())) {
|
||||
enableMemory(memory);
|
||||
}
|
||||
chatMemoryRepository.createMemory(memory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateMemory(ChatMemoryDO memory) {
|
||||
if (!ChatMemoryDO.Status.ENABLED.equals(memory.getStatus())
|
||||
&& ReviewResult.POSITIVE.equals(memory.getHumanReviewRet())) {
|
||||
if (MemoryStatus.ENABLED.equals(memory.getStatus())) {
|
||||
enableMemory(memory);
|
||||
} else if (MemoryStatus.DISABLED.equals(memory.getStatus())) {
|
||||
disableMemory(memory);
|
||||
}
|
||||
chatMemoryRepository.updateMemory(memory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMemoryDO> getMemories() {
|
||||
return chatMemoryRepository.getMemories();
|
||||
public PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq) {
|
||||
return PageHelper.startPage(pageMemoryReq.getCurrent(),
|
||||
pageMemoryReq.getPageSize())
|
||||
.doSelectPageInfo(() -> getMemories(pageMemoryReq.getChatMemoryFilter()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter) {
|
||||
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
|
||||
if (StringUtils.isNotBlank(chatMemoryFilter.getQuestion())) {
|
||||
queryWrapper.lambda().like(ChatMemoryDO::getQuestion, chatMemoryFilter.getQuestion());
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(chatMemoryFilter.getQuestions())) {
|
||||
queryWrapper.lambda().in(ChatMemoryDO::getQuestion, chatMemoryFilter.getQuestions());
|
||||
}
|
||||
if (chatMemoryFilter.getStatus() != null) {
|
||||
queryWrapper.lambda().eq(ChatMemoryDO::getStatus, chatMemoryFilter.getStatus());
|
||||
}
|
||||
if (chatMemoryFilter.getHumanReviewRet() != null) {
|
||||
queryWrapper.lambda().eq(ChatMemoryDO::getHumanReviewRet, chatMemoryFilter.getHumanReviewRet());
|
||||
}
|
||||
if (chatMemoryFilter.getLlmReviewRet() != null) {
|
||||
queryWrapper.lambda().eq(ChatMemoryDO::getLlmReviewRet, chatMemoryFilter.getLlmReviewRet());
|
||||
}
|
||||
return chatMemoryRepository.getMemories(queryWrapper);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMemoryDO> getMemoriesForLlmReview() {
|
||||
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
|
||||
queryWrapper.lambda().eq(ChatMemoryDO::getStatus, MemoryStatus.PENDING)
|
||||
.isNull(ChatMemoryDO::getLlmReviewRet);
|
||||
return chatMemoryRepository.getMemories(queryWrapper);
|
||||
}
|
||||
|
||||
private void enableMemory(ChatMemoryDO memory) {
|
||||
@@ -50,6 +85,15 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
.dbSchema(memory.getDbSchema())
|
||||
.sql(memory.getS2sql())
|
||||
.build());
|
||||
memory.setStatus(Status.ENABLED);
|
||||
}
|
||||
|
||||
private void disableMemory(ChatMemoryDO memory) {
|
||||
exemplarService.removeExemplar(memory.getAgentId().toString(),
|
||||
SqlExemplar.builder()
|
||||
.question(memory.getQuestion())
|
||||
.dbSchema(memory.getDbSchema())
|
||||
.sql(memory.getS2sql())
|
||||
.build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import java.util.List;
|
||||
public interface ExemplarService {
|
||||
void storeExemplar(String collection, SqlExemplar exemplar);
|
||||
|
||||
void removeExemplar(String collection, SqlExemplar exemplar);
|
||||
|
||||
List<SqlExemplar> recallExemplars(String collection, String query, int num);
|
||||
|
||||
List<SqlExemplar> recallExemplars(String query, int num);
|
||||
|
||||
@@ -49,6 +49,14 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
||||
embeddingService.addQuery(collection, Lists.newArrayList(segment));
|
||||
}
|
||||
|
||||
public void removeExemplar(String collection, SqlExemplar exemplar) {
|
||||
Metadata metadata = Metadata.from(JsonUtil.toMap(JsonUtil.toString(exemplar),
|
||||
String.class, Object.class));
|
||||
TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata);
|
||||
|
||||
embeddingService.deleteQuery(collection, Lists.newArrayList(segment));
|
||||
}
|
||||
|
||||
public List<SqlExemplar> recallExemplars(String query, int num) {
|
||||
String collection = embeddingConfig.getText2sqlCollectionName();
|
||||
return recallExemplars(collection, query, num);
|
||||
|
||||
@@ -329,3 +329,22 @@ alter table s2_agent add column `visual_config` varchar(2000) COLLATE utf8_unic
|
||||
|
||||
alter table s2_term add column `related_metrics` varchar(1000) DEFAULT NULL COMMENT '术语关联的指标';
|
||||
alter table s2_term add column `related_dimensions` varchar(1000) DEFAULT NULL COMMENT '术语关联的维度';
|
||||
|
||||
--20240627
|
||||
CREATE TABLE IF NOT EXISTS `s2_chat_memory` (
|
||||
`id` INT NOT NULL AUTO_INCREMENT,
|
||||
`question` varchar(655) ,
|
||||
`agent_id` INT ,
|
||||
`db_schema` TEXT ,
|
||||
`s2_sql` TEXT ,
|
||||
`status` char(10) ,
|
||||
`llm_review` char(10) ,
|
||||
`llm_comment` TEXT,
|
||||
`human_review` char(10) ,
|
||||
`human_comment` TEXT ,
|
||||
`created_at` datetime NOT NULL ,
|
||||
`updated_at` datetime NOT NULL ,
|
||||
`created_by` varchar(100) NOT NULL ,
|
||||
`updated_by` varchar(100) NOT NULL ,
|
||||
PRIMARY KEY (`id`)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
|
||||
@@ -147,8 +147,8 @@ CREATE TABLE IF NOT EXISTS `s2_chat_memory` (
|
||||
`llm_comment` TEXT,
|
||||
`human_review` char(10) ,
|
||||
`human_comment` TEXT ,
|
||||
`created_at` TIMESTAMP NOT NULL ,
|
||||
`updated_at` TIMESTAMP NOT NULL ,
|
||||
`created_at` datetime NOT NULL ,
|
||||
`updated_at` datetime NOT NULL ,
|
||||
`created_by` varchar(100) NOT NULL ,
|
||||
`updated_by` varchar(100) NOT NULL ,
|
||||
PRIMARY KEY (`id`)
|
||||
|
||||
Reference in New Issue
Block a user