(improvement)(chat)Enable memory directly if the review result by LLM is positive.

This commit is contained in:
jerryjzhang
2024-07-12 09:56:51 +08:00
parent 41ad1ada6c
commit 37da1ac2ae
6 changed files with 25 additions and 17 deletions

View File

@@ -66,6 +66,10 @@ public class MemoryReviewTask {
if (matcher.find()) {
m.setLlmReviewRet(MemoryReviewResult.valueOf(matcher.group(1)));
m.setLlmReviewCmt(matcher.group(2));
// directly enable memory if the LLM determines it positive
if (MemoryReviewResult.POSITIVE.equals(m.getLlmReviewRet())) {
memoryService.enableMemory(m);
}
memoryService.updateMemory(m);
}
} else {

View File

@@ -16,6 +16,10 @@ public interface MemoryService {
void updateMemory(ChatMemoryDO memory);
void enableMemory(ChatMemoryDO memory);
void disableMemory(ChatMemoryDO memory);
PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq);
List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter);

View File

@@ -96,7 +96,9 @@ public class MemoryServiceImpl implements MemoryService {
return chatMemoryRepository.getMemories(queryWrapper);
}
private void enableMemory(ChatMemoryDO memory) {
@Override
public void enableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.ENABLED);
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
SqlExemplar.builder()
.question(memory.getQuestion())
@@ -105,7 +107,9 @@ public class MemoryServiceImpl implements MemoryService {
.build());
}
private void disableMemory(ChatMemoryDO memory) {
@Override
public void disableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.DISABLED);
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
SqlExemplar.builder()
.question(memory.getQuestion())

View File

@@ -5,7 +5,6 @@ import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlDateSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo;
@@ -38,8 +37,6 @@ public class TimeCorrector extends BaseSemanticCorrector {
removeDateIfExist(chatQueryContext, semanticParseInfo);
parserDateDiffFunction(semanticParseInfo);
addLowerBoundDate(semanticParseInfo);
}
@@ -112,10 +109,4 @@ public class TimeCorrector extends BaseSemanticCorrector {
}
}
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
correctS2SQL = SqlReplaceHelper.replaceFunction(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
}

View File

@@ -28,13 +28,14 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
private static final String INSTRUCTION = ""
+ "#Role: You are a data analyst experienced in SQL languages.\n"
+ "#Task: You will be provided a natural language question asked by users,"
+ "please convert it to a SQL query so that relevant data could be returned to the user "
+ "please convert it to a SQL query so that relevant data could be returned "
+ "by executing the SQL query against underlying database.\n"
+ "#Rules:"
+ "1.ALWAYS use `数据日期` as the date field."
+ "2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
+ "3.DO NOT include date filter in the where clause if not explicitly expressed in the query."
+ "4.ONLY respond with the converted SQL statement.\n"
+ "3.ALWAYS calculate the absolute date range by yourself."
+ "4.DO NOT include date filter in the where clause if not explicitly expressed in the question."
+ "5.ONLY respond with the converted SQL statement.\n"
+ "#Exemplars:\n{{exemplar}}"
+ "#Question:{{question}} #Schema:{{schema}} #SQL:";

View File

@@ -141,7 +141,7 @@ public class S2VisitsDemo extends S2BaseDemo {
chatService.parseAndExecute(chatId.intValue(), agentId, "按部门统计");
chatService.parseAndExecute(chatId.intValue(), agentId, "查询近30天");
chatService.parseAndExecute(chatId.intValue(), agentId, "alice 停留时长");
chatService.parseAndExecute(chatId.intValue(), agentId, "对比alice和lucy访问次数");
chatService.parseAndExecute(chatId.intValue(), agentId, "对比alice和lucy访问次数");
chatService.parseAndExecute(chatId.intValue(), agentId, "访问次数最高的部门");
}
@@ -151,8 +151,12 @@ public class S2VisitsDemo extends S2BaseDemo {
agent.setDescription("帮助您用自然语言查询指标,支持时间限定、条件筛选、下钻维度以及聚合统计");
agent.setStatus(1);
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("超音数访问次数", "近15天超音数访问次数汇总", "按部门统计超音数的访问人数",
"对比alice和lucy的停留时长", "超音数访问次数最高的部门"));
agent.setExamples(Lists.newArrayList(
"超音数访问次数",
"近15天超音数访问次数汇总",
"按部门统计超音数的访问人数",
"对比alice和lucy的停留时长",
"超音数访问次数最高的部门"));
AgentConfig agentConfig = new AgentConfig();
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);