(improvement)(Headless) Add integration testing for building data-model by LLM (#1848)

Co-authored-by: lxwcodemonkey
This commit is contained in:
LXW
2024-10-27 18:16:59 +08:00
committed by GitHub
parent 3e0f724e97
commit bd82b0904b
12 changed files with 211 additions and 50 deletions

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.api.pojo;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import lombok.Data;
@@ -14,5 +15,8 @@ public class FieldSchema {
private FieldType filedType;
private AggOperatorEnum agg;
private String name;
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.api.pojo;
import com.fasterxml.jackson.annotation.JsonIgnore;
import lombok.Data;
import java.util.List;
@@ -15,4 +16,14 @@ public class ModelSchema {
private List<FieldSchema> filedSchemas;
@JsonIgnore
public FieldSchema getFieldByName(String columnName) {
for (FieldSchema fieldSchema : filedSchemas) {
if (fieldSchema.getColumnName().equalsIgnoreCase(columnName)) {
return fieldSchema;
}
}
return null;
}
}

View File

@@ -6,7 +6,7 @@ import lombok.Data;
import java.util.List;
@Data
public class ModelSchemaReq {
public class ModelBuildReq {
private Long databaseId;

View File

@@ -7,16 +7,19 @@ import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.DbSchema;
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.service.AiServices;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@Slf4j
@Component
public class ModelIntelligentBuilder extends IntelligentBuilder {
@@ -33,10 +36,13 @@ public class ModelIntelligentBuilder extends IntelligentBuilder {
+ "\n primary_key: This is a unique identifier for a record row in a database."
+ "\n foreign_key: This is a key in a database whose value is derived from the primary key of another table."
+ "\n data_time: This represents the time when data is generated in the data warehouse."
+ "\n dimension: Usually a string type, used for grouping and filtering data."
+ "\n dimension: Usually a string type, used for grouping and filtering data. No need to generate aggregate functions"
+ "\n measure: Usually a numeric type, used to quantify data from a certain evaluative perspective. "
+ "\nDBSchema: {{DBSchema}}" + "\nExemplar: {{exemplar}}";
+ " Also, you need to generate aggregate functions(Eg: MAX, MIN, AVG, SUM, COUNT) for the measure type. "
+ "\nTip: I will also give you other related dbSchemas. If you determine that different dbSchemas have the same fields, "
+ " they can be primary and foreign key relationships."
+ "\nDBSchema: {{DBSchema}}" + "\nOtherRelatedDBSchema: {{otherRelatedDBSchema}}"
+ "\nExemplar: {{exemplar}}";
public ModelIntelligentBuilder() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("构造数据语义模型")
@@ -49,22 +55,29 @@ public class ModelIntelligentBuilder extends IntelligentBuilder {
}
public ModelSchema build(DbSchema dbSchema, ModelSchemaReq modelSchemaReq) {
public ModelSchema build(DbSchema dbSchema, List<DbSchema> otherDbSchema,
ModelBuildReq modelBuildReq) {
Optional<ChatApp> chatApp = ChatAppManager.getApp(APP_KEY);
if (!chatApp.isPresent() || !chatApp.get().isEnable()) {
return null;
}
ChatModelConfig chatModelConfig = modelSchemaReq.getChatModelConfig();
ChatModelConfig chatModelConfig = modelBuildReq.getChatModelConfig();
ModelSchemaExtractor extractor =
AiServices.create(ModelSchemaExtractor.class, getChatModel(chatModelConfig));
Prompt prompt = generatePrompt(dbSchema, chatApp.get());
return extractor.generateModelSchema(prompt.toUserMessage().singleText());
Prompt prompt = generatePrompt(dbSchema, otherDbSchema, chatApp.get());
ModelSchema modelSchema =
extractor.generateModelSchema(prompt.toUserMessage().singleText());
log.info("dbSchema: {} modelSchema: {}", JsonUtil.toString(dbSchema),
JsonUtil.toString(modelSchema));
return modelSchema;
}
private Prompt generatePrompt(DbSchema dbSchema, ChatApp chatApp) {
private Prompt generatePrompt(DbSchema dbSchema, List<DbSchema> otherDbSchema,
ChatApp chatApp) {
Map<String, Object> variable = new HashMap<>();
variable.put("exemplar", loadExemplars());
variable.put("DBSchema", JsonUtil.toString(dbSchema));
variable.put("otherRelatedDBSchema", JsonUtil.toString(otherDbSchema));
return PromptTemplate.from(chatApp.getPrompt()).apply(variable);
}

View File

@@ -7,7 +7,7 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -96,9 +96,9 @@ public class DatabaseController {
}
@PostMapping("/listColumnsBySql")
public List<DBColumn> listColumnsBySql(@RequestBody ModelSchemaReq modelSchemaReq)
public List<DBColumn> listColumnsBySql(@RequestBody ModelBuildReq modelBuildReq)
throws SQLException {
return databaseService.getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql());
return databaseService.getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getSql());
}
@GetMapping("/getDatabaseParameters")

View File

@@ -10,8 +10,8 @@ import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq;
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.UnAvailableItemResp;
@@ -111,8 +111,8 @@ public class ModelController {
}
@PostMapping("/buildModelSchema")
public Map<String, ModelSchema> buildModelSchema(@RequestBody ModelSchemaReq modelSchemaReq)
public Map<String, ModelSchema> buildModelSchema(@RequestBody ModelBuildReq modelBuildReq)
throws SQLException {
return modelService.buildModelSchema(modelSchemaReq);
return modelService.buildModelSchema(modelBuildReq);
}
}

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.server.service;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -37,7 +37,7 @@ public interface DatabaseService {
List<String> getTables(Long id, String db) throws SQLException;
Map<String, List<DBColumn>> getDbColumns(ModelSchemaReq modelSchemaReq) throws SQLException;
Map<String, List<DBColumn>> getDbColumns(ModelBuildReq modelBuildReq) throws SQLException;
List<DBColumn> getColumns(Long id, String db, String table) throws SQLException;

View File

@@ -8,8 +8,8 @@ import com.tencent.supersonic.headless.api.pojo.MetaFilter;
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq;
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.UnAvailableItemResp;
@@ -35,7 +35,7 @@ public interface ModelService {
UnAvailableItemResp getUnAvailableItem(FieldRemovedReq fieldRemovedReq);
Map<String, ModelSchema> buildModelSchema(ModelSchemaReq modelSchemaReq) throws SQLException;
Map<String, ModelSchema> buildModelSchema(ModelBuildReq modelBuildReq) throws SQLException;
List<ModelResp> getModelListWithAuth(User user, Long domainId, AuthType authType);

View File

@@ -7,7 +7,7 @@ import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
@@ -208,17 +208,17 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
}
@Override
public Map<String, List<DBColumn>> getDbColumns(ModelSchemaReq modelSchemaReq)
public Map<String, List<DBColumn>> getDbColumns(ModelBuildReq modelBuildReq)
throws SQLException {
Map<String, List<DBColumn>> dbColumnMap = new HashMap<>();
if (StringUtils.isNotBlank(modelSchemaReq.getSql())) {
if (StringUtils.isNotBlank(modelBuildReq.getSql())) {
List<DBColumn> columns =
getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getSql());
dbColumnMap.put(modelSchemaReq.getSql(), columns);
getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getSql());
dbColumnMap.put(modelBuildReq.getSql(), columns);
} else {
for (String table : modelSchemaReq.getTables()) {
for (String table : modelBuildReq.getTables()) {
List<DBColumn> columns =
getColumns(modelSchemaReq.getDatabaseId(), modelSchemaReq.getDb(), table);
getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getDb(), table);
dbColumnMap.put(table, columns);
}
}

View File

@@ -26,8 +26,8 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionReq;
import com.tencent.supersonic.headless.api.pojo.request.FieldRemovedReq;
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaReq;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
@@ -41,7 +41,13 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.ModelDO;
import com.tencent.supersonic.headless.server.persistence.repository.DateInfoRepository;
import com.tencent.supersonic.headless.server.persistence.repository.ModelRepository;
import com.tencent.supersonic.headless.server.pojo.ModelFilter;
import com.tencent.supersonic.headless.server.service.*;
import com.tencent.supersonic.headless.server.service.DataSetService;
import com.tencent.supersonic.headless.server.service.DatabaseService;
import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.DomainService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.ModelRelaService;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.utils.ModelConverter;
import com.tencent.supersonic.headless.server.utils.NameCheckUtils;
import lombok.extern.slf4j.Slf4j;
@@ -53,7 +59,21 @@ import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.CollectionUtils;
import java.sql.SQLException;
import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
@Service
@@ -82,6 +102,9 @@ public class ModelServiceImpl implements ModelService {
private ModelRelaService modelRelaService;
ExecutorService executor =
new ThreadPoolExecutor(0, 5, 100L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>());
public ModelServiceImpl(ModelRepository modelRepository, DatabaseService databaseService,
@Lazy DimensionService dimensionService, @Lazy MetricService metricService,
DomainService domainService, UserService userService, DataSetService dataSetService,
@@ -194,29 +217,49 @@ public class ModelServiceImpl implements ModelService {
}
@Override
public Map<String, ModelSchema> buildModelSchema(ModelSchemaReq modelSchemaReq)
public Map<String, ModelSchema> buildModelSchema(ModelBuildReq modelBuildReq)
throws SQLException {
Map<String, List<DBColumn>> dbColumnMap = databaseService.getDbColumns(modelSchemaReq);
Map<String, ModelSchema> modelSchemaMap = new HashMap<>();
if (modelSchemaReq.isBuildByLLM()) {
ChatModel chatModel = chatModelService.getChatModel(modelSchemaReq.getChatModelId());
modelSchemaReq.setChatModelConfig(chatModel.getConfig());
}
for (Map.Entry<String, List<DBColumn>> entry : dbColumnMap.entrySet()) {
if (modelSchemaReq.isBuildByLLM()) {
DbSchema dbSchema = convert(modelSchemaReq, entry.getKey(), entry.getValue());
ModelSchema modelSchema = modelIntelligentBuilder.build(dbSchema, modelSchemaReq);
if (modelSchema != null) {
modelSchemaMap.put(entry.getKey(), modelSchema);
}
} else {
modelSchemaMap.put(entry.getKey(), build(entry.getValue()));
}
Map<String, List<DBColumn>> dbColumnMap = databaseService.getDbColumns(modelBuildReq);
if (modelBuildReq.isBuildByLLM() && modelBuildReq.getChatModelConfig() == null) {
ChatModel chatModel = chatModelService.getChatModel(modelBuildReq.getChatModelId());
modelBuildReq.setChatModelConfig(chatModel.getConfig());
}
List<DbSchema> dbSchemas = convert(dbColumnMap, modelBuildReq);
Map<String, ModelSchema> modelSchemaMap = new ConcurrentHashMap<>();
CompletableFuture.allOf(dbSchemas.stream()
.map(dbSchema -> CompletableFuture.runAsync(
() -> doBuild(modelBuildReq, dbSchema, dbSchemas, modelSchemaMap),
executor))
.toArray(CompletableFuture[]::new)).join();
return modelSchemaMap;
}
private DbSchema convert(ModelSchemaReq modelSchemaReq, String key, List<DBColumn> dbColumns) {
private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List<DbSchema> dbSchemas,
Map<String, ModelSchema> modelSchemaMap) {
if (modelBuildReq.isBuildByLLM()) {
List<DbSchema> otherDbSchema = getOtherDbSchema(curSchema, dbSchemas);
ModelSchema modelSchema =
modelIntelligentBuilder.build(curSchema, otherDbSchema, modelBuildReq);
modelSchemaMap.put(curSchema.getTable(), modelSchema);
} else {
modelSchemaMap.put(curSchema.getTable(), build(curSchema.getDbColumns()));
}
}
private List<DbSchema> getOtherDbSchema(DbSchema curSchema, List<DbSchema> dbSchemas) {
return dbSchemas.stream()
.filter(dbSchema -> !dbSchema.getTable().equals(curSchema.getTable()))
.collect(Collectors.toList());
}
private List<DbSchema> convert(Map<String, List<DBColumn>> dbColumnMap,
ModelBuildReq modelSchemaReq) {
return dbColumnMap.keySet().stream()
.map(key -> convert(modelSchemaReq, key, dbColumnMap.get(key)))
.collect(Collectors.toList());
}
private DbSchema convert(ModelBuildReq modelSchemaReq, String key, List<DBColumn> dbColumns) {
DbSchema dbSchema = new DbSchema();
dbSchema.setDb(modelSchemaReq.getDb());
dbSchema.setTable(key);

View File

@@ -0,0 +1,90 @@
package com.tencent.supersonic.headless;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.util.LLMConfigUtils;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import java.sql.SQLException;
import java.util.Map;
@Disabled
public class ModelIntelligentBuildTest extends BaseTest {
private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3;
@Autowired
private ModelService modelService;
@Test
public void testBuildModelBatch() throws SQLException {
ChatModelConfig llmConfig = LLMConfigUtils.getLLMConfig(llmType);
ModelBuildReq modelSchemaReq = new ModelBuildReq();
modelSchemaReq.setChatModelConfig(llmConfig);
modelSchemaReq.setBuildByLLM(true);
modelSchemaReq.setDatabaseId(1L);
modelSchemaReq.setDb("semantic");
modelSchemaReq.setTables(Lists.newArrayList("s2_user_department", "s2_stay_time_statis"));
Map<String, ModelSchema> modelSchemaMap = modelService.buildModelSchema(modelSchemaReq);
ModelSchema userModelSchema = modelSchemaMap.get("s2_user_department");
Assertions.assertEquals(2, userModelSchema.getFiledSchemas().size());
Assertions.assertEquals(FieldType.primary_key,
userModelSchema.getFieldByName("user_name").getFiledType());
Assertions.assertEquals(FieldType.dimension,
userModelSchema.getFieldByName("department").getFiledType());
ModelSchema stayTimeModelSchema = modelSchemaMap.get("s2_stay_time_statis");
Assertions.assertEquals(4, stayTimeModelSchema.getFiledSchemas().size());
Assertions.assertEquals(FieldType.foreign_key,
stayTimeModelSchema.getFieldByName("user_name").getFiledType());
Assertions.assertEquals(FieldType.data_time,
stayTimeModelSchema.getFieldByName("imp_date").getFiledType());
Assertions.assertEquals(FieldType.dimension,
stayTimeModelSchema.getFieldByName("page").getFiledType());
Assertions.assertEquals(FieldType.measure,
stayTimeModelSchema.getFieldByName("stay_hours").getFiledType());
Assertions.assertEquals(AggOperatorEnum.SUM,
stayTimeModelSchema.getFieldByName("stay_hours").getAgg());
}
@Test
public void testBuildModelBySql() throws SQLException {
ChatModelConfig llmConfig = LLMConfigUtils.getLLMConfig(llmType);
ModelBuildReq modelSchemaReq = new ModelBuildReq();
modelSchemaReq.setChatModelConfig(llmConfig);
modelSchemaReq.setBuildByLLM(true);
modelSchemaReq.setDatabaseId(1L);
modelSchemaReq.setDb("semantic");
modelSchemaReq.setSql(
"SELECT imp_date, user_name, page, 1 as pv, user_name as uv FROM s2_pv_uv_statis");
Map<String, ModelSchema> modelSchemaMap = modelService.buildModelSchema(modelSchemaReq);
ModelSchema pvModelSchema = modelSchemaMap.values().iterator().next();
Assertions.assertEquals(5, pvModelSchema.getFiledSchemas().size());
Assertions.assertEquals(FieldType.data_time,
pvModelSchema.getFieldByName("imp_date").getFiledType());
Assertions.assertEquals(FieldType.dimension,
pvModelSchema.getFieldByName("user_name").getFiledType());
Assertions.assertEquals(FieldType.dimension,
pvModelSchema.getFieldByName("page").getFiledType());
Assertions.assertEquals(FieldType.measure,
pvModelSchema.getFieldByName("pv").getFiledType());
Assertions.assertEquals(AggOperatorEnum.SUM, pvModelSchema.getFieldByName("pv").getAgg());
Assertions.assertEquals(FieldType.measure,
pvModelSchema.getFieldByName("uv").getFiledType());
Assertions.assertEquals(AggOperatorEnum.COUNT_DISTINCT,
pvModelSchema.getFieldByName("uv").getAgg());
}
}

View File

@@ -1,7 +1,7 @@
spring:
datasource:
driver-class-name: org.h2.Driver
url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false;QUERY_TIMEOUT=30
url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false;QUERY_TIMEOUT=100
username: root
password: semantic
sql: