(feature)(chat&common)Introduce ChatMemory module to support dynamic few-shot exemplars.#1097

This commit is contained in:
jerryjzhang
2024-06-27 10:19:59 +08:00
parent 7c711f6105
commit a655110f5f
28 changed files with 561 additions and 153 deletions

View File

@@ -1,13 +1,18 @@
package com.tencent.supersonic.chat.server.executor;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.ResultFormatter;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import lombok.SneakyThrows;
import org.apache.commons.lang3.StringUtils;
public class SqlExecutor implements ChatExecutor {
@@ -21,6 +26,18 @@ public class SqlExecutor implements ChatExecutor {
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
queryResult.getQueryResults());
queryResult.setTextResult(textResult);
if (queryResult.getQueryState().equals(QueryState.SUCCESS)
&& queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
memoryService.createMemory(ChatMemoryDO.builder()
.agentId(chatExecuteContext.getAgentId())
.status(ChatMemoryDO.Status.PENDING)
.question(chatExecuteContext.getQueryText())
.s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getS2SQL())
.schema(buildSchemaStr(chatExecuteContext.getParseInfo()))
.build());
}
}
return queryResult;
@@ -38,4 +55,36 @@ public class SqlExecutor implements ChatExecutor {
.build();
}
public String buildSchemaStr(SemanticParseInfo parseInfo) {
String tableStr = parseInfo.getDataSet().getName();
StringBuilder metricStr = new StringBuilder();
StringBuilder dimensionStr = new StringBuilder();
parseInfo.getMetrics().stream().forEach(
metric -> {
metricStr.append(metric.getName());
if (StringUtils.isNotEmpty(metric.getDescription())) {
metricStr.append(" COMMENT '" + metric.getDescription() + "'");
}
if (StringUtils.isNotEmpty(metric.getDefaultAgg())) {
metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'");
}
metricStr.append(",");
}
);
parseInfo.getDimensions().stream().forEach(
dimension -> {
dimensionStr.append(dimension.getName());
if (StringUtils.isNotEmpty(dimension.getDescription())) {
dimensionStr.append(" COMMENT '" + dimension.getDescription() + "'");
}
dimensionStr.append(",");
}
);
String template = "Table: %s, Metrics: [%s], Dimensions: [%s]";
return String.format(template, tableStr, metricStr, dimensionStr);
}
}

View File

@@ -0,0 +1,71 @@
package com.tencent.supersonic.chat.server.memory;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.ReviewResult;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.util.S2ChatModelProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import java.util.Collections;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@Component
@Slf4j
public class MemoryReviewTask {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
private static final String INSTRUCTION = ""
+ "#Role: You are a senior data engineer experienced in writing SQL.\n"
+ "#Task: Your will be provided with a user question and the SQL written by junior engineer,"
+ "please take a review and give your opinion.\n"
+ "#Rules: "
+ "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`."
+ "2.DO NOT check the usage of `数据日期` field and `datediff()` function.\n"
+ "#Question: %s\n"
+ "#Schema: %s\n"
+ "#SQL: %s\n"
+ "#Response: ";
private static final Pattern OUTPUT_PATTERN = Pattern.compile("opinion=(.*),.*comment=(.*)");
@Autowired
private MemoryService memoryService;
@Autowired
private AgentService agentService;
@Scheduled(fixedDelay = 60 * 1000)
public void review() {
memoryService.getMemories().stream()
.filter(c -> c.getStatus() == ChatMemoryDO.Status.PENDING)
.forEach(m -> {
Agent chatAgent = agentService.getAgent(m.getAgentId());
String promptStr = String.format(INSTRUCTION, m.getQuestion(), m.getSchema(), m.getS2sql());
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(chatAgent.getLlmConfig());
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
keyPipelineLog.info("MemoryReviewTask modelResp:{}", response);
Matcher matcher = OUTPUT_PATTERN.matcher(response);
if (matcher.find()) {
m.setLlmReviewRet(ReviewResult.valueOf(matcher.group(1)));
m.setLlmReviewCmt(matcher.group(2));
memoryService.updateMemory(m);
}
});
}
}

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
@@ -12,6 +13,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
@@ -64,9 +66,11 @@ public class NL2SQLParser implements ChatParser {
if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
return;
}
processMultiTurn(chatParseContext, parseResp);
processMultiTurn(chatParseContext);
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
addExemplars(chatParseContext.getAgent().getId(), queryReq);
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
@@ -131,7 +135,7 @@ public class NL2SQLParser implements ChatParser {
parseInfo.setTextInfo(textBuilder.toString());
}
private void processMultiTurn(ChatParseContext chatParseContext, ParseResp parseResp) {
private void processMultiTurn(ChatParseContext chatParseContext) {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
@@ -220,6 +224,13 @@ public class NL2SQLParser implements ChatParser {
return contextualList;
}
private void addExemplars(Integer agentId, QueryReq queryReq) {
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(agentId.toString(),
queryReq.getQueryText(), 5);
queryReq.getExemplars().addAll(exemplars);
}
@Builder
@Data
public static class RewriteContext {

View File

@@ -0,0 +1,70 @@
package com.tencent.supersonic.chat.server.persistence.dataobject;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Builder;
import lombok.Data;
import lombok.ToString;
import java.util.Date;
@Data
@Builder
@ToString
@TableName("s2_chat_memory")
public class ChatMemoryDO {
@TableId(type = IdType.AUTO)
private Long id;
@TableField("question")
private String question;
@TableField("agent_id")
private Integer agentId;
@TableField("db_schema")
private String schema;
@TableField("s2_sql")
private String s2sql;
@TableField("status")
private Status status;
@TableField("llm_review")
private ReviewResult llmReviewRet;
@TableField("llm_comment")
private String llmReviewCmt;
@TableField("human_review")
private ReviewResult humanReviewRet;
@TableField("human_comment")
private String humanReviewCmt;
@TableField("created_by")
private String createdBy;
@TableField("created_at")
private Date createdAt;
@TableField("updated_by")
private String updatedBy;
@TableField("updated_at")
private Date updatedAt;
public enum ReviewResult {
POSITIVE,
NEGATIVE
}
public enum Status {
PENDING,
ENABLED,
DISABLED;
}
}

View File

@@ -0,0 +1,10 @@
package com.tencent.supersonic.chat.server.persistence.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface ChatMemoryMapper extends BaseMapper<ChatMemoryDO> {
}

View File

@@ -1,11 +0,0 @@
package com.tencent.supersonic.chat.server.persistence.repository;
import com.tencent.supersonic.headless.chat.ChatContext;
public interface ChatContextRepository {
ChatContext getOrCreateContext(int chatId);
void updateContext(ChatContext chatCtx);
}

View File

@@ -0,0 +1,15 @@
package com.tencent.supersonic.chat.server.persistence.repository;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import java.util.List;
public interface ChatMemoryRepository {
void createMemory(ChatMemoryDO chatMemoryDO);
void updateMemory(ChatMemoryDO chatMemoryDO);
ChatMemoryDO getMemory(Long id);
List<ChatMemoryDO> getMemories();
}

View File

@@ -0,0 +1,42 @@
package com.tencent.supersonic.chat.server.persistence.repository.impl;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.mapper.ChatMemoryMapper;
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;
import java.util.List;
@Repository
@Primary
public class ChatMemoryRepositoryImpl implements ChatMemoryRepository {
private final ChatMemoryMapper chatMemoryMapper;
public ChatMemoryRepositoryImpl(ChatMemoryMapper chatMemoryMapper) {
this.chatMemoryMapper = chatMemoryMapper;
}
@Override
public void createMemory(ChatMemoryDO chatMemoryDO) {
chatMemoryMapper.insert(chatMemoryDO);
}
@Override
public void updateMemory(ChatMemoryDO chatMemoryDO) {
chatMemoryMapper.updateById(chatMemoryDO);
}
@Override
public ChatMemoryDO getMemory(Long id) {
return chatMemoryMapper.selectById(id);
}
@Override
public List<ChatMemoryDO> getMemories() {
return chatMemoryMapper.selectList(new QueryWrapper<>());
}
}

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.server.service;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import java.util.List;
public interface MemoryService {
void createMemory(ChatMemoryDO memory);
void updateMemory(ChatMemoryDO memory);
List<ChatMemoryDO> getMemories();
}

View File

@@ -0,0 +1,55 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.ReviewResult;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.Status;
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.ExemplarService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.List;
@Service
public class MemoryServiceImpl implements MemoryService {
@Autowired
private ChatMemoryRepository chatMemoryRepository;
@Autowired
private ExemplarService exemplarService;
@Override
public void createMemory(ChatMemoryDO memory) {
if (ReviewResult.POSITIVE.equals(memory.getHumanReviewRet())) {
enableMemory(memory);
}
chatMemoryRepository.createMemory(memory);
}
@Override
public void updateMemory(ChatMemoryDO memory) {
if (!ChatMemoryDO.Status.ENABLED.equals(memory.getStatus())
&& ReviewResult.POSITIVE.equals(memory.getHumanReviewRet())) {
enableMemory(memory);
}
chatMemoryRepository.updateMemory(memory);
}
@Override
public List<ChatMemoryDO> getMemories() {
return chatMemoryRepository.getMemories();
}
private void enableMemory(ChatMemoryDO memory) {
exemplarService.storeExemplar(memory.getAgentId().toString(),
SqlExemplar.builder()
.question(memory.getQuestion())
.dbSchema(memory.getSchema())
.sql(memory.getS2sql())
.build());
memory.setStatus(Status.ENABLED);
}
}