diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java index 301debbdd..feba255a1 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java @@ -66,6 +66,10 @@ public class MemoryReviewTask { if (matcher.find()) { m.setLlmReviewRet(MemoryReviewResult.valueOf(matcher.group(1))); m.setLlmReviewCmt(matcher.group(2)); + // directly enable memory if the LLM determines it positive + if (MemoryReviewResult.POSITIVE.equals(m.getLlmReviewRet())) { + memoryService.enableMemory(m); + } memoryService.updateMemory(m); } } else { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java index 8ff65bb73..0affcfdb1 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/MemoryService.java @@ -16,6 +16,10 @@ public interface MemoryService { void updateMemory(ChatMemoryDO memory); + void enableMemory(ChatMemoryDO memory); + + void disableMemory(ChatMemoryDO memory); + PageInfo pageMemories(PageMemoryReq pageMemoryReq); List getMemories(ChatMemoryFilter chatMemoryFilter); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java index a666a705b..0574c315c 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java @@ -96,7 +96,9 @@ public class MemoryServiceImpl implements MemoryService { return chatMemoryRepository.getMemories(queryWrapper); } - private void enableMemory(ChatMemoryDO memory) { + @Override + public void enableMemory(ChatMemoryDO memory) { + memory.setStatus(MemoryStatus.ENABLED); exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()), SqlExemplar.builder() .question(memory.getQuestion()) @@ -105,7 +107,9 @@ public class MemoryServiceImpl implements MemoryService { .build()); } - private void disableMemory(ChatMemoryDO memory) { + @Override + public void disableMemory(ChatMemoryDO memory) { + memory.setStatus(MemoryStatus.DISABLED); exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()), SqlExemplar.builder() .question(memory.getQuestion()) diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java index 83395f425..4629f0e40 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java @@ -5,7 +5,6 @@ import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlDateSelectHelper; -import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo; @@ -38,8 +37,6 @@ public class TimeCorrector extends BaseSemanticCorrector { removeDateIfExist(chatQueryContext, semanticParseInfo); - parserDateDiffFunction(semanticParseInfo); - addLowerBoundDate(semanticParseInfo); } @@ -112,10 +109,4 @@ public class TimeCorrector extends BaseSemanticCorrector { } } - private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) { - String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); - correctS2SQL = SqlReplaceHelper.replaceFunction(correctS2SQL); - semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); - } - } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index 7f8da30ef..be732bcbb 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -28,13 +28,14 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { private static final String INSTRUCTION = "" + "#Role: You are a data analyst experienced in SQL languages.\n" + "#Task: You will be provided a natural language question asked by users," - + "please convert it to a SQL query so that relevant data could be returned to the user " + + "please convert it to a SQL query so that relevant data could be returned " + "by executing the SQL query against underlying database.\n" + "#Rules:" + "1.ALWAYS use `数据日期` as the date field." + "2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator." - + "3.DO NOT include date filter in the where clause if not explicitly expressed in the query." - + "4.ONLY respond with the converted SQL statement.\n" + + "3.ALWAYS calculate the absolute date range by yourself." + + "4.DO NOT include date filter in the where clause if not explicitly expressed in the question." + + "5.ONLY respond with the converted SQL statement.\n" + "#Exemplars:\n{{exemplar}}" + "#Question:{{question}} #Schema:{{schema}} #SQL:"; diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java index 5600218e4..561b130f5 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java @@ -141,7 +141,7 @@ public class S2VisitsDemo extends S2BaseDemo { chatService.parseAndExecute(chatId.intValue(), agentId, "按部门统计"); chatService.parseAndExecute(chatId.intValue(), agentId, "查询近30天"); chatService.parseAndExecute(chatId.intValue(), agentId, "alice 停留时长"); - chatService.parseAndExecute(chatId.intValue(), agentId, "对比alice和lucy的访问次数"); + chatService.parseAndExecute(chatId.intValue(), agentId, "对比alice和lucy访问次数"); chatService.parseAndExecute(chatId.intValue(), agentId, "访问次数最高的部门"); } @@ -151,8 +151,12 @@ public class S2VisitsDemo extends S2BaseDemo { agent.setDescription("帮助您用自然语言查询指标,支持时间限定、条件筛选、下钻维度以及聚合统计"); agent.setStatus(1); agent.setEnableSearch(1); - agent.setExamples(Lists.newArrayList("超音数访问次数", "近15天超音数访问次数汇总", "按部门统计超音数的访问人数", - "对比alice和lucy的停留时长", "超音数访问次数最高的部门")); + agent.setExamples(Lists.newArrayList( + "超音数访问次数", + "近15天超音数访问次数汇总", + "按部门统计超音数的访问人数", + "对比alice和lucy的停留时长", + "超音数访问次数最高的部门")); AgentConfig agentConfig = new AgentConfig(); RuleParserTool ruleQueryTool = new RuleParserTool(); ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);