mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 21:17:08 +00:00
(feature)(chat&common)Introduce ChatMemory module to support dynamic few-shot exemplars.#1097
This commit is contained in:
@@ -1,13 +1,18 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.ResultFormatter;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||
import lombok.SneakyThrows;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
public class SqlExecutor implements ChatExecutor {
|
||||
|
||||
@@ -21,6 +26,18 @@ public class SqlExecutor implements ChatExecutor {
|
||||
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
|
||||
queryResult.getQueryResults());
|
||||
queryResult.setTextResult(textResult);
|
||||
|
||||
if (queryResult.getQueryState().equals(QueryState.SUCCESS)
|
||||
&& queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
||||
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
|
||||
memoryService.createMemory(ChatMemoryDO.builder()
|
||||
.agentId(chatExecuteContext.getAgentId())
|
||||
.status(ChatMemoryDO.Status.PENDING)
|
||||
.question(chatExecuteContext.getQueryText())
|
||||
.s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getS2SQL())
|
||||
.schema(buildSchemaStr(chatExecuteContext.getParseInfo()))
|
||||
.build());
|
||||
}
|
||||
}
|
||||
|
||||
return queryResult;
|
||||
@@ -38,4 +55,36 @@ public class SqlExecutor implements ChatExecutor {
|
||||
.build();
|
||||
}
|
||||
|
||||
public String buildSchemaStr(SemanticParseInfo parseInfo) {
|
||||
String tableStr = parseInfo.getDataSet().getName();
|
||||
StringBuilder metricStr = new StringBuilder();
|
||||
StringBuilder dimensionStr = new StringBuilder();
|
||||
|
||||
parseInfo.getMetrics().stream().forEach(
|
||||
metric -> {
|
||||
metricStr.append(metric.getName());
|
||||
if (StringUtils.isNotEmpty(metric.getDescription())) {
|
||||
metricStr.append(" COMMENT '" + metric.getDescription() + "'");
|
||||
}
|
||||
if (StringUtils.isNotEmpty(metric.getDefaultAgg())) {
|
||||
metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'");
|
||||
}
|
||||
metricStr.append(",");
|
||||
}
|
||||
);
|
||||
|
||||
parseInfo.getDimensions().stream().forEach(
|
||||
dimension -> {
|
||||
dimensionStr.append(dimension.getName());
|
||||
if (StringUtils.isNotEmpty(dimension.getDescription())) {
|
||||
dimensionStr.append(" COMMENT '" + dimension.getDescription() + "'");
|
||||
}
|
||||
dimensionStr.append(",");
|
||||
}
|
||||
);
|
||||
|
||||
String template = "Table: %s, Metrics: [%s], Dimensions: [%s]";
|
||||
return String.format(template, tableStr, metricStr, dimensionStr);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
package com.tencent.supersonic.chat.server.memory;
|
||||
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.ReviewResult;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.common.util.S2ChatModelProvider;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.scheduling.annotation.Scheduled;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class MemoryReviewTask {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
private static final String INSTRUCTION = ""
|
||||
+ "#Role: You are a senior data engineer experienced in writing SQL.\n"
|
||||
+ "#Task: Your will be provided with a user question and the SQL written by junior engineer,"
|
||||
+ "please take a review and give your opinion.\n"
|
||||
+ "#Rules: "
|
||||
+ "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`."
|
||||
+ "2.DO NOT check the usage of `数据日期` field and `datediff()` function.\n"
|
||||
+ "#Question: %s\n"
|
||||
+ "#Schema: %s\n"
|
||||
+ "#SQL: %s\n"
|
||||
+ "#Response: ";
|
||||
|
||||
private static final Pattern OUTPUT_PATTERN = Pattern.compile("opinion=(.*),.*comment=(.*)");
|
||||
|
||||
@Autowired
|
||||
private MemoryService memoryService;
|
||||
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
|
||||
@Scheduled(fixedDelay = 60 * 1000)
|
||||
public void review() {
|
||||
memoryService.getMemories().stream()
|
||||
.filter(c -> c.getStatus() == ChatMemoryDO.Status.PENDING)
|
||||
.forEach(m -> {
|
||||
Agent chatAgent = agentService.getAgent(m.getAgentId());
|
||||
String promptStr = String.format(INSTRUCTION, m.getQuestion(), m.getSchema(), m.getS2sql());
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
|
||||
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);
|
||||
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(chatAgent.getLlmConfig());
|
||||
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
|
||||
keyPipelineLog.info("MemoryReviewTask modelResp:{}", response);
|
||||
|
||||
Matcher matcher = OUTPUT_PATTERN.matcher(response);
|
||||
if (matcher.find()) {
|
||||
m.setLlmReviewRet(ReviewResult.valueOf(matcher.group(1)));
|
||||
m.setLlmReviewCmt(matcher.group(2));
|
||||
memoryService.updateMemory(m);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
@@ -12,6 +13,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
@@ -64,9 +66,11 @@ public class NL2SQLParser implements ChatParser {
|
||||
if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
|
||||
return;
|
||||
}
|
||||
processMultiTurn(chatParseContext, parseResp);
|
||||
processMultiTurn(chatParseContext);
|
||||
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
addExemplars(chatParseContext.getAgent().getId(), queryReq);
|
||||
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
|
||||
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
|
||||
@@ -131,7 +135,7 @@ public class NL2SQLParser implements ChatParser {
|
||||
parseInfo.setTextInfo(textBuilder.toString());
|
||||
}
|
||||
|
||||
private void processMultiTurn(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
private void processMultiTurn(ChatParseContext chatParseContext) {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
|
||||
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
||||
@@ -220,6 +224,13 @@ public class NL2SQLParser implements ChatParser {
|
||||
return contextualList;
|
||||
}
|
||||
|
||||
private void addExemplars(Integer agentId, QueryReq queryReq) {
|
||||
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
|
||||
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(agentId.toString(),
|
||||
queryReq.getQueryText(), 5);
|
||||
queryReq.getExemplars().addAll(exemplars);
|
||||
}
|
||||
|
||||
@Builder
|
||||
@Data
|
||||
public static class RewriteContext {
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.IdType;
|
||||
import com.baomidou.mybatisplus.annotation.TableField;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@ToString
|
||||
@TableName("s2_chat_memory")
|
||||
public class ChatMemoryDO {
|
||||
@TableId(type = IdType.AUTO)
|
||||
private Long id;
|
||||
|
||||
@TableField("question")
|
||||
private String question;
|
||||
|
||||
@TableField("agent_id")
|
||||
private Integer agentId;
|
||||
|
||||
@TableField("db_schema")
|
||||
private String schema;
|
||||
|
||||
@TableField("s2_sql")
|
||||
private String s2sql;
|
||||
|
||||
@TableField("status")
|
||||
private Status status;
|
||||
|
||||
@TableField("llm_review")
|
||||
private ReviewResult llmReviewRet;
|
||||
|
||||
@TableField("llm_comment")
|
||||
private String llmReviewCmt;
|
||||
|
||||
@TableField("human_review")
|
||||
private ReviewResult humanReviewRet;
|
||||
|
||||
@TableField("human_comment")
|
||||
private String humanReviewCmt;
|
||||
|
||||
@TableField("created_by")
|
||||
private String createdBy;
|
||||
|
||||
@TableField("created_at")
|
||||
private Date createdAt;
|
||||
|
||||
@TableField("updated_by")
|
||||
private String updatedBy;
|
||||
|
||||
@TableField("updated_at")
|
||||
private Date updatedAt;
|
||||
|
||||
public enum ReviewResult {
|
||||
POSITIVE,
|
||||
NEGATIVE
|
||||
}
|
||||
|
||||
public enum Status {
|
||||
PENDING,
|
||||
ENABLED,
|
||||
DISABLED;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.mapper;
|
||||
|
||||
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
@Mapper
|
||||
public interface ChatMemoryMapper extends BaseMapper<ChatMemoryDO> {
|
||||
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository;
|
||||
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
|
||||
public interface ChatContextRepository {
|
||||
|
||||
ChatContext getOrCreateContext(int chatId);
|
||||
|
||||
void updateContext(ChatContext chatCtx);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface ChatMemoryRepository {
|
||||
void createMemory(ChatMemoryDO chatMemoryDO);
|
||||
|
||||
void updateMemory(ChatMemoryDO chatMemoryDO);
|
||||
|
||||
ChatMemoryDO getMemory(Long id);
|
||||
|
||||
List<ChatMemoryDO> getMemories();
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository.impl;
|
||||
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.ChatMemoryMapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Repository
|
||||
@Primary
|
||||
public class ChatMemoryRepositoryImpl implements ChatMemoryRepository {
|
||||
|
||||
private final ChatMemoryMapper chatMemoryMapper;
|
||||
|
||||
public ChatMemoryRepositoryImpl(ChatMemoryMapper chatMemoryMapper) {
|
||||
this.chatMemoryMapper = chatMemoryMapper;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createMemory(ChatMemoryDO chatMemoryDO) {
|
||||
chatMemoryMapper.insert(chatMemoryDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateMemory(ChatMemoryDO chatMemoryDO) {
|
||||
chatMemoryMapper.updateById(chatMemoryDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatMemoryDO getMemory(Long id) {
|
||||
return chatMemoryMapper.selectById(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMemoryDO> getMemories() {
|
||||
return chatMemoryMapper.selectList(new QueryWrapper<>());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface MemoryService {
|
||||
void createMemory(ChatMemoryDO memory);
|
||||
|
||||
void updateMemory(ChatMemoryDO memory);
|
||||
|
||||
List<ChatMemoryDO> getMemories();
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.ReviewResult;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.Status;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Service
|
||||
public class MemoryServiceImpl implements MemoryService {
|
||||
|
||||
@Autowired
|
||||
private ChatMemoryRepository chatMemoryRepository;
|
||||
|
||||
@Autowired
|
||||
private ExemplarService exemplarService;
|
||||
|
||||
@Override
|
||||
public void createMemory(ChatMemoryDO memory) {
|
||||
if (ReviewResult.POSITIVE.equals(memory.getHumanReviewRet())) {
|
||||
enableMemory(memory);
|
||||
}
|
||||
chatMemoryRepository.createMemory(memory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateMemory(ChatMemoryDO memory) {
|
||||
if (!ChatMemoryDO.Status.ENABLED.equals(memory.getStatus())
|
||||
&& ReviewResult.POSITIVE.equals(memory.getHumanReviewRet())) {
|
||||
enableMemory(memory);
|
||||
}
|
||||
chatMemoryRepository.updateMemory(memory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMemoryDO> getMemories() {
|
||||
return chatMemoryRepository.getMemories();
|
||||
}
|
||||
|
||||
private void enableMemory(ChatMemoryDO memory) {
|
||||
exemplarService.storeExemplar(memory.getAgentId().toString(),
|
||||
SqlExemplar.builder()
|
||||
.question(memory.getQuestion())
|
||||
.dbSchema(memory.getSchema())
|
||||
.sql(memory.getS2sql())
|
||||
.build());
|
||||
memory.setStatus(Status.ENABLED);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user