mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 22:46:49 +00:00
Merge a number of fixes and improvements (#1896)
This commit is contained in:
@@ -46,7 +46,7 @@ com.tencent.supersonic.headless.core.cache.QueryCache=\
|
||||
### headless-server SPIs
|
||||
|
||||
com.tencent.supersonic.headless.server.modeller.SemanticModeller=\
|
||||
com.tencent.supersonic.headless.server.modeller.RuleSemanticModeller
|
||||
com.tencent.supersonic.headless.server.modeller.LLMSemanticModeller
|
||||
|
||||
### chat-server SPIs
|
||||
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
package com.tencent.supersonic.headless;
|
||||
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
|
||||
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||
import com.tencent.supersonic.util.LLMConfigUtils;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.test.context.TestPropertySource;
|
||||
|
||||
import java.sql.SQLException;
|
||||
import java.util.Map;
|
||||
|
||||
@Disabled
|
||||
@TestPropertySource(properties = {"s2.model.building.exemplars.enabled = false"})
|
||||
public class LLMSemanticModellerTest extends BaseTest {
|
||||
|
||||
private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3;
|
||||
|
||||
@Autowired
|
||||
private ModelService modelService;
|
||||
|
||||
@Test
|
||||
public void testBuildModelBatch() throws SQLException {
|
||||
ChatModelConfig llmConfig = LLMConfigUtils.getLLMConfig(llmType);
|
||||
ModelBuildReq modelSchemaReq = new ModelBuildReq();
|
||||
modelSchemaReq.setChatModelConfig(llmConfig);
|
||||
modelSchemaReq.setBuildByLLM(true);
|
||||
modelSchemaReq.setDatabaseId(1L);
|
||||
modelSchemaReq.setDb("semantic");
|
||||
modelSchemaReq.setTables(Lists.newArrayList("s2_user_department", "s2_stay_time_statis"));
|
||||
Map<String, ModelSchema> modelSchemaMap = modelService.buildModelSchema(modelSchemaReq);
|
||||
|
||||
ModelSchema userModelSchema = modelSchemaMap.get("s2_user_department");
|
||||
Assertions.assertEquals(2, userModelSchema.getColumnSchemas().size());
|
||||
Assertions.assertEquals(FieldType.primary_key,
|
||||
userModelSchema.getColumnByName("user_name").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
userModelSchema.getColumnByName("department").getFiledType());
|
||||
|
||||
ModelSchema stayTimeModelSchema = modelSchemaMap.get("s2_stay_time_statis");
|
||||
Assertions.assertEquals(4, stayTimeModelSchema.getColumnSchemas().size());
|
||||
Assertions.assertEquals(FieldType.foreign_key,
|
||||
stayTimeModelSchema.getColumnByName("user_name").getFiledType());
|
||||
Assertions.assertEquals(FieldType.data_time,
|
||||
stayTimeModelSchema.getColumnByName("imp_date").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
stayTimeModelSchema.getColumnByName("page").getFiledType());
|
||||
Assertions.assertEquals(FieldType.measure,
|
||||
stayTimeModelSchema.getColumnByName("stay_hours").getFiledType());
|
||||
Assertions.assertEquals(AggOperatorEnum.SUM,
|
||||
stayTimeModelSchema.getColumnByName("stay_hours").getAgg());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBuildModelBySql() throws SQLException {
|
||||
ChatModelConfig llmConfig = LLMConfigUtils.getLLMConfig(llmType);
|
||||
ModelBuildReq modelSchemaReq = new ModelBuildReq();
|
||||
modelSchemaReq.setChatModelConfig(llmConfig);
|
||||
modelSchemaReq.setBuildByLLM(true);
|
||||
modelSchemaReq.setDatabaseId(1L);
|
||||
modelSchemaReq.setDb("semantic");
|
||||
modelSchemaReq.setSql(
|
||||
"SELECT imp_date, user_name, page, 1 as pv, user_name as uv FROM s2_pv_uv_statis");
|
||||
Map<String, ModelSchema> modelSchemaMap = modelService.buildModelSchema(modelSchemaReq);
|
||||
|
||||
ModelSchema pvModelSchema = modelSchemaMap.values().iterator().next();
|
||||
Assertions.assertEquals(5, pvModelSchema.getColumnSchemas().size());
|
||||
Assertions.assertEquals(FieldType.data_time,
|
||||
pvModelSchema.getColumnByName("imp_date").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
pvModelSchema.getColumnByName("user_name").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
pvModelSchema.getColumnByName("page").getFiledType());
|
||||
Assertions.assertEquals(FieldType.measure,
|
||||
pvModelSchema.getColumnByName("pv").getFiledType());
|
||||
Assertions.assertEquals(AggOperatorEnum.SUM, pvModelSchema.getColumnByName("pv").getAgg());
|
||||
Assertions.assertEquals(FieldType.measure,
|
||||
pvModelSchema.getColumnByName("uv").getFiledType());
|
||||
Assertions.assertEquals(AggOperatorEnum.COUNT_DISTINCT,
|
||||
pvModelSchema.getColumnByName("uv").getAgg());
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user