[improvement][headless]Optimize code structure and code style.

This commit is contained in:
jerryjzhang
2024-07-27 18:29:08 +08:00
parent e5504473a4
commit ccd79e4830
55 changed files with 138 additions and 168 deletions

View File

@@ -33,11 +33,6 @@
<version>${guava.version}</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.apache.commons/commons-lang3 -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
<dependency>
<groupId>com.tencent.supersonic</groupId>
<artifactId>headless-core</artifactId>

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.server.facade.service.impl;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
@@ -26,7 +27,6 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.corrector.GrammarCorrector;
import com.tencent.supersonic.headless.chat.corrector.SchemaCorrector;
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.service.RetrieveService;
import com.tencent.supersonic.headless.api.pojo.MetaFilter;
@@ -64,20 +64,17 @@ public class S2ChatLayerService implements ChatLayerService {
@Override
public MapResp performMapping(QueryNLReq queryNLReq) {
MapResp mapResp = new MapResp();
MapResp mapResp = new MapResp(queryNLReq.getQueryText());
ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq);
ComponentFactory.getSchemaMappers().forEach(mapper -> {
mapper.map(queryCtx);
});
SchemaMapInfo mapInfo = queryCtx.getMapInfo();
mapResp.setMapInfo(mapInfo);
mapResp.setQueryText(queryNLReq.getQueryText());
mapResp.setMapInfo(queryCtx.getMapInfo());
return mapResp;
}
@Override
public MapInfoResp map(QueryMapReq queryMapReq) {
QueryNLReq queryNLReq = new QueryNLReq();
BeanUtils.copyProperties(queryMapReq, queryNLReq);
List<DataSetResp> dataSets = dataSetService.getDataSets(queryMapReq.getDataSetNames(), queryMapReq.getUser());
@@ -92,19 +89,13 @@ public class S2ChatLayerService implements ChatLayerService {
@Override
public ParseResp performParsing(QueryNLReq queryNLReq) {
ParseResp parseResult = new ParseResp(queryNLReq.getQueryText());
// build queryContext
ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq);
chatWorkflowEngine.execute(queryCtx, parseResult);
List<SemanticParseInfo> parseInfos = queryCtx.getCandidateQueries().stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
parseResult.setSelectedParses(parseInfos);
return parseResult;
}
public ChatQueryContext buildChatQueryContext(QueryNLReq queryNLReq) {
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
private ChatQueryContext buildChatQueryContext(QueryNLReq queryNLReq) {
SemanticSchema semanticSchema = schemaService.getSemanticSchema(queryNLReq.getDataSetIds());
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds();
ChatQueryContext queryCtx = ChatQueryContext.builder()
.queryFilters(queryNLReq.getQueryFilters())
@@ -138,7 +129,8 @@ public class S2ChatLayerService implements ChatLayerService {
private SemanticParseInfo correctSqlReq(QuerySqlReq querySqlReq, User user) {
ChatQueryContext queryCtx = new ChatQueryContext();
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
SemanticSchema semanticSchema = schemaService.getSemanticSchema(
Sets.newHashSet(querySqlReq.getDataSetId()));
queryCtx.setSemanticSchema(semanticSchema);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
@@ -277,7 +269,7 @@ public class S2ChatLayerService implements ChatLayerService {
* @return
*/
private SchemaElementMatch getTimeDimension(Long dataSetId, String dataSetName) {
SchemaElement element = SchemaElement.builder().dataSet(dataSetId).dataSetName(dataSetName)
SchemaElement element = SchemaElement.builder().dataSetId(dataSetId).dataSetName(dataSetName)
.type(SchemaElementType.DIMENSION).bizName(TimeDimensionEnum.DAY.getName()).build();
SchemaElementMatch timeDimensionMatch = SchemaElementMatch.builder().element(element)

View File

@@ -82,7 +82,7 @@ import java.util.stream.Collectors;
@Slf4j
public class S2SemanticLayerService implements SemanticLayerService {
private StatUtils statUtils;
private final StatUtils statUtils;
private final QueryUtils queryUtils;
private final QueryReqConverter queryReqConverter;
private final SemanticSchemaManager semanticSchemaManager;
@@ -93,8 +93,8 @@ public class S2SemanticLayerService implements SemanticLayerService {
private final KnowledgeBaseService knowledgeBaseService;
private final MetricService metricService;
private final DimensionService dimensionService;
private QueryCache queryCache = ComponentFactory.getQueryCache();
private List<QueryExecutor> queryExecutors = ComponentFactory.getQueryExecutors();
private final QueryCache queryCache = ComponentFactory.getQueryCache();
private final List<QueryExecutor> queryExecutors = ComponentFactory.getQueryExecutors();
public S2SemanticLayerService(
StatUtils statUtils,
@@ -322,7 +322,7 @@ public class S2SemanticLayerService implements SemanticLayerService {
}
private Set<SchemaElement> getMetrics(EntityInfo modelInfo) {
Set<SchemaElement> metrics = new LinkedHashSet();
Set<SchemaElement> metrics = Sets.newHashSet();
for (DataInfo metricValue : modelInfo.getMetrics()) {
SchemaElement metric = new SchemaElement();
BeanUtils.copyProperties(metricValue, metric);
@@ -439,7 +439,7 @@ public class S2SemanticLayerService implements SemanticLayerService {
if (dataSetSchema == null) {
return entityInfo;
}
Long dataSetId = dataSetSchema.getDataSet().getDataSet();
Long dataSetId = dataSetSchema.getDataSet().getDataSetId();
DataSetInfo dataSetInfo = new DataSetInfo();
dataSetInfo.setItemId(dataSetId.intValue());
dataSetInfo.setName(dataSetSchema.getDataSet().getName());
@@ -518,7 +518,7 @@ public class S2SemanticLayerService implements SemanticLayerService {
//add filter
QueryFilter chatFilter = getQueryFilter(entityInfo);
Set<QueryFilter> chatFilters = new LinkedHashSet();
Set<QueryFilter> chatFilters = Sets.newHashSet();
chatFilters.add(chatFilter);
semanticParseInfo.setDimensionFilters(chatFilters);

View File

@@ -1,8 +1,7 @@
package com.tencent.supersonic.headless.server.pojo.yaml;
import com.google.common.collect.Lists;
import lombok.Data;
import org.apache.commons.compress.utils.Lists;
import java.util.List;
@Data

View File

@@ -24,6 +24,7 @@ import com.tencent.supersonic.headless.server.pojo.yaml.MetricYamlTpl;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
public interface SchemaService {
@@ -32,6 +33,8 @@ public interface SchemaService {
SemanticSchema getSemanticSchema();
SemanticSchema getSemanticSchema(Set<Long> dataSetIds);
SemanticSchemaResp fetchSemanticSchema(SchemaFilterReq schemaFilterReq);
List<ModelSchemaResp> fetchModelSchemaResps(List<Long> modelIds);

View File

@@ -64,7 +64,7 @@ public class RetrieveServiceImpl implements RetrieveService {
String queryText = queryNLReq.getQueryText();
// 1.get meta info
SemanticSchema semanticSchemaDb = schemaService.getSemanticSchema();
SemanticSchema semanticSchemaDb = schemaService.getSemanticSchema(queryNLReq.getDataSetIds());
List<SchemaElement> metricsDb = semanticSchemaDb.getMetrics();
final Map<Long, String> dataSetIdToName = semanticSchemaDb.getDataSetIdToName();
Map<Long, List<Long>> modelIdToDataSetIds =
@@ -228,7 +228,7 @@ public class RetrieveServiceImpl implements RetrieveService {
return Lists.newArrayList();
}
return metricsDb.stream()
.filter(mapDO -> Objects.nonNull(mapDO) && model.equals(mapDO.getDataSet()))
.filter(mapDO -> Objects.nonNull(mapDO) && model.equals(mapDO.getDataSetId()))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.flatMap(entry -> {
List<String> result = new ArrayList<>();

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.server.service.impl;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.ItemDateResp;
import com.tencent.supersonic.common.pojo.ModelRela;
@@ -64,6 +65,7 @@ import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -140,15 +142,15 @@ public class SchemaServiceImpl implements SchemaService {
return fetchDataSetSchema(new DataSetFilterReq(dataSetId)).stream().findFirst().orElse(null);
}
private List<DataSetSchemaResp> fetchDataSetSchema(List<Long> ids) {
private List<DataSetSchemaResp> fetchDataSetSchema(Set<Long> ids) {
DataSetFilterReq dataSetFilterReq = new DataSetFilterReq();
dataSetFilterReq.setDataSetIds(ids);
dataSetFilterReq.setDataSetIds(new ArrayList(ids));
return fetchDataSetSchema(dataSetFilterReq);
}
@Override
public DataSetSchema getDataSetSchema(Long dataSetId) {
List<Long> ids = new ArrayList<>();
Set<Long> ids = Sets.newHashSet();
ids.add(dataSetId);
List<DataSetSchemaResp> dataSetSchemaResps = fetchDataSetSchema(ids);
if (!CollectionUtils.isEmpty(dataSetSchemaResps)) {
@@ -162,7 +164,7 @@ public class SchemaServiceImpl implements SchemaService {
return null;
}
public List<DataSetSchema> getDataSetSchema(List<Long> ids) {
public List<DataSetSchema> getDataSetSchema(Set<Long> ids) {
List<DataSetSchema> domainSchemaList = new ArrayList<>();
for (DataSetSchemaResp resp : fetchDataSetSchema(ids)) {
@@ -174,7 +176,12 @@ public class SchemaServiceImpl implements SchemaService {
@Override
public SemanticSchema getSemanticSchema() {
return new SemanticSchema(getDataSetSchema(new ArrayList<>()));
return new SemanticSchema(getDataSetSchema(Collections.EMPTY_SET));
}
@Override
public SemanticSchema getSemanticSchema(Set<Long> dataSetIds) {
return new SemanticSchema(getDataSetSchema(dataSetIds));
}
public List<DataSetSchemaResp> buildDataSetSchema(DataSetFilterReq filter) {

View File

@@ -59,6 +59,9 @@ public class ChatWorkflowEngine {
parseResult.setErrorMsg("No semantic queries can be parsed out.");
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
} else {
List<SemanticParseInfo> parseInfos = queryCtx.getCandidateQueries().stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
parseResult.setSelectedParses(parseInfos);
queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING);
}
break;

View File

@@ -31,7 +31,7 @@ public class DataSetSchemaBuilder {
DataSetSchema dataSetSchema = new DataSetSchema();
dataSetSchema.setQueryConfig(resp.getQueryConfig());
SchemaElement dataSet = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetId(resp.getId())
.dataSetName(resp.getName())
.id(resp.getId())
.name(resp.getName())
@@ -71,7 +71,7 @@ public class DataSetSchemaBuilder {
List<String> alias = SchemaItem.getAliasList(metric.getAlias());
if (metric.getIsTag() == 1) {
SchemaElement tagToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetId(resp.getId())
.dataSetName(resp.getName())
.model(metric.getModelId())
.id(metric.getId())
@@ -105,7 +105,7 @@ public class DataSetSchemaBuilder {
}
if (dim.getIsTag() == 1) {
SchemaElement tagToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetId(resp.getId())
.dataSetName(resp.getName())
.model(dim.getModelId())
.id(dim.getId())
@@ -130,7 +130,7 @@ public class DataSetSchemaBuilder {
return null;
}
return SchemaElement.builder()
.dataSet(resp.getId())
.dataSetId(resp.getId())
.model(dim.getModelId())
.id(dim.getId())
.name(dim.getName())
@@ -155,7 +155,7 @@ public class DataSetSchemaBuilder {
}
}
SchemaElement dimToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetId(resp.getId())
.dataSetName(resp.getName())
.model(dim.getModelId())
.id(dim.getId())
@@ -189,7 +189,7 @@ public class DataSetSchemaBuilder {
}
}
SchemaElement dimValueToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetId(resp.getId())
.dataSetName(resp.getName())
.model(dim.getModelId())
.id(dim.getId())
@@ -213,7 +213,7 @@ public class DataSetSchemaBuilder {
List<String> alias = SchemaItem.getAliasList(metric.getAlias());
SchemaElement metricToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetId(resp.getId())
.dataSetName(resp.getName())
.model(metric.getModelId())
.id(metric.getId())
@@ -239,7 +239,7 @@ public class DataSetSchemaBuilder {
for (TermResp termResp : resp.getTermResps()) {
List<String> alias = termResp.getAlias();
SchemaElement metricToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetId(resp.getId())
.dataSetName(resp.getName())
.model(-1L)
.id(termResp.getId())