(improvement)(Headless) Add integration testing for building data-model by LLM (#1848)

Co-authored-by: lxwcodemonkey
This commit is contained in:
LXW
2024-10-27 18:16:59 +08:00
committed by GitHub
parent 3e0f724e97
commit bd82b0904b
12 changed files with 211 additions and 50 deletions

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.api.pojo; package com.tencent.supersonic.headless.api.pojo;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType; import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import lombok.Data; import lombok.Data;
@@ -14,5 +15,8 @@ public class FieldSchema {
private FieldType filedType; private FieldType filedType;
private AggOperatorEnum agg;
private String name; private String name;
} }

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.api.pojo; package com.tencent.supersonic.headless.api.pojo;
import com.fasterxml.jackson.annotation.JsonIgnore;
import lombok.Data; import lombok.Data;
import java.util.List; import java.util.List;
@@ -15,4 +16,14 @@ public class ModelSchema {
private List<FieldSchema> filedSchemas; private List<FieldSchema> filedSchemas;
@JsonIgnore
public FieldSchema getFieldByName(String columnName) {
for (FieldSchema fieldSchema : filedSchemas) {
if (fieldSchema.getColumnName().equalsIgnoreCase(columnName)) {
return fieldSchema;
}
}
return null;
}
} }

View File

@@ -6,7 +6,7 @@ import lombok.Data;
import java.util.List; import java.util.List;
@Data @Data
public class ModelSchemaReq { public class ModelBuildReq {
private Long databaseId; private Long databaseId;

View File

@@ -7,16 +7,19 @@ import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.DbSchema; import com.tencent.supersonic.headless.api.pojo.DbSchema;
import com.tencent.supersonic.headless.api.pojo.ModelSchema; import com.tencent.supersonic.headless.api.pojo.ModelSchema;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq; import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.service.AiServices; import dev.langchain4j.service.AiServices;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
@Slf4j
@Component @Component
public class ModelIntelligentBuilder extends IntelligentBuilder { public class ModelIntelligentBuilder extends IntelligentBuilder {
@@ -33,10 +36,13 @@ public class ModelIntelligentBuilder extends IntelligentBuilder {
+ "\n primary_key: This is a unique identifier for a record row in a database." + "\n primary_key: This is a unique identifier for a record row in a database."
+ "\n foreign_key: This is a key in a database whose value is derived from the primary key of another table." + "\n foreign_key: This is a key in a database whose value is derived from the primary key of another table."
+ "\n data_time: This represents the time when data is generated in the data warehouse." + "\n data_time: This represents the time when data is generated in the data warehouse."
+ "\n dimension: Usually a string type, used for grouping and filtering data." + "\n dimension: Usually a string type, used for grouping and filtering data. No need to generate aggregate functions"
+ "\n measure: Usually a numeric type, used to quantify data from a certain evaluative perspective." + "\n measure: Usually a numeric type, used to quantify data from a certain evaluative perspective. "
+ " Also, you need to generate aggregate functions(Eg: MAX, MIN, AVG, SUM, COUNT) for the measure type. "
+ "\nDBSchema: {{DBSchema}}" + "\nExemplar: {{exemplar}}"; + "\nTip: I will also give you other related dbSchemas. If you determine that different dbSchemas have the same fields, "
+ " they can be primary and foreign key relationships."
+ "\nDBSchema: {{DBSchema}}" + "\nOtherRelatedDBSchema: {{otherRelatedDBSchema}}"
+ "\nExemplar: {{exemplar}}";
public ModelIntelligentBuilder() { public ModelIntelligentBuilder() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("构造数据语义模型") ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("构造数据语义模型")
@@ -49,22 +55,29 @@ public class ModelIntelligentBuilder extends IntelligentBuilder {
} }
public ModelSchema build(DbSchema dbSchema, ModelSchemaReq modelSchemaReq) { public ModelSchema build(DbSchema dbSchema, List<DbSchema> otherDbSchema,
ModelBuildReq modelBuildReq) {
Optional<ChatApp> chatApp = ChatAppManager.getApp(APP_KEY); Optional<ChatApp> chatApp = ChatAppManager.getApp(APP_KEY);
if (!chatApp.isPresent() || !chatApp.get().isEnable()) { if (!chatApp.isPresent() || !chatApp.get().isEnable()) {
return null; return null;
} }
ChatModelConfig chatModelConfig = modelSchemaReq.getChatModelConfig(); ChatModelConfig chatModelConfig = modelBuildReq.getChatModelConfig();
ModelSchemaExtractor extractor = ModelSchemaExtractor extractor =
AiServices.create(ModelSchemaExtractor.class, getChatModel(chatModelConfig)); AiServices.create(ModelSchemaExtractor.class, getChatModel(chatModelConfig));
Prompt prompt = generatePrompt(dbSchema, chatApp.get()); Prompt prompt = generatePrompt(dbSchema, otherDbSchema, chatApp.get());
return extractor.generateModelSchema(prompt.toUserMessage().singleText()); ModelSchema modelSchema =
extractor.generateModelSchema(prompt.toUserMessage().singleText());
log.info("dbSchema: {} modelSchema: {}", JsonUtil.toString(dbSchema),
JsonUtil.toString(modelSchema));
return modelSchema;
} }
private Prompt generatePrompt(DbSchema dbSchema, ChatApp chatApp) { private Prompt generatePrompt(DbSchema dbSchema, List<DbSchema> otherDbSchema,
ChatApp chatApp) {
Map<String, Object> variable = new HashMap<>(); Map<String, Object> variable = new HashMap<>();
variable.put("exemplar", loadExemplars()); variable.put("exemplar", loadExemplars());
variable.put("DBSchema", JsonUtil.toString(dbSchema)); variable.put("DBSchema", JsonUtil.toString(dbSchema));
variable.put("otherRelatedDBSchema", JsonUtil.toString(otherDbSchema));
return PromptTemplate.from(chatApp.getPrompt()).apply(variable); return PromptTemplate.from(chatApp.getPrompt()).apply(variable);
} }

View File

@@ -7,7 +7,7 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.DBColumn; import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq; import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq; import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq; import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -96,9 +96,9 @@ public class DatabaseController {
} }
@PostMapping("/listColumnsBySql") @PostMapping("/listColumnsBySql")
public List<DBColumn> listColumnsBySql(@RequestBody ModelSchemaReq modelSchemaReq) public List<DBColumn> listColumnsBySql(@RequestBody ModelBuildReq modelBuildReq)
throws SQLException { throws SQLException {
return databaseService.getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql()); return databaseService.getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getSql());
} }
@GetMapping("/getDatabaseParameters") @GetMapping("/getDatabaseParameters")

View File

@@ -10,8 +10,8 @@ import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.headless.api.pojo.ModelSchema; import com.tencent.supersonic.headless.api.pojo.ModelSchema;
import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq; import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq;
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq; import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq; import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.UnAvailableItemResp; import com.tencent.supersonic.headless.api.pojo.response.UnAvailableItemResp;
@@ -111,8 +111,8 @@ public class ModelController {
} }
@PostMapping("/buildModelSchema") @PostMapping("/buildModelSchema")
public Map<String, ModelSchema> buildModelSchema(@RequestBody ModelSchemaReq modelSchemaReq) public Map<String, ModelSchema> buildModelSchema(@RequestBody ModelBuildReq modelBuildReq)
throws SQLException { throws SQLException {
return modelService.buildModelSchema(modelSchemaReq); return modelService.buildModelSchema(modelBuildReq);
} }
} }

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.server.service;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.DBColumn; import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq; import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq; import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq; import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -37,7 +37,7 @@ public interface DatabaseService {
List<String> getTables(Long id, String db) throws SQLException; List<String> getTables(Long id, String db) throws SQLException;
Map<String, List<DBColumn>> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException; Map<String, List<DBColumn>> getDbColumns(ModelBuildReq modelBuildReq) throws SQLException;
List<DBColumn> getColumns(Long id, String db, String table) throws SQLException; List<DBColumn> getColumns(Long id, String db, String table) throws SQLException;

View File

@@ -8,8 +8,8 @@ import com.tencent.supersonic.headless.api.pojo.MetaFilter;
import com.tencent.supersonic.headless.api.pojo.ModelSchema; import com.tencent.supersonic.headless.api.pojo.ModelSchema;
import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq; import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq;
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq; import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq; import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.UnAvailableItemResp; import com.tencent.supersonic.headless.api.pojo.response.UnAvailableItemResp;
@@ -35,7 +35,7 @@ public interface ModelService {
UnAvailableItemResp getUnAvailableItem(FieldRemovedReq fieldRemovedReq); UnAvailableItemResp getUnAvailableItem(FieldRemovedReq fieldRemovedReq);
Map<String, ModelSchema> buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException; Map<String, ModelSchema> buildModelSchema(ModelBuildReq modelBuildReq) throws SQLException;
List<ModelResp> getModelListWithAuth(User user, Long domainId, AuthType authType); List<ModelResp> getModelListWithAuth(User user, Long domainId, AuthType authType);

View File

@@ -7,7 +7,7 @@ import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.EngineType; import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.headless.api.pojo.DBColumn; import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq; import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq; import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq; import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
@@ -208,17 +208,17 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
} }
@Override @Override
public Map<String, List<DBColumn>> getDbColumns(ModelSchemaReq modelSchemaReq) public Map<String, List<DBColumn>> getDbColumns(ModelBuildReq modelBuildReq)
throws SQLException { throws SQLException {
Map<String, List<DBColumn>> dbColumnMap = new HashMap<>(); Map<String, List<DBColumn>> dbColumnMap = new HashMap<>();
if (StringUtils.isNotBlank(modelSchemaReq.getSql())) { if (StringUtils.isNotBlank(modelBuildReq.getSql())) {
List<DBColumn> columns = List<DBColumn> columns =
getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql()); getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getSql());
dbColumnMap.put(modelSchemaReq.getSql(), columns); dbColumnMap.put(modelBuildReq.getSql(), columns);
} else { } else {
for (String table : modelSchemaReq.getTables()) { for (String table : modelBuildReq.getTables()) {
List<DBColumn> columns = List<DBColumn> columns =
getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getDb(), table); getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getDb(), table);
dbColumnMap.put(table, columns); dbColumnMap.put(table, columns);
} }
} }

View File

@@ -26,8 +26,8 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionReq;
import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq; import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq;
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq; import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
import com.tencent.supersonic.headless.api.pojo.request.MetricReq; import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq; import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp; import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
@@ -41,7 +41,13 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.ModelDO;
import com.tencent.supersonic.headless.server.persistence.repository.DateInfoRepository; import com.tencent.supersonic.headless.server.persistence.repository.DateInfoRepository;
import com.tencent.supersonic.headless.server.persistence.repository.ModelRepository; import com.tencent.supersonic.headless.server.persistence.repository.ModelRepository;
import com.tencent.supersonic.headless.server.pojo.ModelFilter; import com.tencent.supersonic.headless.server.pojo.ModelFilter;
import com.tencent.supersonic.headless.server.service.*; import com.tencent.supersonic.headless.server.service.DataSetService;
import com.tencent.supersonic.headless.server.service.DatabaseService;
import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.DomainService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.ModelRelaService;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.utils.ModelConverter; import com.tencent.supersonic.headless.server.utils.ModelConverter;
import com.tencent.supersonic.headless.server.utils.NameCheckUtils; import com.tencent.supersonic.headless.server.utils.NameCheckUtils;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -53,7 +59,21 @@ import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.*; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Service @Service
@@ -82,6 +102,9 @@ public class ModelServiceImpl implements ModelService {
private ModelRelaService modelRelaService; private ModelRelaService modelRelaService;
ExecutorService executor =
new ThreadPoolExecutor(0, 5, 100L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>());
public ModelServiceImpl(ModelRepository modelRepository, DatabaseService databaseService, public ModelServiceImpl(ModelRepository modelRepository, DatabaseService databaseService,
@Lazy DimensionService dimensionService, @Lazy MetricService metricService, @Lazy DimensionService dimensionService, @Lazy MetricService metricService,
DomainService domainService, UserService userService, DataSetService dataSetService, DomainService domainService, UserService userService, DataSetService dataSetService,
@@ -194,29 +217,49 @@ public class ModelServiceImpl implements ModelService {
} }
@Override @Override
public Map<String, ModelSchema> buildModelSchema(ModelSchemaReq modelSchemaReq) public Map<String, ModelSchema> buildModelSchema(ModelBuildReq modelBuildReq)
throws SQLException { throws SQLException {
Map<String, List<DBColumn>> dbColumnMap = databaseService.getDbColumns(modelSchemaReq); Map<String, List<DBColumn>> dbColumnMap = databaseService.getDbColumns(modelBuildReq);
Map<String, ModelSchema> modelSchemaMap = new HashMap<>(); if (modelBuildReq.isBuildByLLM() && modelBuildReq.getChatModelConfig() == null) {
if (modelSchemaReq.isBuildByLLM()) { ChatModel chatModel = chatModelService.getChatModel(modelBuildReq.getChatModelId());
ChatModel chatModel = chatModelService.getChatModel(modelSchemaReq.getChatModelId()); modelBuildReq.setChatModelConfig(chatModel.getConfig());
modelSchemaReq.setChatModelConfig(chatModel.getConfig());
}
for (Map.Entry<String, List<DBColumn>> entry : dbColumnMap.entrySet()) {
if (modelSchemaReq.isBuildByLLM()) {
DbSchema dbSchema = convert(modelSchemaReq, entry.getKey(), entry.getValue());
ModelSchema modelSchema = modelIntelligentBuilder.build(dbSchema, modelSchemaReq);
if (modelSchema != null) {
modelSchemaMap.put(entry.getKey(), modelSchema);
}
} else {
modelSchemaMap.put(entry.getKey(), build(entry.getValue()));
}
} }
List<DbSchema> dbSchemas = convert(dbColumnMap, modelBuildReq);
Map<String, ModelSchema> modelSchemaMap = new ConcurrentHashMap<>();
CompletableFuture.allOf(dbSchemas.stream()
.map(dbSchema -> CompletableFuture.runAsync(
() -> doBuild(modelBuildReq, dbSchema, dbSchemas, modelSchemaMap),
executor))
.toArray(CompletableFuture[]::new)).join();
return modelSchemaMap; return modelSchemaMap;
} }
private DbSchema convert(ModelSchemaReq modelSchemaReq, String key, List<DBColumn> dbColumns) { private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List<DbSchema> dbSchemas,
Map<String, ModelSchema> modelSchemaMap) {
if (modelBuildReq.isBuildByLLM()) {
List<DbSchema> otherDbSchema = getOtherDbSchema(curSchema, dbSchemas);
ModelSchema modelSchema =
modelIntelligentBuilder.build(curSchema, otherDbSchema, modelBuildReq);
modelSchemaMap.put(curSchema.getTable(), modelSchema);
} else {
modelSchemaMap.put(curSchema.getTable(), build(curSchema.getDbColumns()));
}
}
private List<DbSchema> getOtherDbSchema(DbSchema curSchema, List<DbSchema> dbSchemas) {
return dbSchemas.stream()
.filter(dbSchema -> !dbSchema.getTable().equals(curSchema.getTable()))
.collect(Collectors.toList());
}
private List<DbSchema> convert(Map<String, List<DBColumn>> dbColumnMap,
ModelBuildReq modelSchemaReq) {
return dbColumnMap.keySet().stream()
.map(key -> convert(modelSchemaReq, key, dbColumnMap.get(key)))
.collect(Collectors.toList());
}
private DbSchema convert(ModelBuildReq modelSchemaReq, String key, List<DBColumn> dbColumns) {
DbSchema dbSchema = new DbSchema(); DbSchema dbSchema = new DbSchema();
dbSchema.setDb(modelSchemaReq.getDb()); dbSchema.setDb(modelSchemaReq.getDb());
dbSchema.setTable(key); dbSchema.setTable(key);

View File

@@ -0,0 +1,90 @@
package com.tencent.supersonic.headless;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.util.LLMConfigUtils;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import java.sql.SQLException;
import java.util.Map;
@Disabled
public class ModelIntelligentBuildTest extends BaseTest {
private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3;
@Autowired
private ModelService modelService;
@Test
public void testBuildModelBatch() throws SQLException {
ChatModelConfig llmConfig = LLMConfigUtils.getLLMConfig(llmType);
ModelBuildReq modelSchemaReq = new ModelBuildReq();
modelSchemaReq.setChatModelConfig(llmConfig);
modelSchemaReq.setBuildByLLM(true);
modelSchemaReq.setDatabaseId(1L);
modelSchemaReq.setDb("semantic");
modelSchemaReq.setTables(Lists.newArrayList("s2_user_department", "s2_stay_time_statis"));
Map<String, ModelSchema> modelSchemaMap = modelService.buildModelSchema(modelSchemaReq);
ModelSchema userModelSchema = modelSchemaMap.get("s2_user_department");
Assertions.assertEquals(2, userModelSchema.getFiledSchemas().size());
Assertions.assertEquals(FieldType.primary_key,
userModelSchema.getFieldByName("user_name").getFiledType());
Assertions.assertEquals(FieldType.dimension,
userModelSchema.getFieldByName("department").getFiledType());
ModelSchema stayTimeModelSchema = modelSchemaMap.get("s2_stay_time_statis");
Assertions.assertEquals(4, stayTimeModelSchema.getFiledSchemas().size());
Assertions.assertEquals(FieldType.foreign_key,
stayTimeModelSchema.getFieldByName("user_name").getFiledType());
Assertions.assertEquals(FieldType.data_time,
stayTimeModelSchema.getFieldByName("imp_date").getFiledType());
Assertions.assertEquals(FieldType.dimension,
stayTimeModelSchema.getFieldByName("page").getFiledType());
Assertions.assertEquals(FieldType.measure,
stayTimeModelSchema.getFieldByName("stay_hours").getFiledType());
Assertions.assertEquals(AggOperatorEnum.SUM,
stayTimeModelSchema.getFieldByName("stay_hours").getAgg());
}
@Test
public void testBuildModelBySql() throws SQLException {
ChatModelConfig llmConfig = LLMConfigUtils.getLLMConfig(llmType);
ModelBuildReq modelSchemaReq = new ModelBuildReq();
modelSchemaReq.setChatModelConfig(llmConfig);
modelSchemaReq.setBuildByLLM(true);
modelSchemaReq.setDatabaseId(1L);
modelSchemaReq.setDb("semantic");
modelSchemaReq.setSql(
"SELECT imp_date, user_name, page, 1 as pv, user_name as uv FROM s2_pv_uv_statis");
Map<String, ModelSchema> modelSchemaMap = modelService.buildModelSchema(modelSchemaReq);
ModelSchema pvModelSchema = modelSchemaMap.values().iterator().next();
Assertions.assertEquals(5, pvModelSchema.getFiledSchemas().size());
Assertions.assertEquals(FieldType.data_time,
pvModelSchema.getFieldByName("imp_date").getFiledType());
Assertions.assertEquals(FieldType.dimension,
pvModelSchema.getFieldByName("user_name").getFiledType());
Assertions.assertEquals(FieldType.dimension,
pvModelSchema.getFieldByName("page").getFiledType());
Assertions.assertEquals(FieldType.measure,
pvModelSchema.getFieldByName("pv").getFiledType());
Assertions.assertEquals(AggOperatorEnum.SUM, pvModelSchema.getFieldByName("pv").getAgg());
Assertions.assertEquals(FieldType.measure,
pvModelSchema.getFieldByName("uv").getFiledType());
Assertions.assertEquals(AggOperatorEnum.COUNT_DISTINCT,
pvModelSchema.getFieldByName("uv").getAgg());
}
}

View File

@@ -1,7 +1,7 @@
spring: spring:
datasource: datasource:
driver-class-name: org.h2.Driver driver-class-name: org.h2.Driver
url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false;QUERY_TIMEOUT=30 url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false;QUERY_TIMEOUT=100
username: root username: root
password: semantic password: semantic
sql: sql: