mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(headless | chat ) 向量数据被重置后,记忆不会再次添加到向量数据库 (#2164)
This commit is contained in:
@@ -161,9 +161,11 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
|||||||
JsonUtil.toMap(agentDO.getChatModelConfig(), String.class, ChatApp.class));
|
JsonUtil.toMap(agentDO.getChatModelConfig(), String.class, ChatApp.class));
|
||||||
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
|
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
|
||||||
agent.getChatAppConfig().values().forEach(c -> {
|
agent.getChatAppConfig().values().forEach(c -> {
|
||||||
ChatModel chatModel = chatModelService.getChatModel(c.getChatModelId());
|
if (c.isEnable()) {// 优化,减少访问数据库的次数
|
||||||
if (Objects.nonNull(chatModel)) {
|
ChatModel chatModel = chatModelService.getChatModel(c.getChatModelId());
|
||||||
c.setChatModelConfig(chatModelService.getChatModel(c.getChatModelId()).getConfig());
|
if (Objects.nonNull(chatModel)) {
|
||||||
|
c.setChatModelConfig(chatModel.getConfig());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
agent.setAdmins(JsonUtil.toList(agentDO.getAdmin(), String.class));
|
agent.setAdmins(JsonUtil.toList(agentDO.getAdmin(), String.class));
|
||||||
|
|||||||
@@ -18,19 +18,23 @@ import com.tencent.supersonic.common.config.EmbeddingConfig;
|
|||||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.pojo.User;
|
import com.tencent.supersonic.common.pojo.User;
|
||||||
import com.tencent.supersonic.common.service.ExemplarService;
|
import com.tencent.supersonic.common.service.ExemplarService;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.boot.CommandLineRunner;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
public class MemoryServiceImpl implements MemoryService {
|
@Slf4j
|
||||||
|
public class MemoryServiceImpl implements MemoryService , CommandLineRunner {
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private ChatMemoryRepository chatMemoryRepository;
|
private ChatMemoryRepository chatMemoryRepository;
|
||||||
@@ -187,4 +191,23 @@ public class MemoryServiceImpl implements MemoryService {
|
|||||||
return memory;
|
return memory;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run(String... args) { // 优化,启动时检查,向量数据,将记忆放到向量数据库
|
||||||
|
loadSysExemplars();
|
||||||
|
}
|
||||||
|
public void loadSysExemplars() {
|
||||||
|
try {
|
||||||
|
List<ChatMemory> memories =
|
||||||
|
this.getMemories(ChatMemoryFilter.builder().status(MemoryStatus.ENABLED).build());
|
||||||
|
for(ChatMemory memory:memories){
|
||||||
|
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||||
|
Text2SQLExemplar.builder().question(memory.getQuestion())
|
||||||
|
.sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema())
|
||||||
|
.sql(memory.getS2sql()).build());
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Failed to load system exemplars", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.chat.corrector;
|
package com.tencent.supersonic.headless.chat.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
||||||
@@ -8,10 +9,7 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
|||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
/** Perform SQL corrections on the "Select" section in S2SQL. */
|
/** Perform SQL corrections on the "Select" section in S2SQL. */
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -46,10 +44,28 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
|||||||
return correctS2SQL;
|
return correctS2SQL;
|
||||||
}
|
}
|
||||||
needAddFields.removeAll(selectFields);
|
needAddFields.removeAll(selectFields);
|
||||||
String addFieldsToSelectSql =
|
|
||||||
SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
if (!SqlSelectHelper.hasSubSelect(correctS2SQL)) { //优化内容 , 如果sql 条件包含了这个字段,而且是全等,则不再查询该字段
|
||||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(addFieldsToSelectSql);
|
List<FieldExpression> tmp4 = SqlSelectHelper.getWhereExpressions(correctS2SQL);
|
||||||
return addFieldsToSelectSql;
|
Iterator<String> it = needAddFields.iterator();
|
||||||
|
while (it.hasNext()) {
|
||||||
|
String field = it.next();
|
||||||
|
long size = tmp4.stream()
|
||||||
|
.filter(e -> e.getFieldName().equals(field) && "=".equals(e.getOperator()))
|
||||||
|
.count();
|
||||||
|
if (size == 1) {
|
||||||
|
it.remove();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (needAddFields.size() > 0) {
|
||||||
|
String addFieldsToSelectSql =
|
||||||
|
SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(addFieldsToSelectSql);
|
||||||
|
return addFieldsToSelectSql;
|
||||||
|
} else {
|
||||||
|
return correctS2SQL;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Long elementID = NatureHelper.getElementID(nature);
|
Long elementID = NatureHelper.getElementID(nature);
|
||||||
|
if (elementID == null)continue; // 判空优化
|
||||||
SchemaElement element = getSchemaElement(dataSetId, elementType, elementID,
|
SchemaElement element = getSchemaElement(dataSetId, elementType, elementID,
|
||||||
chatQueryContext.getSemanticSchema());
|
chatQueryContext.getSemanticSchema());
|
||||||
if (Objects.isNull(element)) {
|
if (Objects.isNull(element)) {
|
||||||
|
|||||||
Reference in New Issue
Block a user