mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
Merge a number of fixes and improvements (#1896)
This commit is contained in:
@@ -86,16 +86,6 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
return configEditCmd.getId();
|
||||
}
|
||||
|
||||
public ItemNameVisibilityInfo getVisibilityByModelId(Long modelId) {
|
||||
ChatConfigResp chatConfigResp = fetchConfigByModelId(modelId);
|
||||
ChatConfig chatConfig = new ChatConfig();
|
||||
chatConfig.setModelId(modelId);
|
||||
chatConfig.setChatAggConfig(chatConfigResp.getChatAggConfig());
|
||||
chatConfig.setChatDetailConfig(chatConfigResp.getChatDetailConfig());
|
||||
ItemNameVisibilityInfo itemNameVisibility = getItemNameVisibility(chatConfig);
|
||||
return itemNameVisibility;
|
||||
}
|
||||
|
||||
public ItemNameVisibilityInfo getItemNameVisibility(ChatConfig chatConfig) {
|
||||
Long modelId = chatConfig.getModelId();
|
||||
|
||||
@@ -312,7 +302,7 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
}
|
||||
Map<Long, SchemaElement> dimIdAndRespPair = modelSchema.getDimensions().stream().collect(
|
||||
Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
|
||||
knowledgeInfos.stream().forEach(knowledgeInfo -> {
|
||||
knowledgeInfos.forEach(knowledgeInfo -> {
|
||||
if (Objects.nonNull(knowledgeInfo)) {
|
||||
SchemaElement dimSchemaResp = dimIdAndRespPair.get(knowledgeInfo.getItemId());
|
||||
if (Objects.nonNull(dimSchemaResp)) {
|
||||
|
||||
@@ -138,8 +138,7 @@ public class SemanticParseInfo implements Serializable {
|
||||
public long getDetailLimit() {
|
||||
long limit = DEFAULT_DETAIL_LIMIT;
|
||||
if (Objects.nonNull(queryConfig)
|
||||
&& Objects.nonNull(queryConfig.getDetailTypeDefaultConfig())
|
||||
&& Objects.nonNull(queryConfig.getDetailTypeDefaultConfig().getLimit())) {
|
||||
&& Objects.nonNull(queryConfig.getDetailTypeDefaultConfig())) {
|
||||
limit = queryConfig.getDetailTypeDefaultConfig().getLimit();
|
||||
}
|
||||
return limit;
|
||||
@@ -148,8 +147,7 @@ public class SemanticParseInfo implements Serializable {
|
||||
public long getMetricLimit() {
|
||||
long limit = DEFAULT_METRIC_LIMIT;
|
||||
if (Objects.nonNull(queryConfig)
|
||||
&& Objects.nonNull(queryConfig.getAggregateTypeDefaultConfig())
|
||||
&& Objects.nonNull(queryConfig.getAggregateTypeDefaultConfig().getLimit())) {
|
||||
&& Objects.nonNull(queryConfig.getAggregateTypeDefaultConfig())) {
|
||||
limit = queryConfig.getAggregateTypeDefaultConfig().getLimit();
|
||||
}
|
||||
return limit;
|
||||
|
||||
@@ -13,7 +13,7 @@ import java.util.stream.Collectors;
|
||||
|
||||
public class SemanticSchema implements Serializable {
|
||||
|
||||
private List<DataSetSchema> dataSetSchemaList;
|
||||
private final List<DataSetSchema> dataSetSchemaList;
|
||||
|
||||
public SemanticSchema(List<DataSetSchema> dataSetSchemaList) {
|
||||
this.dataSetSchemaList = dataSetSchemaList;
|
||||
@@ -48,11 +48,7 @@ public class SemanticSchema implements Serializable {
|
||||
default:
|
||||
}
|
||||
|
||||
if (element.isPresent()) {
|
||||
return element.get();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
return element.orElse(null);
|
||||
}
|
||||
|
||||
public Map<Long, String> getDataSetIdToName() {
|
||||
@@ -62,13 +58,13 @@ public class SemanticSchema implements Serializable {
|
||||
|
||||
public List<SchemaElement> getDimensionValues() {
|
||||
List<SchemaElement> dimensionValues = new ArrayList<>();
|
||||
dataSetSchemaList.stream().forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
|
||||
dataSetSchemaList.forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
|
||||
return dimensionValues;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getDimensions() {
|
||||
List<SchemaElement> dimensions = new ArrayList<>();
|
||||
dataSetSchemaList.stream().forEach(d -> dimensions.addAll(d.getDimensions()));
|
||||
dataSetSchemaList.forEach(d -> dimensions.addAll(d.getDimensions()));
|
||||
return dimensions;
|
||||
}
|
||||
|
||||
@@ -96,13 +92,13 @@ public class SemanticSchema implements Serializable {
|
||||
|
||||
public List<SchemaElement> getTags() {
|
||||
List<SchemaElement> tags = new ArrayList<>();
|
||||
dataSetSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
|
||||
dataSetSchemaList.forEach(d -> tags.addAll(d.getTags()));
|
||||
return tags;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getTerms() {
|
||||
List<SchemaElement> terms = new ArrayList<>();
|
||||
dataSetSchemaList.stream().forEach(d -> terms.addAll(d.getTerms()));
|
||||
dataSetSchemaList.forEach(d -> terms.addAll(d.getTerms()));
|
||||
return terms;
|
||||
}
|
||||
|
||||
@@ -135,7 +131,7 @@ public class SemanticSchema implements Serializable {
|
||||
|
||||
public List<SchemaElement> getDataSets() {
|
||||
List<SchemaElement> dataSets = new ArrayList<>();
|
||||
dataSetSchemaList.stream().forEach(d -> dataSets.add(d.getDataSet()));
|
||||
dataSetSchemaList.forEach(d -> dataSets.add(d.getDataSet()));
|
||||
return dataSets;
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -117,7 +117,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId));
|
||||
parseInfo.setQueryConfig(semanticSchema.getQueryConfig(dataSetId));
|
||||
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();
|
||||
Map<Long, List<SchemaElementMatch>> id2Values = new HashMap<>();
|
||||
|
||||
for (SchemaElementMatch schemaMatch : parseInfo.getElementMatches()) {
|
||||
SchemaElement element = schemaMatch.getElement();
|
||||
@@ -131,7 +130,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
dim2Values.get(element.getId()).add(schemaMatch);
|
||||
} else {
|
||||
dim2Values.put(element.getId(),
|
||||
new ArrayList<>(Arrays.asList(schemaMatch)));
|
||||
new ArrayList<>(Collections.singletonList(schemaMatch)));
|
||||
}
|
||||
}
|
||||
break;
|
||||
@@ -170,7 +169,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
} else {
|
||||
QueryFilter dimensionFilter = new QueryFilter();
|
||||
List<String> values = new ArrayList<>();
|
||||
entry.getValue().stream().forEach(i -> values.add(i.getWord()));
|
||||
entry.getValue().forEach(i -> values.add(i.getWord()));
|
||||
dimensionFilter.setValue(values);
|
||||
dimensionFilter.setBizName(dimension.getBizName());
|
||||
dimensionFilter.setName(dimension.getName());
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule.detail;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.DIMENSION;
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.METRIC;
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.VALUE;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_MOST;
|
||||
|
||||
@Component
|
||||
public class DetailValueQuery extends DetailSemanticQuery {
|
||||
|
||||
public static final String QUERY_MODE = "DETAIL_VALUE";
|
||||
|
||||
public DetailValueQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1);
|
||||
queryMatcher.addOption(DIMENSION, OPTIONAL, AT_MOST, 0);
|
||||
queryMatcher.addOption(METRIC, OPTIONAL, AT_MOST, 0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
super.fillParseInfo(chatQueryContext);
|
||||
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
parseInfo.getDimensions().addAll(semanticSchema.getDimensions());
|
||||
parseInfo.getDimensions().forEach(d -> {
|
||||
parseInfo.getElementMatches()
|
||||
.add(SchemaElementMatch.builder().element(d).word(d.getName()).similarity(0)
|
||||
.isInherited(false).detectWord(d.getName()).build());
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
package com.tencent.supersonic.headless.server.modeller;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.tencent.supersonic.common.config.ChatModel;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.service.ChatModelService;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.DbSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.core.env.Environment;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class LLMSemanticModeller implements SemanticModeller {
|
||||
|
||||
public static final String APP_KEY = "BUILD_DATA_MODEL";
|
||||
|
||||
private static final String SYS_EXEMPLAR_FILE = "s2-buildModel-exemplar.json";
|
||||
|
||||
public static final String INSTRUCTION = ""
|
||||
+ "Role: As an experienced data analyst with extensive modeling experience, "
|
||||
+ " you are expected to have a deep understanding of data analysis and data modeling concepts."
|
||||
+ "\nJob: You will be given a database table structure, which includes the database table name, field name,"
|
||||
+ " field type, and field comments. Your task is to utilize this information for data modeling."
|
||||
+ "\nTask:"
|
||||
+ "\n1. Generate a name and description for the model. Please note, 'bizName' refers to the English name, while 'name' is the Chinese name."
|
||||
+ "\n2. Create a Chinese name for the field and categorize the field into one of the following five types:"
|
||||
+ "\n primary_key: This is a unique identifier for a record row in a database."
|
||||
+ "\n foreign_key: This is a key in a database whose value is derived from the primary key of another table."
|
||||
+ "\n data_time: This represents the time when data is generated in the data warehouse."
|
||||
+ "\n dimension: Usually a string type, used for grouping and filtering data. No need to generate aggregate functions"
|
||||
+ "\n measure: Usually a numeric type, used to quantify data from a certain evaluative perspective. "
|
||||
+ " Also, you need to generate aggregate functions(Eg: MAX, MIN, AVG, SUM, COUNT) for the measure type. "
|
||||
+ "\nTip: I will also give you other related dbSchemas. If you determine that different dbSchemas have the same fields, "
|
||||
+ " they can be primary and foreign key relationships."
|
||||
+ "\nDBSchema: {{DBSchema}}" + "\nOtherRelatedDBSchema: {{otherRelatedDBSchema}}"
|
||||
+ "\nExemplar: {{exemplar}}";
|
||||
|
||||
private final ObjectMapper objectMapper = JsonUtil.INSTANCE.getObjectMapper();
|
||||
|
||||
public LLMSemanticModeller() {
|
||||
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("构造数据语义模型")
|
||||
.appModule(AppModule.HEADLESS).description("通过大模型来构造数据语义模型").enable(true).build());
|
||||
}
|
||||
|
||||
interface ModelSchemaExtractor {
|
||||
ModelSchema generateModelSchema(String text);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelSchema build(DbSchema dbSchema, List<DbSchema> dbSchemas,
|
||||
ModelBuildReq modelBuildReq) {
|
||||
Optional<ChatApp> chatApp = ChatAppManager.getApp(APP_KEY);
|
||||
if (!chatApp.isPresent() || !chatApp.get().isEnable()) {
|
||||
return null;
|
||||
}
|
||||
List<DbSchema> otherDbSchema = getOtherDbSchema(dbSchema, dbSchemas);
|
||||
ModelSchemaExtractor extractor =
|
||||
AiServices.create(ModelSchemaExtractor.class, getChatModel(modelBuildReq));
|
||||
Prompt prompt = generatePrompt(dbSchema, otherDbSchema, chatApp.get());
|
||||
ModelSchema modelSchema =
|
||||
extractor.generateModelSchema(prompt.toUserMessage().singleText());
|
||||
log.info("dbSchema: {}\n otherRelatedDBSchema:{}\n modelSchema: {}",
|
||||
JsonUtil.toString(dbSchema), JsonUtil.toString(otherDbSchema),
|
||||
JsonUtil.toString(modelSchema));
|
||||
return modelSchema;
|
||||
}
|
||||
|
||||
private List<DbSchema> getOtherDbSchema(DbSchema curSchema, List<DbSchema> dbSchemas) {
|
||||
return dbSchemas.stream()
|
||||
.filter(dbSchema -> !dbSchema.getTable().equals(curSchema.getTable()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private ChatLanguageModel getChatModel(ModelBuildReq modelBuildReq) {
|
||||
ChatModelConfig chatModelConfig = modelBuildReq.getChatModelConfig();
|
||||
if (chatModelConfig == null) {
|
||||
ChatModelService chatModelService = ContextUtils.getBean(ChatModelService.class);
|
||||
ChatModel chatModel = chatModelService.getChatModel(modelBuildReq.getChatModelId());
|
||||
chatModelConfig = chatModel.getConfig();
|
||||
}
|
||||
return ModelProvider.getChatModel(chatModelConfig);
|
||||
}
|
||||
|
||||
private Prompt generatePrompt(DbSchema dbSchema, List<DbSchema> otherDbSchema,
|
||||
ChatApp chatApp) {
|
||||
Map<String, Object> variable = new HashMap<>();
|
||||
variable.put("exemplar", loadExemplars());
|
||||
variable.put("DBSchema", JsonUtil.toString(dbSchema));
|
||||
variable.put("otherRelatedDBSchema", JsonUtil.toString(otherDbSchema));
|
||||
return PromptTemplate.from(chatApp.getPrompt()).apply(variable);
|
||||
}
|
||||
|
||||
private String loadExemplars() {
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
String enableExemplarLoading =
|
||||
environment.getProperty("s2.model.building.exemplars.enabled");
|
||||
if (Boolean.TRUE.equals(Boolean.parseBoolean(enableExemplarLoading))) {
|
||||
log.info("Not enable load model-building exemplars");
|
||||
return "";
|
||||
}
|
||||
try {
|
||||
ClassPathResource resource = new ClassPathResource(SYS_EXEMPLAR_FILE);
|
||||
if (resource.exists()) {
|
||||
InputStream inputStream = resource.getInputStream();
|
||||
return objectMapper
|
||||
.writeValueAsString(objectMapper.readValue(inputStream, Object.class));
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("Failed to load model-building system exemplars", e);
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.builder.WordBuilderFactory;
|
||||
import com.tencent.supersonic.headless.server.service.SchemaService;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -21,6 +22,7 @@ import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@Data
|
||||
public class DictWordService {
|
||||
|
||||
@Autowired
|
||||
@@ -80,14 +82,6 @@ public class DictWordService {
|
||||
natures.addAll(natureList);
|
||||
}
|
||||
|
||||
public List<DictWord> getPreDictWords() {
|
||||
return preDictWords;
|
||||
}
|
||||
|
||||
public void setPreDictWords(List<DictWord> preDictWords) {
|
||||
this.preDictWords = preDictWords;
|
||||
}
|
||||
|
||||
private List<SchemaElement> distinct(List<SchemaElement> metas) {
|
||||
if (CollectionUtils.isEmpty(metas)) {
|
||||
return metas;
|
||||
|
||||
@@ -46,7 +46,7 @@ com.tencent.supersonic.headless.core.cache.QueryCache=\
|
||||
### headless-server SPIs
|
||||
|
||||
com.tencent.supersonic.headless.server.modeller.SemanticModeller=\
|
||||
com.tencent.supersonic.headless.server.modeller.RuleSemanticModeller
|
||||
com.tencent.supersonic.headless.server.modeller.LLMSemanticModeller
|
||||
|
||||
### chat-server SPIs
|
||||
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
package com.tencent.supersonic.headless;
|
||||
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
|
||||
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||
import com.tencent.supersonic.util.LLMConfigUtils;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.test.context.TestPropertySource;
|
||||
|
||||
import java.sql.SQLException;
|
||||
import java.util.Map;
|
||||
|
||||
@Disabled
|
||||
@TestPropertySource(properties = {"s2.model.building.exemplars.enabled = false"})
|
||||
public class LLMSemanticModellerTest extends BaseTest {
|
||||
|
||||
private LLMConfigUtils.LLMType llmType = LLMConfigUtils.LLMType.OLLAMA_LLAMA3;
|
||||
|
||||
@Autowired
|
||||
private ModelService modelService;
|
||||
|
||||
@Test
|
||||
public void testBuildModelBatch() throws SQLException {
|
||||
ChatModelConfig llmConfig = LLMConfigUtils.getLLMConfig(llmType);
|
||||
ModelBuildReq modelSchemaReq = new ModelBuildReq();
|
||||
modelSchemaReq.setChatModelConfig(llmConfig);
|
||||
modelSchemaReq.setBuildByLLM(true);
|
||||
modelSchemaReq.setDatabaseId(1L);
|
||||
modelSchemaReq.setDb("semantic");
|
||||
modelSchemaReq.setTables(Lists.newArrayList("s2_user_department", "s2_stay_time_statis"));
|
||||
Map<String, ModelSchema> modelSchemaMap = modelService.buildModelSchema(modelSchemaReq);
|
||||
|
||||
ModelSchema userModelSchema = modelSchemaMap.get("s2_user_department");
|
||||
Assertions.assertEquals(2, userModelSchema.getColumnSchemas().size());
|
||||
Assertions.assertEquals(FieldType.primary_key,
|
||||
userModelSchema.getColumnByName("user_name").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
userModelSchema.getColumnByName("department").getFiledType());
|
||||
|
||||
ModelSchema stayTimeModelSchema = modelSchemaMap.get("s2_stay_time_statis");
|
||||
Assertions.assertEquals(4, stayTimeModelSchema.getColumnSchemas().size());
|
||||
Assertions.assertEquals(FieldType.foreign_key,
|
||||
stayTimeModelSchema.getColumnByName("user_name").getFiledType());
|
||||
Assertions.assertEquals(FieldType.data_time,
|
||||
stayTimeModelSchema.getColumnByName("imp_date").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
stayTimeModelSchema.getColumnByName("page").getFiledType());
|
||||
Assertions.assertEquals(FieldType.measure,
|
||||
stayTimeModelSchema.getColumnByName("stay_hours").getFiledType());
|
||||
Assertions.assertEquals(AggOperatorEnum.SUM,
|
||||
stayTimeModelSchema.getColumnByName("stay_hours").getAgg());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBuildModelBySql() throws SQLException {
|
||||
ChatModelConfig llmConfig = LLMConfigUtils.getLLMConfig(llmType);
|
||||
ModelBuildReq modelSchemaReq = new ModelBuildReq();
|
||||
modelSchemaReq.setChatModelConfig(llmConfig);
|
||||
modelSchemaReq.setBuildByLLM(true);
|
||||
modelSchemaReq.setDatabaseId(1L);
|
||||
modelSchemaReq.setDb("semantic");
|
||||
modelSchemaReq.setSql(
|
||||
"SELECT imp_date, user_name, page, 1 as pv, user_name as uv FROM s2_pv_uv_statis");
|
||||
Map<String, ModelSchema> modelSchemaMap = modelService.buildModelSchema(modelSchemaReq);
|
||||
|
||||
ModelSchema pvModelSchema = modelSchemaMap.values().iterator().next();
|
||||
Assertions.assertEquals(5, pvModelSchema.getColumnSchemas().size());
|
||||
Assertions.assertEquals(FieldType.data_time,
|
||||
pvModelSchema.getColumnByName("imp_date").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
pvModelSchema.getColumnByName("user_name").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
pvModelSchema.getColumnByName("page").getFiledType());
|
||||
Assertions.assertEquals(FieldType.measure,
|
||||
pvModelSchema.getColumnByName("pv").getFiledType());
|
||||
Assertions.assertEquals(AggOperatorEnum.SUM, pvModelSchema.getColumnByName("pv").getAgg());
|
||||
Assertions.assertEquals(FieldType.measure,
|
||||
pvModelSchema.getColumnByName("uv").getFiledType());
|
||||
Assertions.assertEquals(AggOperatorEnum.COUNT_DISTINCT,
|
||||
pvModelSchema.getColumnByName("uv").getAgg());
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user