From 21e213fb1901cf62f2f689ea10db96f3834d20f0 Mon Sep 17 00:00:00 2001 From: guilinlewis <185641548@qq.com> Date: Wed, 12 Mar 2025 22:19:51 +0800 Subject: [PATCH] =?UTF-8?q?(improvement)(headless=20|=20chat=20)=20?= =?UTF-8?q?=E5=90=91=E9=87=8F=E6=95=B0=E6=8D=AE=E8=A2=AB=E9=87=8D=E7=BD=AE?= =?UTF-8?q?=E5=90=8E=EF=BC=8C=E8=AE=B0=E5=BF=86=E4=B8=8D=E4=BC=9A=E5=86=8D?= =?UTF-8?q?=E6=AC=A1=E6=B7=BB=E5=8A=A0=E5=88=B0=E5=90=91=E9=87=8F=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=20(#2164)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../server/service/impl/AgentServiceImpl.java | 8 +++-- .../service/impl/MemoryServiceImpl.java | 25 ++++++++++++++- .../chat/corrector/SelectCorrector.java | 32 ++++++++++++++----- .../headless/chat/mapper/KeywordMapper.java | 1 + 4 files changed, 54 insertions(+), 12 deletions(-) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index 43eb6dcff..2d76875bb 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -161,9 +161,11 @@ public class AgentServiceImpl extends ServiceImpl implem JsonUtil.toMap(agentDO.getChatModelConfig(), String.class, ChatApp.class)); agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class)); agent.getChatAppConfig().values().forEach(c -> { - ChatModel chatModel = chatModelService.getChatModel(c.getChatModelId()); - if (Objects.nonNull(chatModel)) { - c.setChatModelConfig(chatModelService.getChatModel(c.getChatModelId()).getConfig()); + if (c.isEnable()) {// 优化,减少访问数据库的次数 + ChatModel chatModel = chatModelService.getChatModel(c.getChatModelId()); + if (Objects.nonNull(chatModel)) { + c.setChatModelConfig(chatModel.getConfig()); + } } }); agent.setAdmins(JsonUtil.toList(agentDO.getAdmin(), String.class)); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java index abb45c438..19781d2b6 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java @@ -18,19 +18,23 @@ import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.service.ExemplarService; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.CommandLineRunner; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; +import java.util.Arrays; import java.util.Date; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @Service -public class MemoryServiceImpl implements MemoryService { +@Slf4j +public class MemoryServiceImpl implements MemoryService , CommandLineRunner { @Autowired private ChatMemoryRepository chatMemoryRepository; @@ -187,4 +191,23 @@ public class MemoryServiceImpl implements MemoryService { return memory; } + @Override + public void run(String... args) { // 优化,启动时检查,向量数据,将记忆放到向量数据库 + loadSysExemplars(); + } + public void loadSysExemplars() { + try { + List memories = + this.getMemories(ChatMemoryFilter.builder().status(MemoryStatus.ENABLED).build()); + for(ChatMemory memory:memories){ + exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()), + Text2SQLExemplar.builder().question(memory.getQuestion()) + .sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema()) + .sql(memory.getS2sql()).build()); + } + + } catch (Exception e) { + log.error("Failed to load system exemplars", e); + } + } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java index 87f00c615..e13ef7912 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.headless.chat.corrector; +import com.tencent.supersonic.common.jsqlparser.FieldExpression; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlValidHelper; @@ -8,10 +9,7 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; import org.springframework.util.CollectionUtils; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; +import java.util.*; /** Perform SQL corrections on the "Select" section in S2SQL. */ @Slf4j @@ -46,10 +44,28 @@ public class SelectCorrector extends BaseSemanticCorrector { return correctS2SQL; } needAddFields.removeAll(selectFields); - String addFieldsToSelectSql = - SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields)); - semanticParseInfo.getSqlInfo().setCorrectedS2SQL(addFieldsToSelectSql); - return addFieldsToSelectSql; + + if (!SqlSelectHelper.hasSubSelect(correctS2SQL)) { //优化内容 , 如果sql 条件包含了这个字段,而且是全等,则不再查询该字段 + List tmp4 = SqlSelectHelper.getWhereExpressions(correctS2SQL); + Iterator it = needAddFields.iterator(); + while (it.hasNext()) { + String field = it.next(); + long size = tmp4.stream() + .filter(e -> e.getFieldName().equals(field) && "=".equals(e.getOperator())) + .count(); + if (size == 1) { + it.remove(); + } + } + } + if (needAddFields.size() > 0) { + String addFieldsToSelectSql = + SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields)); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(addFieldsToSelectSql); + return addFieldsToSelectSql; + } else { + return correctS2SQL; + } } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java index a7df51021..7dedc8a82 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java @@ -75,6 +75,7 @@ public class KeywordMapper extends BaseMapper { continue; } Long elementID = NatureHelper.getElementID(nature); + if (elementID == null)continue; // 判空优化 SchemaElement element = getSchemaElement(dataSetId, elementType, elementID, chatQueryContext.getSemanticSchema()); if (Objects.isNull(element)) {