(feature)(chat)Add switch to enable LLM-based memory review, as it leads to more token costs. #1385

This commit is contained in:
jerryjzhang
2024-07-10 21:08:02 +08:00
parent 78a91ad8c2
commit 9bb95ca4be
7 changed files with 19 additions and 5 deletions

View File

@@ -24,6 +24,7 @@ public class Agent extends RecordInfo {
private Integer id; private Integer id;
private Integer enableSearch; private Integer enableSearch;
private Integer enableMemoryReview;
private String name; private String name;
private String description; private String description;
@@ -60,6 +61,10 @@ public class Agent extends RecordInfo {
return enableSearch != null && enableSearch == 1; return enableSearch != null && enableSearch == 1;
} }
public boolean enableMemoryReview() {
return enableMemoryReview != null && enableMemoryReview == 1;
}
public static boolean containsAllModel(Set<Long> detectViewIds) { public static boolean containsAllModel(Set<Long> detectViewIds) {
return !CollectionUtils.isEmpty(detectViewIds) && detectViewIds.contains(-1L); return !CollectionUtils.isEmpty(detectViewIds) && detectViewIds.contains(-1L);
} }

View File

@@ -32,7 +32,7 @@ public class MemoryReviewTask {
+ "please take a review and give your opinion.\n" + "please take a review and give your opinion.\n"
+ "#Rules: " + "#Rules: "
+ "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`." + "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`."
+ "2.DO NOT check the usage of `数据日期` field and `datediff()` function.\n" + "2.ALWAYS recognize `数据日期` as the date field.\n"
+ "#Question: %s\n" + "#Question: %s\n"
+ "#Schema: %s\n" + "#Schema: %s\n"
+ "#SQL: %s\n" + "#SQL: %s\n"
@@ -51,7 +51,7 @@ public class MemoryReviewTask {
memoryService.getMemoriesForLlmReview().stream() memoryService.getMemoriesForLlmReview().stream()
.forEach(m -> { .forEach(m -> {
Agent chatAgent = agentService.getAgent(m.getAgentId()); Agent chatAgent = agentService.getAgent(m.getAgentId());
if (Objects.nonNull(chatAgent)) { if (Objects.nonNull(chatAgent) && chatAgent.enableMemoryReview()) {
String promptStr = String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getS2sql()); String promptStr = String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getS2sql());
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
@@ -72,7 +72,7 @@ public class MemoryReviewTask {
log.debug("ChatLanguageModel not found for agent:{}", chatAgent.getId()); log.debug("ChatLanguageModel not found for agent:{}", chatAgent.getId());
} }
} else { } else {
log.debug("Agent not found for memory:{}", m.getAgentId()); log.debug("Agent id {} not found or memory review disabled", m.getAgentId());
} }
}); });
} }

View File

@@ -65,6 +65,7 @@ public class AgentDO {
* *
*/ */
private Integer enableSearch; private Integer enableSearch;
private Integer enableMemoryReview;
private String modelConfig; private String modelConfig;
private String multiTurnConfig; private String multiTurnConfig;

View File

@@ -56,8 +56,8 @@ public class ParseInfoProcessor implements ResultProcessor {
if (StringUtils.isBlank(correctS2SQL)) { if (StringUtils.isBlank(correctS2SQL)) {
return; return;
} }
List<FieldExpression> expressions = SqlSelectHelper.getFilterExpression(correctS2SQL); List<FieldExpression> expressions = SqlSelectHelper.getFilterExpression(correctS2SQL);
//set dataInfo //set dataInfo
try { try {
if (!org.apache.commons.collections.CollectionUtils.isEmpty(expressions)) { if (!org.apache.commons.collections.CollectionUtils.isEmpty(expressions)) {
@@ -70,6 +70,9 @@ public class ParseInfoProcessor implements ResultProcessor {
log.error("set dateInfo error :", e); log.error("set dateInfo error :", e);
} }
if (correctS2SQL.equals(sqlInfo.getParsedS2SQL())) {
return;
}
//set filter //set filter
Long dataSetId = parseInfo.getDataSetId(); Long dataSetId = parseInfo.getDataSetId();
try { try {

View File

@@ -354,4 +354,7 @@ alter table s2_agent add column `prompt_config` varchar(6000) COLLATE utf8_unico
--20240707 --20240707
alter table s2_agent add model_config varchar(6000) null; alter table s2_agent add model_config varchar(6000) null;
alter table s2_agent add `prompt_config` varchar(5000) COLLATE utf8_unicode_ci DEFAULT NULL; alter table s2_agent add `prompt_config` varchar(5000) COLLATE utf8_unicode_ci DEFAULT NULL;
--20240710
alter table s2_agent add enable_memory_review tinyint DEFAULT 0;

View File

@@ -383,6 +383,7 @@ CREATE TABLE IF NOT EXISTS s2_agent
updated_by varchar(100) null, updated_by varchar(100) null,
updated_at TIMESTAMP null, updated_at TIMESTAMP null,
enable_search int null, enable_search int null,
enable_memory_review int null,
PRIMARY KEY (`id`) PRIMARY KEY (`id`)
); COMMENT ON TABLE s2_agent IS 'agent information table'; ); COMMENT ON TABLE s2_agent IS 'agent information table';

View File

@@ -383,6 +383,7 @@ CREATE TABLE IF NOT EXISTS s2_agent
updated_by varchar(100) null, updated_by varchar(100) null,
updated_at TIMESTAMP null, updated_at TIMESTAMP null,
enable_search int null, enable_search int null,
enable_memory_review int null,
PRIMARY KEY (`id`) PRIMARY KEY (`id`)
); COMMENT ON TABLE s2_agent IS 'agent information table'; ); COMMENT ON TABLE s2_agent IS 'agent information table';