mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
(improvement)(Headless) Add integration testing for building data-model by LLM (#1848)
Co-authored-by: lxwcodemonkey
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.headless.api.pojo;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
|
||||
import lombok.Data;
|
||||
|
||||
@@ -14,5 +15,8 @@ public class FieldSchema {
|
||||
|
||||
private FieldType filedType;
|
||||
|
||||
private AggOperatorEnum agg;
|
||||
|
||||
private String name;
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.headless.api.pojo;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
@@ -15,4 +16,14 @@ public class ModelSchema {
|
||||
|
||||
private List<FieldSchema> filedSchemas;
|
||||
|
||||
@JsonIgnore
|
||||
public FieldSchema getFieldByName(String columnName) {
|
||||
for (FieldSchema fieldSchema : filedSchemas) {
|
||||
if (fieldSchema.getColumnName().equalsIgnoreCase(columnName)) {
|
||||
return fieldSchema;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import lombok.Data;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ModelSchemaReq {
|
||||
public class ModelBuildReq {
|
||||
|
||||
private Long databaseId;
|
||||
|
||||
@@ -7,16 +7,19 @@ import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.DbSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class ModelIntelligentBuilder extends IntelligentBuilder {
|
||||
|
||||
@@ -33,10 +36,13 @@ public class ModelIntelligentBuilder extends IntelligentBuilder {
|
||||
+ "\n primary_key: This is a unique identifier for a record row in a database."
|
||||
+ "\n foreign_key: This is a key in a database whose value is derived from the primary key of another table."
|
||||
+ "\n data_time: This represents the time when data is generated in the data warehouse."
|
||||
+ "\n dimension: Usually a string type, used for grouping and filtering data."
|
||||
+ "\n dimension: Usually a string type, used for grouping and filtering data. No need to generate aggregate functions"
|
||||
+ "\n measure: Usually a numeric type, used to quantify data from a certain evaluative perspective. "
|
||||
|
||||
+ "\nDBSchema: {{DBSchema}}" + "\nExemplar: {{exemplar}}";
|
||||
+ " Also, you need to generate aggregate functions(Eg: MAX, MIN, AVG, SUM, COUNT) for the measure type. "
|
||||
+ "\nTip: I will also give you other related dbSchemas. If you determine that different dbSchemas have the same fields, "
|
||||
+ " they can be primary and foreign key relationships."
|
||||
+ "\nDBSchema: {{DBSchema}}" + "\nOtherRelatedDBSchema: {{otherRelatedDBSchema}}"
|
||||
+ "\nExemplar: {{exemplar}}";
|
||||
|
||||
public ModelIntelligentBuilder() {
|
||||
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("构造数据语义模型")
|
||||
@@ -49,22 +55,29 @@ public class ModelIntelligentBuilder extends IntelligentBuilder {
|
||||
}
|
||||
|
||||
|
||||
public ModelSchema build(DbSchema dbSchema, ModelSchemaReq modelSchemaReq) {
|
||||
public ModelSchema build(DbSchema dbSchema, List<DbSchema> otherDbSchema,
|
||||
ModelBuildReq modelBuildReq) {
|
||||
Optional<ChatApp> chatApp = ChatAppManager.getApp(APP_KEY);
|
||||
if (!chatApp.isPresent() || !chatApp.get().isEnable()) {
|
||||
return null;
|
||||
}
|
||||
ChatModelConfig chatModelConfig = modelSchemaReq.getChatModelConfig();
|
||||
ChatModelConfig chatModelConfig = modelBuildReq.getChatModelConfig();
|
||||
ModelSchemaExtractor extractor =
|
||||
AiServices.create(ModelSchemaExtractor.class, getChatModel(chatModelConfig));
|
||||
Prompt prompt = generatePrompt(dbSchema, chatApp.get());
|
||||
return extractor.generateModelSchema(prompt.toUserMessage().singleText());
|
||||
Prompt prompt = generatePrompt(dbSchema, otherDbSchema, chatApp.get());
|
||||
ModelSchema modelSchema =
|
||||
extractor.generateModelSchema(prompt.toUserMessage().singleText());
|
||||
log.info("dbSchema: {} modelSchema: {}", JsonUtil.toString(dbSchema),
|
||||
JsonUtil.toString(modelSchema));
|
||||
return modelSchema;
|
||||
}
|
||||
|
||||
private Prompt generatePrompt(DbSchema dbSchema, ChatApp chatApp) {
|
||||
private Prompt generatePrompt(DbSchema dbSchema, List<DbSchema> otherDbSchema,
|
||||
ChatApp chatApp) {
|
||||
Map<String, Object> variable = new HashMap<>();
|
||||
variable.put("exemplar", loadExemplars());
|
||||
variable.put("DBSchema", JsonUtil.toString(dbSchema));
|
||||
variable.put("otherRelatedDBSchema", JsonUtil.toString(otherDbSchema));
|
||||
return PromptTemplate.from(chatApp.getPrompt()).apply(variable);
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
@@ -96,9 +96,9 @@ public class DatabaseController {
|
||||
}
|
||||
|
||||
@PostMapping("/listColumnsBySql")
|
||||
public List<DBColumn> listColumnsBySql(@RequestBody ModelSchemaReq modelSchemaReq)
|
||||
public List<DBColumn> listColumnsBySql(@RequestBody ModelBuildReq modelBuildReq)
|
||||
throws SQLException {
|
||||
return databaseService.getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql());
|
||||
return databaseService.getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getSql());
|
||||
}
|
||||
|
||||
@GetMapping("/getDatabaseParameters")
|
||||
|
||||
@@ -10,8 +10,8 @@ import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.UnAvailableItemResp;
|
||||
@@ -111,8 +111,8 @@ public class ModelController {
|
||||
}
|
||||
|
||||
@PostMapping("/buildModelSchema")
|
||||
public Map<String, ModelSchema> buildModelSchema(@RequestBody ModelSchemaReq modelSchemaReq)
|
||||
public Map<String, ModelSchema> buildModelSchema(@RequestBody ModelBuildReq modelBuildReq)
|
||||
throws SQLException {
|
||||
return modelService.buildModelSchema(modelSchemaReq);
|
||||
return modelService.buildModelSchema(modelBuildReq);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.server.service;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
@@ -37,7 +37,7 @@ public interface DatabaseService {
|
||||
|
||||
List<String> getTables(Long id, String db) throws SQLException;
|
||||
|
||||
Map<String, List<DBColumn>> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException;
|
||||
Map<String, List<DBColumn>> getDbColumns(ModelBuildReq modelBuildReq) throws SQLException;
|
||||
|
||||
List<DBColumn> getColumns(Long id, String db, String table) throws SQLException;
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ import com.tencent.supersonic.headless.api.pojo.MetaFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.UnAvailableItemResp;
|
||||
@@ -35,7 +35,7 @@ public interface ModelService {
|
||||
|
||||
UnAvailableItemResp getUnAvailableItem(FieldRemovedReq fieldRemovedReq);
|
||||
|
||||
Map<String, ModelSchema> buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException;
|
||||
Map<String, ModelSchema> buildModelSchema(ModelBuildReq modelBuildReq) throws SQLException;
|
||||
|
||||
List<ModelResp> getModelListWithAuth(User user, Long domainId, AuthType authType);
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.EngineType;
|
||||
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||
@@ -208,17 +208,17 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, List<DBColumn>> getDbColumns(ModelSchemaReq modelSchemaReq)
|
||||
public Map<String, List<DBColumn>> getDbColumns(ModelBuildReq modelBuildReq)
|
||||
throws SQLException {
|
||||
Map<String, List<DBColumn>> dbColumnMap = new HashMap<>();
|
||||
if (StringUtils.isNotBlank(modelSchemaReq.getSql())) {
|
||||
if (StringUtils.isNotBlank(modelBuildReq.getSql())) {
|
||||
List<DBColumn> columns =
|
||||
getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql());
|
||||
dbColumnMap.put(modelSchemaReq.getSql(), columns);
|
||||
getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getSql());
|
||||
dbColumnMap.put(modelBuildReq.getSql(), columns);
|
||||
} else {
|
||||
for (String table : modelSchemaReq.getTables()) {
|
||||
for (String table : modelBuildReq.getTables()) {
|
||||
List<DBColumn> columns =
|
||||
getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getDb(), table);
|
||||
getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getDb(), table);
|
||||
dbColumnMap.put(table, columns);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,8 +26,8 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
@@ -41,7 +41,13 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.ModelDO;
|
||||
import com.tencent.supersonic.headless.server.persistence.repository.DateInfoRepository;
|
||||
import com.tencent.supersonic.headless.server.persistence.repository.ModelRepository;
|
||||
import com.tencent.supersonic.headless.server.pojo.ModelFilter;
|
||||
import com.tencent.supersonic.headless.server.service.*;
|
||||
import com.tencent.supersonic.headless.server.service.DataSetService;
|
||||
import com.tencent.supersonic.headless.server.service.DatabaseService;
|
||||
import com.tencent.supersonic.headless.server.service.DimensionService;
|
||||
import com.tencent.supersonic.headless.server.service.DomainService;
|
||||
import com.tencent.supersonic.headless.server.service.MetricService;
|
||||
import com.tencent.supersonic.headless.server.service.ModelRelaService;
|
||||
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||
import com.tencent.supersonic.headless.server.utils.ModelConverter;
|
||||
import com.tencent.supersonic.headless.server.utils.NameCheckUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -53,7 +59,21 @@ import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.sql.SQLException;
|
||||
import java.util.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.Date;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.ThreadPoolExecutor;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@@ -82,6 +102,9 @@ public class ModelServiceImpl implements ModelService {
|
||||
|
||||
private ModelRelaService modelRelaService;
|
||||
|
||||
ExecutorService executor =
|
||||
new ThreadPoolExecutor(0, 5, 100L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>());
|
||||
|
||||
public ModelServiceImpl(ModelRepository modelRepository, DatabaseService databaseService,
|
||||
@Lazy DimensionService dimensionService, @Lazy MetricService metricService,
|
||||
DomainService domainService, UserService userService, DataSetService dataSetService,
|
||||
@@ -194,29 +217,49 @@ public class ModelServiceImpl implements ModelService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, ModelSchema> buildModelSchema(ModelSchemaReq modelSchemaReq)
|
||||
public Map<String, ModelSchema> buildModelSchema(ModelBuildReq modelBuildReq)
|
||||
throws SQLException {
|
||||
Map<String, List<DBColumn>> dbColumnMap = databaseService.getDbColumns(modelSchemaReq);
|
||||
Map<String, ModelSchema> modelSchemaMap = new HashMap<>();
|
||||
if (modelSchemaReq.isBuildByLLM()) {
|
||||
ChatModel chatModel = chatModelService.getChatModel(modelSchemaReq.getChatModelId());
|
||||
modelSchemaReq.setChatModelConfig(chatModel.getConfig());
|
||||
}
|
||||
for (Map.Entry<String, List<DBColumn>> entry : dbColumnMap.entrySet()) {
|
||||
if (modelSchemaReq.isBuildByLLM()) {
|
||||
DbSchema dbSchema = convert(modelSchemaReq, entry.getKey(), entry.getValue());
|
||||
ModelSchema modelSchema = modelIntelligentBuilder.build(dbSchema, modelSchemaReq);
|
||||
if (modelSchema != null) {
|
||||
modelSchemaMap.put(entry.getKey(), modelSchema);
|
||||
}
|
||||
} else {
|
||||
modelSchemaMap.put(entry.getKey(), build(entry.getValue()));
|
||||
}
|
||||
Map<String, List<DBColumn>> dbColumnMap = databaseService.getDbColumns(modelBuildReq);
|
||||
if (modelBuildReq.isBuildByLLM() && modelBuildReq.getChatModelConfig() == null) {
|
||||
ChatModel chatModel = chatModelService.getChatModel(modelBuildReq.getChatModelId());
|
||||
modelBuildReq.setChatModelConfig(chatModel.getConfig());
|
||||
}
|
||||
List<DbSchema> dbSchemas = convert(dbColumnMap, modelBuildReq);
|
||||
Map<String, ModelSchema> modelSchemaMap = new ConcurrentHashMap<>();
|
||||
CompletableFuture.allOf(dbSchemas.stream()
|
||||
.map(dbSchema -> CompletableFuture.runAsync(
|
||||
() -> doBuild(modelBuildReq, dbSchema, dbSchemas, modelSchemaMap),
|
||||
executor))
|
||||
.toArray(CompletableFuture[]::new)).join();
|
||||
return modelSchemaMap;
|
||||
}
|
||||
|
||||
private DbSchema convert(ModelSchemaReq modelSchemaReq, String key, List<DBColumn> dbColumns) {
|
||||
private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List<DbSchema> dbSchemas,
|
||||
Map<String, ModelSchema> modelSchemaMap) {
|
||||
if (modelBuildReq.isBuildByLLM()) {
|
||||
List<DbSchema> otherDbSchema = getOtherDbSchema(curSchema, dbSchemas);
|
||||
ModelSchema modelSchema =
|
||||
modelIntelligentBuilder.build(curSchema, otherDbSchema, modelBuildReq);
|
||||
modelSchemaMap.put(curSchema.getTable(), modelSchema);
|
||||
} else {
|
||||
modelSchemaMap.put(curSchema.getTable(), build(curSchema.getDbColumns()));
|
||||
}
|
||||
}
|
||||
|
||||
private List<DbSchema> getOtherDbSchema(DbSchema curSchema, List<DbSchema> dbSchemas) {
|
||||
return dbSchemas.stream()
|
||||
.filter(dbSchema -> !dbSchema.getTable().equals(curSchema.getTable()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private List<DbSchema> convert(Map<String, List<DBColumn>> dbColumnMap,
|
||||
ModelBuildReq modelSchemaReq) {
|
||||
return dbColumnMap.keySet().stream()
|
||||
.map(key -> convert(modelSchemaReq, key, dbColumnMap.get(key)))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private DbSchema convert(ModelBuildReq modelSchemaReq, String key, List<DBColumn> dbColumns) {
|
||||
DbSchema dbSchema = new DbSchema();
|
||||
dbSchema.setDb(modelSchemaReq.getDb());
|
||||
dbSchema.setTable(key);
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
package com.tencent.supersonic.headless;
|
||||
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
|
||||
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||
import com.tencent.supersonic.util.LLMConfigUtils;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
|
||||
import java.sql.SQLException;
|
||||
import java.util.Map;
|
||||
|
||||
@Disabled
|
||||
public class ModelIntelligentBuildTest extends BaseTest {
|
||||
|
||||
private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3;
|
||||
|
||||
@Autowired
|
||||
private ModelService modelService;
|
||||
|
||||
@Test
|
||||
public void testBuildModelBatch() throws SQLException {
|
||||
ChatModelConfig llmConfig = LLMConfigUtils.getLLMConfig(llmType);
|
||||
ModelBuildReq modelSchemaReq = new ModelBuildReq();
|
||||
modelSchemaReq.setChatModelConfig(llmConfig);
|
||||
modelSchemaReq.setBuildByLLM(true);
|
||||
modelSchemaReq.setDatabaseId(1L);
|
||||
modelSchemaReq.setDb("semantic");
|
||||
modelSchemaReq.setTables(Lists.newArrayList("s2_user_department", "s2_stay_time_statis"));
|
||||
Map<String, ModelSchema> modelSchemaMap = modelService.buildModelSchema(modelSchemaReq);
|
||||
|
||||
ModelSchema userModelSchema = modelSchemaMap.get("s2_user_department");
|
||||
Assertions.assertEquals(2, userModelSchema.getFiledSchemas().size());
|
||||
Assertions.assertEquals(FieldType.primary_key,
|
||||
userModelSchema.getFieldByName("user_name").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
userModelSchema.getFieldByName("department").getFiledType());
|
||||
|
||||
ModelSchema stayTimeModelSchema = modelSchemaMap.get("s2_stay_time_statis");
|
||||
Assertions.assertEquals(4, stayTimeModelSchema.getFiledSchemas().size());
|
||||
Assertions.assertEquals(FieldType.foreign_key,
|
||||
stayTimeModelSchema.getFieldByName("user_name").getFiledType());
|
||||
Assertions.assertEquals(FieldType.data_time,
|
||||
stayTimeModelSchema.getFieldByName("imp_date").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
stayTimeModelSchema.getFieldByName("page").getFiledType());
|
||||
Assertions.assertEquals(FieldType.measure,
|
||||
stayTimeModelSchema.getFieldByName("stay_hours").getFiledType());
|
||||
Assertions.assertEquals(AggOperatorEnum.SUM,
|
||||
stayTimeModelSchema.getFieldByName("stay_hours").getAgg());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testBuildModelBySql() throws SQLException {
|
||||
ChatModelConfig llmConfig = LLMConfigUtils.getLLMConfig(llmType);
|
||||
ModelBuildReq modelSchemaReq = new ModelBuildReq();
|
||||
modelSchemaReq.setChatModelConfig(llmConfig);
|
||||
modelSchemaReq.setBuildByLLM(true);
|
||||
modelSchemaReq.setDatabaseId(1L);
|
||||
modelSchemaReq.setDb("semantic");
|
||||
modelSchemaReq.setSql(
|
||||
"SELECT imp_date, user_name, page, 1 as pv, user_name as uv FROM s2_pv_uv_statis");
|
||||
Map<String, ModelSchema> modelSchemaMap = modelService.buildModelSchema(modelSchemaReq);
|
||||
|
||||
ModelSchema pvModelSchema = modelSchemaMap.values().iterator().next();
|
||||
Assertions.assertEquals(5, pvModelSchema.getFiledSchemas().size());
|
||||
Assertions.assertEquals(FieldType.data_time,
|
||||
pvModelSchema.getFieldByName("imp_date").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
pvModelSchema.getFieldByName("user_name").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
pvModelSchema.getFieldByName("page").getFiledType());
|
||||
Assertions.assertEquals(FieldType.measure,
|
||||
pvModelSchema.getFieldByName("pv").getFiledType());
|
||||
Assertions.assertEquals(AggOperatorEnum.SUM, pvModelSchema.getFieldByName("pv").getAgg());
|
||||
Assertions.assertEquals(FieldType.measure,
|
||||
pvModelSchema.getFieldByName("uv").getFiledType());
|
||||
Assertions.assertEquals(AggOperatorEnum.COUNT_DISTINCT,
|
||||
pvModelSchema.getFieldByName("uv").getAgg());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
spring:
|
||||
datasource:
|
||||
driver-class-name: org.h2.Driver
|
||||
url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false;QUERY_TIMEOUT=30
|
||||
url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false;QUERY_TIMEOUT=100
|
||||
username: root
|
||||
password: semantic
|
||||
sql:
|
||||
|
||||
Reference in New Issue
Block a user