diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java index 19781d2b6..f6e895549 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java @@ -34,7 +34,7 @@ import java.util.stream.Collectors; @Service @Slf4j -public class MemoryServiceImpl implements MemoryService , CommandLineRunner { +public class MemoryServiceImpl implements MemoryService, CommandLineRunner { @Autowired private ChatMemoryRepository chatMemoryRepository; @@ -195,12 +195,14 @@ public class MemoryServiceImpl implements MemoryService , CommandLineRunner { public void run(String... args) { // 优化,启动时检查,向量数据,将记忆放到向量数据库 loadSysExemplars(); } + public void loadSysExemplars() { try { - List memories = - this.getMemories(ChatMemoryFilter.builder().status(MemoryStatus.ENABLED).build()); - for(ChatMemory memory:memories){ - exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()), + List memories = this + .getMemories(ChatMemoryFilter.builder().status(MemoryStatus.ENABLED).build()); + for (ChatMemory memory : memories) { + exemplarService.storeExemplar( + embeddingConfig.getMemoryCollectionName(memory.getAgentId()), Text2SQLExemplar.builder().question(memory.getQuestion()) .sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema()) .sql(memory.getS2sql()).build()); diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/Dimension.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/Dimension.java index 1e0539d98..4e79a4a20 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/Dimension.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/Dimension.java @@ -3,10 +3,12 @@ package com.tencent.supersonic.headless.api.pojo; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.headless.api.pojo.enums.DimensionType; import lombok.AllArgsConstructor; +import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; @Data +@Builder @AllArgsConstructor @NoArgsConstructor public class Dimension { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java index e13ef7912..4302b57a0 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java @@ -45,7 +45,7 @@ public class SelectCorrector extends BaseSemanticCorrector { } needAddFields.removeAll(selectFields); - if (!SqlSelectHelper.hasSubSelect(correctS2SQL)) { //优化内容 , 如果sql 条件包含了这个字段,而且是全等,则不再查询该字段 + if (!SqlSelectHelper.hasSubSelect(correctS2SQL)) { // 优化内容 , 如果sql 条件包含了这个字段,而且是全等,则不再查询该字段 List tmp4 = SqlSelectHelper.getWhereExpressions(correctS2SQL); Iterator it = needAddFields.iterator(); while (it.hasNext()) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java index 7dedc8a82..3fe87477f 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java @@ -75,7 +75,9 @@ public class KeywordMapper extends BaseMapper { continue; } Long elementID = NatureHelper.getElementID(nature); - if (elementID == null)continue; // 判空优化 + if (elementID == null) { + continue; + } SchemaElement element = getSchemaElement(dataSetId, elementType, elementID, chatQueryContext.getSemanticSchema()); if (Objects.isNull(element)) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ModelService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ModelService.java index 471bf84d6..6bf019387 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ModelService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ModelService.java @@ -3,13 +3,11 @@ package com.tencent.supersonic.headless.server.service; import com.tencent.supersonic.common.pojo.ItemDateResp; import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.enums.AuthType; +import com.tencent.supersonic.headless.api.pojo.Dimension; import com.tencent.supersonic.headless.api.pojo.ItemDateFilter; import com.tencent.supersonic.headless.api.pojo.MetaFilter; import com.tencent.supersonic.headless.api.pojo.ModelSchema; -import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq; -import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq; -import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq; -import com.tencent.supersonic.headless.api.pojo.request.ModelReq; +import com.tencent.supersonic.headless.api.pojo.request.*; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.UnAvailableItemResp; @@ -54,4 +52,6 @@ public interface ModelService { DatabaseResp getDatabaseByModelId(Long modelId); void batchUpdateStatus(MetaBatchReq metaBatchReq, User user); + + Dimension updateDimension(DimensionReq dimensionReq, User user); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DimensionServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DimensionServiceImpl.java index 20589cfc2..3bf6a7692 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DimensionServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DimensionServiceImpl.java @@ -84,6 +84,10 @@ public class DimensionServiceImpl extends ServiceImpl dimOptional = modelDetail.getDimensions().stream() + .filter(dimension -> dimension.getBizName().equals(dimensionReq.getBizName())) + .findFirst(); + Dimension result; + if (dimOptional.isPresent()) { + Dimension dimension = dimOptional.get(); + dimension.setExpr(dimensionReq.getExpr()); + dimension.setName(dimensionReq.getName()); + dimension.setType(DimensionType.valueOf(dimensionReq.getType())); + dimension.setDescription(dimensionReq.getDescription()); + result = dimension; + } else { + Dimension dimension = Dimension.builder().name(dimensionReq.getName()) + .bizName(dimensionReq.getBizName()).expr(dimensionReq.getExpr()) + .type(DimensionType.valueOf(dimensionReq.getType())) + .description(dimensionReq.getDescription()).build(); + modelDetail.getDimensions().add(dimension); + result = dimension; + } + + modelDO.setModelDetail(JsonUtil.toString(modelDetail)); + modelRepository.updateModel(modelDO); + return result; + } + protected ModelDO getModelDO(Long id) { return modelRepository.getModelById(id); } diff --git a/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/SqlVariableParseUtilsTest.java b/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/SqlVariableParseUtilsTest.java index 36f799154..d74401a87 100644 --- a/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/SqlVariableParseUtilsTest.java +++ b/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/SqlVariableParseUtilsTest.java @@ -36,7 +36,8 @@ public class SqlVariableParseUtilsTest { @Test void testParseSql_if() { - String sql = "select * from t_$interval$ where id = $id$ $if(name)$and name = $name$$endif$"; + String sql = + "select * from t_$interval$ where id = $id$ $if(name)$and name = $name$$endif$"; List variables = Lists.newArrayList(mockNumSqlVariable(), mockExprSqlVariable(), mockStrSqlVariable()); List params =