diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/builder/ModelIntelligentBuilder.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/builder/ModelIntelligentBuilder.java index d2c37dcb4..c3730be22 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/builder/ModelIntelligentBuilder.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/builder/ModelIntelligentBuilder.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.headless.server.builder; +import com.fasterxml.jackson.databind.ObjectMapper; import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.enums.AppModule; @@ -12,8 +13,11 @@ import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.service.AiServices; import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.core.io.ClassPathResource; import org.springframework.stereotype.Component; +import java.io.InputStream; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -25,6 +29,8 @@ public class ModelIntelligentBuilder extends IntelligentBuilder { public static final String APP_KEY = "BUILD_DATA_MODEL"; + private static final String SYS_EXEMPLAR_FILE = "s2-buildModel-exemplar.json"; + public static final String INSTRUCTION = "" + "Role: As an experienced data analyst with extensive modeling experience, " + " you are expected to have a deep understanding of data analysis and data modeling concepts." @@ -44,12 +50,16 @@ public class ModelIntelligentBuilder extends IntelligentBuilder { + "\nDBSchema: {{DBSchema}}" + "\nOtherRelatedDBSchema: {{otherRelatedDBSchema}}" + "\nExemplar: {{exemplar}}"; + private final ObjectMapper objectMapper = JsonUtil.INSTANCE.getObjectMapper(); + + @Value("${s2.model.building.exemplars.enabled:true}") + private Boolean enableExemplarLoading; + public ModelIntelligentBuilder() { ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("构造数据语义模型") .appModule(AppModule.HEADLESS).description("通过大模型来构造数据语义模型").enable(true).build()); } - interface ModelSchemaExtractor { ModelSchema generateModelSchema(String text); } @@ -67,7 +77,8 @@ public class ModelIntelligentBuilder extends IntelligentBuilder { Prompt prompt = generatePrompt(dbSchema, otherDbSchema, chatApp.get()); ModelSchema modelSchema = extractor.generateModelSchema(prompt.toUserMessage().singleText()); - log.info("dbSchema: {} modelSchema: {}", JsonUtil.toString(dbSchema), + log.info("dbSchema: {}\n otherRelatedDBSchema:{}\n modelSchema: {}", + JsonUtil.toString(dbSchema), JsonUtil.toString(otherDbSchema), JsonUtil.toString(modelSchema)); return modelSchema; } @@ -82,7 +93,20 @@ public class ModelIntelligentBuilder extends IntelligentBuilder { } private String loadExemplars() { - // to add + if (!enableExemplarLoading) { + log.info("Not enable load model-building exemplars"); + return ""; + } + try { + ClassPathResource resource = new ClassPathResource(SYS_EXEMPLAR_FILE); + if (resource.exists()) { + InputStream inputStream = resource.getInputStream(); + return objectMapper + .writeValueAsString(objectMapper.readValue(inputStream, Object.class)); + } + } catch (Exception e) { + log.error("Failed to load model-building system exemplars", e); + } return ""; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java index b2dc364e7..c8189e63e 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java @@ -103,7 +103,7 @@ public class ModelServiceImpl implements ModelService { private ModelRelaService modelRelaService; ExecutorService executor = - new ThreadPoolExecutor(0, 5, 100L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>()); + new ThreadPoolExecutor(0, 5, 5L, TimeUnit.SECONDS, new LinkedBlockingQueue<>()); public ModelServiceImpl(ModelRepository modelRepository, DatabaseService databaseService, @Lazy DimensionService dimensionService, @Lazy MetricService metricService, @@ -219,12 +219,11 @@ public class ModelServiceImpl implements ModelService { @Override public Map buildModelSchema(ModelBuildReq modelBuildReq) throws SQLException { - Map> dbColumnMap = databaseService.getDbColumns(modelBuildReq); if (modelBuildReq.isBuildByLLM() && modelBuildReq.getChatModelConfig() == null) { ChatModel chatModel = chatModelService.getChatModel(modelBuildReq.getChatModelId()); modelBuildReq.setChatModelConfig(chatModel.getConfig()); } - List dbSchemas = convert(dbColumnMap, modelBuildReq); + List dbSchemas = getDbSchemes(modelBuildReq); Map modelSchemaMap = new ConcurrentHashMap<>(); CompletableFuture.allOf(dbSchemas.stream() .map(dbSchema -> CompletableFuture.runAsync( @@ -246,6 +245,11 @@ public class ModelServiceImpl implements ModelService { } } + private List getDbSchemes(ModelBuildReq modelBuildReq) throws SQLException { + Map> dbColumnMap = databaseService.getDbColumns(modelBuildReq); + return convert(dbColumnMap, modelBuildReq); + } + private List getOtherDbSchema(DbSchema curSchema, List dbSchemas) { return dbSchemas.stream() .filter(dbSchema -> !dbSchema.getTable().equals(curSchema.getTable())) diff --git a/launchers/standalone/src/main/resources/s2-buildModel-exemplar.json b/launchers/standalone/src/main/resources/s2-buildModel-exemplar.json new file mode 100644 index 000000000..6b18da740 --- /dev/null +++ b/launchers/standalone/src/main/resources/s2-buildModel-exemplar.json @@ -0,0 +1,234 @@ +[ + { + "dbSchema": { + "db": "semantic", + "table": "s2_stay_time_statis", + "dbColumns": [ + { + "columnName": "imp_date", + "dataType": "VARCHAR", + "comment": "" + }, + { + "columnName": "user_name", + "dataType": "VARCHAR", + "comment": "" + }, + { + "columnName": "stay_hours", + "dataType": "DOUBLE", + "comment": "" + }, + { + "columnName": "page", + "dataType": "VARCHAR", + "comment": "" + } + ] + }, + + "otherRelatedDBSchema":[ + { + "db": "semantic", + "table": "s2_user_department", + "dbColumns": [ + { + "columnName": "user_name", + "dataType": "VARCHAR", + "comment": "" + }, + { + "columnName": "department", + "dataType": "VARCHAR", + "comment": "" + } + ] + } + ], + + "modelSchema":{ + "name": "停留时间统计", + "bizName": "StayTimeStatistics", + "description": "记录用户在页面上的停留时间统计信息", + "filedSchemas": [ + { + "columnName": "imp_date", + "dataType": "VARCHAR", + "comment": "", + "filedType": "data_time", + "name": "导入日期" + }, + { + "columnName": "user_name", + "dataType": "VARCHAR", + "comment": "", + "filedType": "foreign_key", + "name": "用户名" + }, + { + "columnName": "stay_hours", + "dataType": "DOUBLE", + "comment": "", + "filedType": "measure", + "agg": "SUM", + "name": "停留小时数" + }, + { + "columnName": "page", + "dataType": "VARCHAR", + "comment": "", + "filedType": "dimension", + "name": "页面" + } + ] + } + }, + + { + "dbSchema":{ + "db": "semantic", + "table": "s2_user_department", + "dbColumns": [ + { + "columnName": "user_name", + "dataType": "VARCHAR", + "comment": "" + }, + { + "columnName": "department", + "dataType": "VARCHAR", + "comment": "" + } + ] + }, + "otherRelatedDBSchema":[ + { + "db": "semantic", + "table": "s2_stay_time_statis", + "dbColumns": [ + { + "columnName": "imp_date", + "dataType": "VARCHAR", + "comment": "" + }, + { + "columnName": "user_name", + "dataType": "VARCHAR", + "comment": "" + }, + { + "columnName": "stay_hours", + "dataType": "DOUBLE", + "comment": "" + }, + { + "columnName": "page", + "dataType": "VARCHAR", + "comment": "" + } + ] + } + ], + + "modelSchema":{ + "name": "用户部门信息", + "bizName": "UserDepartmentInfo", + "description": "记录用户所属的部门信息", + "filedSchemas": [ + { + "columnName": "user_name", + "dataType": "VARCHAR", + "comment": "", + "filedType": "primary_key", + "name": "用户名" + }, + { + "columnName": "department", + "dataType": "VARCHAR", + "comment": "", + "filedType": "dimension", + "name": "部门" + } + ] + } + }, + + { + "dbSchema": { + "db": "semantic", + "table": "SELECT imp_date, user_name, page, 1 as pv, user_name as uv FROM s2_pv_uv_statis", + "sql": "SELECT imp_date, user_name, page, 1 as pv, user_name as uv FROM s2_pv_uv_statis", + "dbColumns": [ + { + "columnName": "imp_date", + "dataType": "VARCHAR" + }, + { + "columnName": "user_name", + "dataType": "VARCHAR" + }, + { + "columnName": "page", + "dataType": "VARCHAR" + }, + { + "columnName": "pv", + "dataType": "INTEGER" + }, + { + "columnName": "uv", + "dataType": "VARCHAR" + } + ] + }, + + "otherRelatedDBSchema": [ + + ], + + "modelSchema":{ + "name": "页面访问统计", + "bizName": "PageVisitStatistics", + "description": "该模型用于统计用户在不同页面的访问情况,包括页面访问次数和独立用户数。", + "filedSchemas": [ + { + "columnName": "imp_date", + "dataType": "VARCHAR", + "comment": "数据生成时间", + "filedType": "data_time", + "name": "数据生成时间" + }, + { + "columnName": "user_name", + "dataType": "VARCHAR", + "comment": "用户名", + "filedType": "dimension", + "name": "用户名" + }, + { + "columnName": "page", + "dataType": "VARCHAR", + "comment": "页面名称", + "filedType": "dimension", + "name": "页面名称" + }, + { + "columnName": "pv", + "dataType": "INTEGER", + "comment": "页面访问次数", + "filedType": "measure", + "agg": "SUM", + "name": "页面访问次数" + }, + { + "columnName": "uv", + "dataType": "VARCHAR", + "comment": "独立用户数", + "filedType": "measure", + "agg": "COUNT_DISTINCT", + "name": "独立用户数" + } + ] + } + } +] \ No newline at end of file diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/ModelIntelligentBuildTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/ModelIntelligentBuildTest.java index b383c1ad4..11b68f27e 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/ModelIntelligentBuildTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/ModelIntelligentBuildTest.java @@ -13,11 +13,13 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.TestPropertySource; import java.sql.SQLException; import java.util.Map; @Disabled +@TestPropertySource(properties = {"s2.model.building.exemplars.enabled = false"}) public class ModelIntelligentBuildTest extends BaseTest { private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3; @@ -57,7 +59,6 @@ public class ModelIntelligentBuildTest extends BaseTest { stayTimeModelSchema.getFieldByName("stay_hours").getAgg()); } - @Test public void testBuildModelBySql() throws SQLException { ChatModelConfig llmConfig = LLMConfigUtils.getLLMConfig(llmType);