From a655110f5f1d80bb41442c2c2c5f0afcf89e893d Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Thu, 27 Jun 2024 10:19:59 +0800 Subject: [PATCH] (feature)(chat&common)Introduce ChatMemory module to support dynamic few-shot exemplars.#1097 --- .../chat/server/executor/SqlExecutor.java | 49 +++++++++ .../chat/server/memory/MemoryReviewTask.java | 71 ++++++++++++ .../chat/server/parser/NL2SQLParser.java | 15 ++- .../persistence/dataobject/ChatMemoryDO.java | 70 ++++++++++++ .../persistence/mapper/ChatMemoryMapper.java | 10 ++ .../repository/ChatContextRepository.java | 11 -- .../repository/ChatMemoryRepository.java | 15 +++ .../impl/ChatMemoryRepositoryImpl.java | 42 ++++++++ .../chat/server/service/MemoryService.java | 13 +++ .../service/impl/MemoryServiceImpl.java | 55 ++++++++++ .../supersonic/common/pojo/SqlExemplar.java | 20 ++++ .../common/service/ExemplarService.java | 14 +++ .../service/impl/ExemplarServiceImpl.java | 87 +++++++++++++++ .../headless/api/pojo/request/QueryReq.java | 4 + .../headless/chat/QueryContext.java | 2 + .../headless/chat/parser/llm/Exemplar.java | 16 --- .../chat/parser/llm/ExemplarManager.java | 101 ------------------ .../chat/parser/llm/LLMRequestService.java | 2 + .../parser/llm/OnePassSCSqlGenStrategy.java | 22 ++-- .../chat/parser/llm/OutputFormat.java | 3 +- .../chat/parser/llm/PromptHelper.java | 25 +++-- .../headless/chat/query/llm/s2sql/LLMReq.java | 4 + .../chat/query/llm/s2sql/LLMResp.java | 2 + .../chat/query/llm/s2sql/LLMSqlResp.java | 5 +- .../tencent/supersonic/demo/S2BaseDemo.java | 1 + .../src/main/resources/db/schema-h2.sql | 19 ++++ .../src/main/resources/db/schema-mysql.sql | 18 ++++ .../src/test/resources/db/schema-h2.sql | 18 ++++ 28 files changed, 561 insertions(+), 153 deletions(-) create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatMemoryMapper.java delete mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatContextRepository.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatMemoryRepository.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatMemoryRepositoryImpl.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java create mode 100644 common/src/main/java/com/tencent/supersonic/common/pojo/SqlExemplar.java create mode 100644 common/src/main/java/com/tencent/supersonic/common/service/ExemplarService.java create mode 100644 common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java delete mode 100644 headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/Exemplar.java delete mode 100644 headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ExemplarManager.java diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java index db173ed91..a91077128 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java @@ -1,13 +1,18 @@ package com.tencent.supersonic.chat.server.executor; +import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext; +import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.util.ResultFormatter; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.headless.api.pojo.response.QueryResult; +import com.tencent.supersonic.headless.api.pojo.response.QueryState; +import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.headless.server.facade.service.ChatQueryService; import lombok.SneakyThrows; +import org.apache.commons.lang3.StringUtils; public class SqlExecutor implements ChatExecutor { @@ -21,6 +26,18 @@ public class SqlExecutor implements ChatExecutor { String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(), queryResult.getQueryResults()); queryResult.setTextResult(textResult); + + if (queryResult.getQueryState().equals(QueryState.SUCCESS) + && queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) { + MemoryService memoryService = ContextUtils.getBean(MemoryService.class); + memoryService.createMemory(ChatMemoryDO.builder() + .agentId(chatExecuteContext.getAgentId()) + .status(ChatMemoryDO.Status.PENDING) + .question(chatExecuteContext.getQueryText()) + .s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getS2SQL()) + .schema(buildSchemaStr(chatExecuteContext.getParseInfo())) + .build()); + } } return queryResult; @@ -38,4 +55,36 @@ public class SqlExecutor implements ChatExecutor { .build(); } + public String buildSchemaStr(SemanticParseInfo parseInfo) { + String tableStr = parseInfo.getDataSet().getName(); + StringBuilder metricStr = new StringBuilder(); + StringBuilder dimensionStr = new StringBuilder(); + + parseInfo.getMetrics().stream().forEach( + metric -> { + metricStr.append(metric.getName()); + if (StringUtils.isNotEmpty(metric.getDescription())) { + metricStr.append(" COMMENT '" + metric.getDescription() + "'"); + } + if (StringUtils.isNotEmpty(metric.getDefaultAgg())) { + metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'"); + } + metricStr.append(","); + } + ); + + parseInfo.getDimensions().stream().forEach( + dimension -> { + dimensionStr.append(dimension.getName()); + if (StringUtils.isNotEmpty(dimension.getDescription())) { + dimensionStr.append(" COMMENT '" + dimension.getDescription() + "'"); + } + dimensionStr.append(","); + } + ); + + String template = "Table: %s, Metrics: [%s], Dimensions: [%s]"; + return String.format(template, tableStr, metricStr, dimensionStr); + } + } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java new file mode 100644 index 000000000..74482da4e --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java @@ -0,0 +1,71 @@ +package com.tencent.supersonic.chat.server.memory; + +import com.tencent.supersonic.chat.server.agent.Agent; +import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; +import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.ReviewResult; +import com.tencent.supersonic.chat.server.service.AgentService; +import com.tencent.supersonic.chat.server.service.MemoryService; +import com.tencent.supersonic.common.util.S2ChatModelProvider; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.input.Prompt; +import dev.langchain4j.model.input.PromptTemplate; +import lombok.extern.slf4j.Slf4j; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.scheduling.annotation.Scheduled; +import org.springframework.stereotype.Component; + +import java.util.Collections; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +@Component +@Slf4j +public class MemoryReviewTask { + + private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); + + private static final String INSTRUCTION = "" + + "#Role: You are a senior data engineer experienced in writing SQL.\n" + + "#Task: Your will be provided with a user question and the SQL written by junior engineer," + + "please take a review and give your opinion.\n" + + "#Rules: " + + "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`." + + "2.DO NOT check the usage of `数据日期` field and `datediff()` function.\n" + + "#Question: %s\n" + + "#Schema: %s\n" + + "#SQL: %s\n" + + "#Response: "; + + private static final Pattern OUTPUT_PATTERN = Pattern.compile("opinion=(.*),.*comment=(.*)"); + + @Autowired + private MemoryService memoryService; + + @Autowired + private AgentService agentService; + + @Scheduled(fixedDelay = 60 * 1000) + public void review() { + memoryService.getMemories().stream() + .filter(c -> c.getStatus() == ChatMemoryDO.Status.PENDING) + .forEach(m -> { + Agent chatAgent = agentService.getAgent(m.getAgentId()); + String promptStr = String.format(INSTRUCTION, m.getQuestion(), m.getSchema(), m.getS2sql()); + Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); + + keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr); + ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(chatAgent.getLlmConfig()); + String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text(); + keyPipelineLog.info("MemoryReviewTask modelResp:{}", response); + + Matcher matcher = OUTPUT_PATTERN.matcher(response); + if (matcher.find()) { + m.setLlmReviewRet(ReviewResult.valueOf(matcher.group(1))); + m.setLlmReviewCmt(matcher.group(2)); + memoryService.updateMemory(m); + } + }); + } +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index d229f276b..cb8628cf1 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.chat.server.parser; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; +import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.pojo.ChatParseContext; @@ -12,6 +13,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.response.MapResp; @@ -64,9 +66,11 @@ public class NL2SQLParser implements ChatParser { if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) { return; } - processMultiTurn(chatParseContext, parseResp); + processMultiTurn(chatParseContext); QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); + addExemplars(chatParseContext.getAgent().getId(), queryReq); + ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq); if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) { @@ -131,7 +135,7 @@ public class NL2SQLParser implements ChatParser { parseInfo.setTextInfo(textBuilder.toString()); } - private void processMultiTurn(ChatParseContext chatParseContext, ParseResp parseResp) { + private void processMultiTurn(ChatParseContext chatParseContext) { ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig(); Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE)); @@ -220,6 +224,13 @@ public class NL2SQLParser implements ChatParser { return contextualList; } + private void addExemplars(Integer agentId, QueryReq queryReq) { + ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class); + List exemplars = exemplarManager.recallExemplars(agentId.toString(), + queryReq.getQueryText(), 5); + queryReq.getExemplars().addAll(exemplars); + } + @Builder @Data public static class RewriteContext { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java new file mode 100644 index 000000000..984eb3405 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java @@ -0,0 +1,70 @@ +package com.tencent.supersonic.chat.server.persistence.dataobject; + +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableField; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.Builder; +import lombok.Data; +import lombok.ToString; + +import java.util.Date; + +@Data +@Builder +@ToString +@TableName("s2_chat_memory") +public class ChatMemoryDO { + @TableId(type = IdType.AUTO) + private Long id; + + @TableField("question") + private String question; + + @TableField("agent_id") + private Integer agentId; + + @TableField("db_schema") + private String schema; + + @TableField("s2_sql") + private String s2sql; + + @TableField("status") + private Status status; + + @TableField("llm_review") + private ReviewResult llmReviewRet; + + @TableField("llm_comment") + private String llmReviewCmt; + + @TableField("human_review") + private ReviewResult humanReviewRet; + + @TableField("human_comment") + private String humanReviewCmt; + + @TableField("created_by") + private String createdBy; + + @TableField("created_at") + private Date createdAt; + + @TableField("updated_by") + private String updatedBy; + + @TableField("updated_at") + private Date updatedAt; + + public enum ReviewResult { + POSITIVE, + NEGATIVE + } + + public enum Status { + PENDING, + ENABLED, + DISABLED; + } +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatMemoryMapper.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatMemoryMapper.java new file mode 100644 index 000000000..b9be150f4 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatMemoryMapper.java @@ -0,0 +1,10 @@ +package com.tencent.supersonic.chat.server.persistence.mapper; + +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; +import org.apache.ibatis.annotations.Mapper; + +@Mapper +public interface ChatMemoryMapper extends BaseMapper { + +} \ No newline at end of file diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatContextRepository.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatContextRepository.java deleted file mode 100644 index 501ae3499..000000000 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatContextRepository.java +++ /dev/null @@ -1,11 +0,0 @@ -package com.tencent.supersonic.chat.server.persistence.repository; - -import com.tencent.supersonic.headless.chat.ChatContext; - -public interface ChatContextRepository { - - ChatContext getOrCreateContext(int chatId); - - void updateContext(ChatContext chatCtx); - -} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatMemoryRepository.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatMemoryRepository.java new file mode 100644 index 000000000..053699a52 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatMemoryRepository.java @@ -0,0 +1,15 @@ +package com.tencent.supersonic.chat.server.persistence.repository; + +import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; + +import java.util.List; + +public interface ChatMemoryRepository { + void createMemory(ChatMemoryDO chatMemoryDO); + + void updateMemory(ChatMemoryDO chatMemoryDO); + + ChatMemoryDO getMemory(Long id); + + List getMemories(); +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatMemoryRepositoryImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatMemoryRepositoryImpl.java new file mode 100644 index 000000000..4e2033b0a --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatMemoryRepositoryImpl.java @@ -0,0 +1,42 @@ +package com.tencent.supersonic.chat.server.persistence.repository.impl; + +import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; +import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; +import com.tencent.supersonic.chat.server.persistence.mapper.ChatMemoryMapper; +import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository; +import org.springframework.context.annotation.Primary; +import org.springframework.stereotype.Repository; + +import java.util.List; + +@Repository +@Primary +public class ChatMemoryRepositoryImpl implements ChatMemoryRepository { + + private final ChatMemoryMapper chatMemoryMapper; + + public ChatMemoryRepositoryImpl(ChatMemoryMapper chatMemoryMapper) { + this.chatMemoryMapper = chatMemoryMapper; + } + + @Override + public void createMemory(ChatMemoryDO chatMemoryDO) { + chatMemoryMapper.insert(chatMemoryDO); + } + + @Override + public void updateMemory(ChatMemoryDO chatMemoryDO) { + chatMemoryMapper.updateById(chatMemoryDO); + } + + @Override + public ChatMemoryDO getMemory(Long id) { + return chatMemoryMapper.selectById(id); + } + + @Override + public List getMemories() { + return chatMemoryMapper.selectList(new QueryWrapper<>()); + } + +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java new file mode 100644 index 000000000..4b1965145 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java @@ -0,0 +1,13 @@ +package com.tencent.supersonic.chat.server.service; + +import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; + +import java.util.List; + +public interface MemoryService { + void createMemory(ChatMemoryDO memory); + + void updateMemory(ChatMemoryDO memory); + + List getMemories(); +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java new file mode 100644 index 000000000..26cca016e --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java @@ -0,0 +1,55 @@ +package com.tencent.supersonic.chat.server.service.impl; + +import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; +import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.ReviewResult; +import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.Status; +import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository; +import com.tencent.supersonic.chat.server.service.MemoryService; +import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.common.service.ExemplarService; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +import java.util.List; + +@Service +public class MemoryServiceImpl implements MemoryService { + + @Autowired + private ChatMemoryRepository chatMemoryRepository; + + @Autowired + private ExemplarService exemplarService; + + @Override + public void createMemory(ChatMemoryDO memory) { + if (ReviewResult.POSITIVE.equals(memory.getHumanReviewRet())) { + enableMemory(memory); + } + chatMemoryRepository.createMemory(memory); + } + + @Override + public void updateMemory(ChatMemoryDO memory) { + if (!ChatMemoryDO.Status.ENABLED.equals(memory.getStatus()) + && ReviewResult.POSITIVE.equals(memory.getHumanReviewRet())) { + enableMemory(memory); + } + chatMemoryRepository.updateMemory(memory); + } + + @Override + public List getMemories() { + return chatMemoryRepository.getMemories(); + } + + private void enableMemory(ChatMemoryDO memory) { + exemplarService.storeExemplar(memory.getAgentId().toString(), + SqlExemplar.builder() + .question(memory.getQuestion()) + .dbSchema(memory.getSchema()) + .sql(memory.getS2sql()) + .build()); + memory.setStatus(Status.ENABLED); + } +} diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/SqlExemplar.java b/common/src/main/java/com/tencent/supersonic/common/pojo/SqlExemplar.java new file mode 100644 index 000000000..4b75cb8aa --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/SqlExemplar.java @@ -0,0 +1,20 @@ +package com.tencent.supersonic.common.pojo; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +public class SqlExemplar { + + private String question; + + private String dbSchema; + + private String sql; + +} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/service/ExemplarService.java b/common/src/main/java/com/tencent/supersonic/common/service/ExemplarService.java new file mode 100644 index 000000000..f08839682 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/service/ExemplarService.java @@ -0,0 +1,14 @@ +package com.tencent.supersonic.common.service; + +import com.tencent.supersonic.common.pojo.SqlExemplar; + +import java.util.List; + +public interface ExemplarService { + void storeExemplar(String collection, SqlExemplar exemplar); + + List recallExemplars(String collection, String query, int num); + + List recallExemplars(String query, int num); + +} diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java new file mode 100644 index 000000000..e774978e8 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java @@ -0,0 +1,87 @@ +package com.tencent.supersonic.common.service.impl; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.Lists; +import com.tencent.supersonic.common.config.EmbeddingConfig; +import com.tencent.supersonic.common.service.EmbeddingService; +import com.tencent.supersonic.common.service.ExemplarService; +import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.common.pojo.SqlExemplar; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.RetrieveQuery; +import dev.langchain4j.store.embedding.RetrieveQueryResult; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.CommandLineRunner; +import org.springframework.core.io.ClassPathResource; +import org.springframework.stereotype.Service; + +import java.io.IOException; +import java.io.InputStream; +import java.util.List; + +@Service +@Slf4j +public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner { + + private static final String SYS_EXEMPLAR_FILE = "s2ql_exemplar.json"; + + private TypeReference> valueTypeRef = new TypeReference>() { + }; + + private final ObjectMapper objectMapper = JsonUtil.INSTANCE.getObjectMapper(); + + @Autowired + private EmbeddingConfig embeddingConfig; + + @Autowired + private EmbeddingService embeddingService; + + public void storeExemplar(String collection, SqlExemplar exemplar) { + Metadata metadata = Metadata.from(JsonUtil.toMap(JsonUtil.toString(exemplar), + String.class, Object.class)); + TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata); + + embeddingService.addQuery(collection, Lists.newArrayList(segment)); + } + + public List recallExemplars(String query, int num) { + String collection = embeddingConfig.getText2sqlCollectionName(); + return recallExemplars(collection, query, num); + } + + public List recallExemplars(String collection, String query, int num) { + List exemplars = Lists.newArrayList(); + RetrieveQuery retrieveQuery = RetrieveQuery.builder() + .queryTextsList(Lists.newArrayList(query)) + .build(); + List results = embeddingService.retrieveQuery(collection, retrieveQuery, num); + results.stream().forEach(ret -> { + ret.getRetrieval().stream().forEach(r -> { + exemplars.add(JsonUtil.mapToObject(r.getMetadata(), SqlExemplar.class)); + }); + }); + + return exemplars; + } + + @Override + public void run(String... args) { + try { + loadSysExemplars(); + } catch (IOException e) { + log.error("Failed to load system exemplars", e); + } + } + + private void loadSysExemplars() throws IOException { + ClassPathResource resource = new ClassPathResource(SYS_EXEMPLAR_FILE); + InputStream inputStream = resource.getInputStream(); + List exemplars = objectMapper.readValue(inputStream, valueTypeRef); + String collection = embeddingConfig.getText2sqlCollectionName(); + exemplars.stream().forEach(e -> storeExemplar(collection, e)); + } + +} diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java index ddadea9cb..020f04894 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java @@ -1,14 +1,17 @@ package com.tencent.supersonic.headless.api.pojo.request; +import com.google.common.collect.Lists; import com.google.common.collect.Sets; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; +import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import lombok.Data; +import java.util.List; import java.util.Set; @Data @@ -24,4 +27,5 @@ public class QueryReq { private SchemaMapInfo mapInfo = new SchemaMapInfo(); private QueryDataType queryDataType = QueryDataType.ALL; private LLMConfig llmConfig; + private List exemplars = Lists.newArrayList(); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/QueryContext.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/QueryContext.java index 0a6ce24cc..da2e60c6d 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/QueryContext.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/QueryContext.java @@ -8,6 +8,7 @@ import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; +import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; @@ -49,6 +50,7 @@ public class QueryContext { private WorkflowState workflowState; private QueryDataType queryDataType = QueryDataType.ALL; private LLMConfig llmConfig; + private List exemplars; public List getCandidateQueries() { ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/Exemplar.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/Exemplar.java deleted file mode 100644 index 7c64f9a1a..000000000 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/Exemplar.java +++ /dev/null @@ -1,16 +0,0 @@ -package com.tencent.supersonic.headless.chat.parser.llm; - -import lombok.Data; - -@Data -public class Exemplar { - - private String question; - - private String questionAugmented; - - private String dbSchema; - - private String sql; - -} \ No newline at end of file diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ExemplarManager.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ExemplarManager.java deleted file mode 100644 index cb90965f6..000000000 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ExemplarManager.java +++ /dev/null @@ -1,101 +0,0 @@ -package com.tencent.supersonic.headless.chat.parser.llm; - - -import com.fasterxml.jackson.core.type.TypeReference; -import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.common.service.EmbeddingService; -import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.headless.chat.utils.ComponentFactory; -import dev.langchain4j.data.document.Metadata; -import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.store.embedding.Retrieval; -import dev.langchain4j.store.embedding.RetrieveQuery; -import dev.langchain4j.store.embedding.RetrieveQueryResult; -import dev.langchain4j.store.embedding.TextSegmentConvert; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.collections.CollectionUtils; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.CommandLineRunner; -import org.springframework.core.annotation.Order; -import org.springframework.core.io.ClassPathResource; -import org.springframework.stereotype.Component; - -import java.io.IOException; -import java.io.InputStream; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; - -@Slf4j -@Component -@Order(0) -public class ExemplarManager implements CommandLineRunner { - - private static final String EXAMPLE_JSON_FILE = "s2ql_exemplar.json"; - - @Autowired - private EmbeddingService embeddingService; - - @Autowired - private EmbeddingConfig embeddingConfig; - - private TypeReference> valueTypeRef = new TypeReference>() { - }; - - @Override - public void run(String... args) { - try { - if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) { - loadDefaultExemplars(); - } - } catch (Exception e) { - log.error("Failed to init examples", e); - } - } - - public void addExemplars(List exemplars, String collectionName) { - List queries = new ArrayList<>(); - for (int i = 0; i < exemplars.size(); i++) { - Exemplar exemplar = exemplars.get(i); - String question = exemplar.getQuestion(); - Map metaDataMap = JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class); - TextSegment embeddingQuery = TextSegment.from(question, new Metadata(metaDataMap)); - TextSegmentConvert.addQueryId(embeddingQuery, String.valueOf(i)); - queries.add(embeddingQuery); - } - embeddingService.addQuery(collectionName, queries); - } - - public List> recallExemplars(String queryText, int maxResults) { - String collectionName = embeddingConfig.getText2sqlCollectionName(); - RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText)) - .queryEmbeddings(null).build(); - - List resultList = embeddingService.retrieveQuery(collectionName, retrieveQuery, - maxResults); - List> result = new ArrayList<>(); - if (CollectionUtils.isEmpty(resultList)) { - return result; - } - for (Retrieval retrieval : resultList.get(0).getRetrieval()) { - if (Objects.nonNull(retrieval.getMetadata()) && !retrieval.getMetadata().isEmpty()) { - Map convertedMap = retrieval.getMetadata().entrySet().stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> String.valueOf(entry.getValue()))); - result.add(convertedMap); - } - } - return result; - } - - private void loadDefaultExemplars() throws IOException { - ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE); - InputStream inputStream = resource.getInputStream(); - List examples = JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef); - String collectionName = embeddingConfig.getText2sqlCollectionName(); - addExemplars(examples, collectionName); - } - -} diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java index e8d128839..c01e22ab0 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java @@ -109,6 +109,8 @@ public class LLMRequestService { llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE))); llmReq.setLlmConfig(queryCtx.getLlmConfig()); + llmReq.setExemplars(queryCtx.getExemplars()); + return llmReq; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index 86b3d7537..f8166d15b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.chat.parser.llm; import com.google.common.collect.Lists; +import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import dev.langchain4j.data.message.AiMessage; @@ -42,11 +43,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { public LLMResp generate(LLMReq llmReq) { //1.recall exemplars keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq); - List>> exemplarsList = promptHelper.getFewShotExemplars(llmReq); + List> exemplarsList = promptHelper.getFewShotExemplars(llmReq); //2.generate sql generation prompt for each self-consistency inference - Map>> prompt2Exemplar = new HashMap<>(); - for (List> exemplars : exemplarsList) { + Map> prompt2Exemplar = new HashMap<>(); + for (List exemplars : exemplarsList) { Prompt prompt = generatePrompt(llmReq, exemplars); prompt2Exemplar.put(prompt, exemplars); } @@ -67,25 +68,24 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { Pair> sqlMapPair = OutputFormat.selfConsistencyVote( Lists.newArrayList(prompt2Output.values())); LLMResp llmResp = new LLMResp(); - llmResp.setQuery(llmReq.getQueryText()); + llmResp.setQuery(promptHelper.buildAugmentedQuestion(llmReq)); + llmResp.setDbSchema(promptHelper.buildSchemaStr(llmReq)); + llmResp.setSqlOutput(sqlMapPair.getLeft()); //TODO: should use the same few-shot exemplars as the one chose by self-consistency vote llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(exemplarsList.get(0), sqlMapPair.getRight())); return llmResp; } - private Prompt generatePrompt(LLMReq llmReq, List> fewshotExampleList) { + private Prompt generatePrompt(LLMReq llmReq, List fewshotExampleList) { StringBuilder exemplarsStr = new StringBuilder(); - for (Map example : fewshotExampleList) { - String metadata = example.get("dbSchema"); - String question = example.get("questionAugmented"); - String sql = example.get("sql"); + for (SqlExemplar exemplar : fewshotExampleList) { String exemplarStr = String.format("#UserQuery: %s #Schema: %s #SQL: %s\n", - question, metadata, sql); + exemplar.getQuestion(), exemplar.getDbSchema(), exemplar.getSql()); exemplarsStr.append(exemplarStr); } - String dataSemanticsStr = promptHelper.buildMetadataStr(llmReq); + String dataSemanticsStr = promptHelper.buildSchemaStr(llmReq); String questionAugmented = promptHelper.buildAugmentedQuestion(llmReq); String promptStr = String.format(INSTRUCTION, exemplarsStr, questionAugmented, dataSemanticsStr); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OutputFormat.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OutputFormat.java index d9849dc4e..41c00390e 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OutputFormat.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OutputFormat.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.headless.chat.parser.llm; +import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.tuple.Pair; @@ -108,7 +109,7 @@ public class OutputFormat { return results; } - public static Map buildSqlRespMap(List> sqlExamples, + public static Map buildSqlRespMap(List sqlExamples, Map sqlMap) { if (sqlMap == null) { return new HashMap<>(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java index da2a83dc7..ca3f0fa77 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java @@ -1,5 +1,8 @@ package com.tencent.supersonic.headless.chat.parser.llm; +import com.google.common.collect.Lists; +import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.common.service.ExemplarService; import com.tencent.supersonic.headless.chat.parser.ParserConfig; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import lombok.extern.slf4j.Slf4j; @@ -11,7 +14,6 @@ import org.springframework.util.CollectionUtils; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Map; import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER; import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_FEW_SHOT_NUMBER; @@ -25,20 +27,27 @@ public class PromptHelper { private ParserConfig parserConfig; @Autowired - private ExemplarManager exemplarManager; + private ExemplarService exemplarService; - public List>> getFewShotExemplars(LLMReq llmReq) { + public List> getFewShotExemplars(LLMReq llmReq) { int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER)); int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER)); int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER)); - List> exemplars = exemplarManager.recallExemplars(llmReq.getQueryText(), - exemplarRecallNumber); - List>> results = new ArrayList<>(); + List exemplars = Lists.newArrayList(); + llmReq.getExemplars().stream().forEach(e -> { + exemplars.add(e); + }); + int recallSize = exemplarRecallNumber - llmReq.getExemplars().size(); + if (recallSize > 0) { + exemplars.addAll(exemplarService.recallExemplars(llmReq.getQueryText(), recallSize)); + } + + List> results = new ArrayList<>(); // use random collection of exemplars for each self-consistency inference for (int i = 0; i < selfConsistencyNumber; i++) { - List> shuffledList = new ArrayList<>(exemplars); + List shuffledList = new ArrayList<>(exemplars); Collections.shuffle(shuffledList); results.add(shuffledList.subList(0, fewShotNumber)); } @@ -64,7 +73,7 @@ public class PromptHelper { linkingListStr, currentDataStr, termStr, priorExts); } - public String buildMetadataStr(LLMReq llmReq) { + public String buildSchemaStr(LLMReq llmReq) { String tableStr = llmReq.getSchema().getDataSetName(); StringBuilder metricStr = new StringBuilder(); StringBuilder dimensionStr = new StringBuilder(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java index 014923b7c..57fc09d9e 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java @@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonValue; import com.google.common.collect.Lists; import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.headless.api.pojo.SchemaElement; +import com.tencent.supersonic.common.pojo.SqlExemplar; import lombok.Data; import java.util.List; @@ -26,6 +27,9 @@ public class LLMReq { private SqlGenType sqlGenType; private LLMConfig llmConfig; + + private List exemplars; + @Data public static class ElementValue { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMResp.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMResp.java index f7e1a672c..f1702bf40 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMResp.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMResp.java @@ -12,6 +12,8 @@ public class LLMResp { private String modelName; + private String dbSchema; + private String sqlOutput; private List fields; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlResp.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlResp.java index 0ee93feb2..f4efe51bc 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlResp.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlResp.java @@ -1,13 +1,12 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql; +import com.tencent.supersonic.common.pojo.SqlExemplar; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; import java.util.List; -import java.util.Map; - @Data @Builder @@ -17,6 +16,6 @@ public class LLMSqlResp { private double sqlWeight; - private List> fewShots; + private List fewShots; } diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java index be59115a1..aacca9922 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java @@ -172,6 +172,7 @@ public abstract class S2BaseDemo implements CommandLineRunner { executeReq.setQueryText(queryText); executeReq.setChatId(parseResp.getChatId()); executeReq.setUser(User.getFakeUser()); + executeReq.setAgentId(agentId); executeReq.setSaveAnswer(true); chatService.performExecution(executeReq); } diff --git a/launchers/standalone/src/main/resources/db/schema-h2.sql b/launchers/standalone/src/main/resources/db/schema-h2.sql index 8f5a6ddc5..b4a361aa3 100644 --- a/launchers/standalone/src/main/resources/db/schema-h2.sql +++ b/launchers/standalone/src/main/resources/db/schema-h2.sql @@ -83,6 +83,25 @@ CREATE TABLE IF NOT EXISTS `s2_chat_config` ( COMMENT ON TABLE s2_chat_config IS 'chat config information table '; +CREATE TABLE IF NOT EXISTS `s2_chat_memory` ( + `id` INT NOT NULL AUTO_INCREMENT, + `question` varchar(655) , + `agent_id` INT , + `db_schema` TEXT , + `s2_sql` TEXT , + `status` char(10) , + `llm_review` char(10) , + `llm_comment` TEXT, + `human_review` char(10) , + `human_comment` TEXT , + `created_at` TIMESTAMP , + `updated_at` TIMESTAMP , + `created_by` varchar(100) , + `updated_by` varchar(100) , + PRIMARY KEY (`id`) + ) ; +COMMENT ON TABLE s2_chat_memory IS 'chat memory table '; + create table IF NOT EXISTS s2_user ( id INT AUTO_INCREMENT, diff --git a/launchers/standalone/src/main/resources/db/schema-mysql.sql b/launchers/standalone/src/main/resources/db/schema-mysql.sql index 1e97d5f38..f93479420 100644 --- a/launchers/standalone/src/main/resources/db/schema-mysql.sql +++ b/launchers/standalone/src/main/resources/db/schema-mysql.sql @@ -136,6 +136,24 @@ CREATE TABLE `s2_chat_config` ( PRIMARY KEY (`id`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='主题域扩展信息表'; +CREATE TABLE IF NOT EXISTS `s2_chat_memory` ( + `id` INT NOT NULL AUTO_INCREMENT, + `question` varchar(655) , + `agent_id` INT , + `db_schema` TEXT , + `s2_sql` TEXT , + `status` char(10) , + `llm_review` char(10) , + `llm_comment` TEXT, + `human_review` char(10) , + `human_comment` TEXT , + `created_at` TIMESTAMP NOT NULL , + `updated_at` TIMESTAMP NOT NULL , + `created_by` varchar(100) NOT NULL , + `updated_by` varchar(100) NOT NULL , + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + CREATE TABLE `s2_chat_context` ( `chat_id` bigint(20) NOT NULL COMMENT 'context chat id', `modified_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'row modify time', diff --git a/launchers/standalone/src/test/resources/db/schema-h2.sql b/launchers/standalone/src/test/resources/db/schema-h2.sql index b06848f95..2c8a73d99 100644 --- a/launchers/standalone/src/test/resources/db/schema-h2.sql +++ b/launchers/standalone/src/test/resources/db/schema-h2.sql @@ -82,6 +82,24 @@ CREATE TABLE IF NOT EXISTS `s2_chat_config` ( ) ; COMMENT ON TABLE s2_chat_config IS 'chat config information table '; +CREATE TABLE IF NOT EXISTS `s2_chat_memory` ( + `id` INT NOT NULL AUTO_INCREMENT, + `question` varchar(655) , + `agent_id` INT , + `db_schema` TEXT , + `s2_sql` TEXT , + `status` char(10) , + `llm_review` char(10) , + `llm_comment` TEXT, + `human_review` char(10) , + `human_comment` TEXT , + `created_at` TIMESTAMP , + `updated_at` TIMESTAMP , + `created_by` varchar(100) , + `updated_by` varchar(100) , + PRIMARY KEY (`id`) + ) ; +COMMENT ON TABLE s2_chat_memory IS 'chat memory table '; create table IF NOT EXISTS s2_user (