mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(headless | chat ) 向量数据被重置后,记忆不会再次添加到向量数据库 (#2164)
This commit is contained in:
@@ -161,9 +161,11 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
JsonUtil.toMap(agentDO.getChatModelConfig(), String.class, ChatApp.class));
|
||||
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
|
||||
agent.getChatAppConfig().values().forEach(c -> {
|
||||
if (c.isEnable()) {// 优化,减少访问数据库的次数
|
||||
ChatModel chatModel = chatModelService.getChatModel(c.getChatModelId());
|
||||
if (Objects.nonNull(chatModel)) {
|
||||
c.setChatModelConfig(chatModelService.getChatModel(c.getChatModelId()).getConfig());
|
||||
c.setChatModelConfig(chatModel.getConfig());
|
||||
}
|
||||
}
|
||||
});
|
||||
agent.setAdmins(JsonUtil.toList(agentDO.getAdmin(), String.class));
|
||||
|
||||
@@ -18,19 +18,23 @@ import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.CommandLineRunner;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
public class MemoryServiceImpl implements MemoryService {
|
||||
@Slf4j
|
||||
public class MemoryServiceImpl implements MemoryService , CommandLineRunner {
|
||||
|
||||
@Autowired
|
||||
private ChatMemoryRepository chatMemoryRepository;
|
||||
@@ -187,4 +191,23 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
return memory;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run(String... args) { // 优化,启动时检查,向量数据,将记忆放到向量数据库
|
||||
loadSysExemplars();
|
||||
}
|
||||
public void loadSysExemplars() {
|
||||
try {
|
||||
List<ChatMemory> memories =
|
||||
this.getMemories(ChatMemoryFilter.builder().status(MemoryStatus.ENABLED).build());
|
||||
for(ChatMemory memory:memories){
|
||||
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
Text2SQLExemplar.builder().question(memory.getQuestion())
|
||||
.sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema())
|
||||
.sql(memory.getS2sql()).build());
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("Failed to load system exemplars", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
||||
@@ -8,10 +9,7 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.*;
|
||||
|
||||
/** Perform SQL corrections on the "Select" section in S2SQL. */
|
||||
@Slf4j
|
||||
@@ -46,10 +44,28 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
return correctS2SQL;
|
||||
}
|
||||
needAddFields.removeAll(selectFields);
|
||||
|
||||
if (!SqlSelectHelper.hasSubSelect(correctS2SQL)) { //优化内容 , 如果sql 条件包含了这个字段,而且是全等,则不再查询该字段
|
||||
List<FieldExpression> tmp4 = SqlSelectHelper.getWhereExpressions(correctS2SQL);
|
||||
Iterator<String> it = needAddFields.iterator();
|
||||
while (it.hasNext()) {
|
||||
String field = it.next();
|
||||
long size = tmp4.stream()
|
||||
.filter(e -> e.getFieldName().equals(field) && "=".equals(e.getOperator()))
|
||||
.count();
|
||||
if (size == 1) {
|
||||
it.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (needAddFields.size() > 0) {
|
||||
String addFieldsToSelectSql =
|
||||
SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(addFieldsToSelectSql);
|
||||
return addFieldsToSelectSql;
|
||||
} else {
|
||||
return correctS2SQL;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -75,6 +75,7 @@ public class KeywordMapper extends BaseMapper {
|
||||
continue;
|
||||
}
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
if (elementID == null)continue; // 判空优化
|
||||
SchemaElement element = getSchemaElement(dataSetId, elementType, elementID,
|
||||
chatQueryContext.getSemanticSchema());
|
||||
if (Objects.isNull(element)) {
|
||||
|
||||
Reference in New Issue
Block a user