[improvement][chat] Optimize and modify the mapper method for terminology (#1866)

This commit is contained in:
lexluo09
2024-10-31 11:18:35 +08:00
committed by GitHub
parent 838745d415
commit cf359f3e2f
27 changed files with 172 additions and 131 deletions

View File

@@ -29,12 +29,12 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.corrector.GrammarCorrector;
import com.tencent.supersonic.headless.chat.corrector.SchemaCorrector;
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.service.DataSetService;
import com.tencent.supersonic.headless.server.service.RetrieveService;
import com.tencent.supersonic.headless.server.service.SchemaService;
import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine;
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;

View File

@@ -45,7 +45,7 @@ import com.tencent.supersonic.headless.server.service.DomainService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.ModelRelaService;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
import com.tencent.supersonic.headless.server.utils.CoreComponentFactory;
import com.tencent.supersonic.headless.server.utils.ModelConverter;
import com.tencent.supersonic.headless.server.utils.NameCheckUtils;
import lombok.extern.slf4j.Slf4j;
@@ -222,7 +222,7 @@ public class ModelServiceImpl implements ModelService {
private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List<DbSchema> dbSchemas,
Map<String, ModelSchema> modelSchemaMap) {
SemanticModeller semanticModeller = ComponentFactory.getSemanticModeller();
SemanticModeller semanticModeller = CoreComponentFactory.getSemanticModeller();
ModelSchema modelSchema = semanticModeller.build(curSchema, dbSchemas, modelBuildReq);
modelSchemaMap.put(curSchema.getTable(), modelSchema);
}

View File

@@ -30,11 +30,12 @@ import java.util.stream.Collectors;
@Slf4j
public class ChatWorkflowEngine {
private final List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
private final List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
private final List<SchemaMapper> schemaMappers = CoreComponentFactory.getSchemaMappers();
private final List<SemanticParser> semanticParsers = CoreComponentFactory.getSemanticParsers();
private final List<SemanticCorrector> semanticCorrectors =
ComponentFactory.getSemanticCorrectors();
private final List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
CoreComponentFactory.getSemanticCorrectors();
private final List<ResultProcessor> resultProcessors =
CoreComponentFactory.getResultProcessors();
public void start(ChatWorkflowState initialState, ChatQueryContext queryCtx,
ParseResp parseResult) {
@@ -48,8 +49,6 @@ public class ChatWorkflowEngine {
parseResult.setErrorMsg(
"No semantic entities can be mapped against user question.");
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
} else if (queryCtx.getMapInfo().needContinueMap()) {
queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING);
} else {
queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING);
}
@@ -91,8 +90,7 @@ public class ChatWorkflowEngine {
private void performMapping(ChatQueryContext queryCtx) {
if (Objects.isNull(queryCtx.getMapInfo())
|| MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())
|| queryCtx.getMapInfo().needContinueMap()) {
|| MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) {
schemaMappers.forEach(mapper -> mapper.map(queryCtx));
}
}

View File

@@ -1,66 +0,0 @@
package com.tencent.supersonic.headless.server.utils;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
import com.tencent.supersonic.headless.chat.mapper.SchemaMapper;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.server.modeller.SemanticModeller;
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.core.io.support.SpringFactoriesLoader;
import java.util.ArrayList;
import java.util.List;
/** QueryConverter QueryOptimizer QueryExecutor object factory */
@Slf4j
public class ComponentFactory {
private static List<ResultProcessor> resultProcessors = new ArrayList<>();
private static List<SchemaMapper> schemaMappers = new ArrayList<>();
private static List<SemanticParser> semanticParsers = new ArrayList<>();
private static List<SemanticCorrector> semanticCorrectors = new ArrayList<>();
private static SemanticModeller semanticModeller;
public static List<ResultProcessor> getResultProcessors() {
return CollectionUtils.isEmpty(resultProcessors)
? init(ResultProcessor.class, resultProcessors)
: resultProcessors;
}
public static List<SchemaMapper> getSchemaMappers() {
return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers)
: schemaMappers;
}
public static List<SemanticParser> getSemanticParsers() {
return CollectionUtils.isEmpty(semanticParsers)
? init(SemanticParser.class, semanticParsers)
: semanticParsers;
}
public static List<SemanticCorrector> getSemanticCorrectors() {
return CollectionUtils.isEmpty(semanticCorrectors)
? init(SemanticCorrector.class, semanticCorrectors)
: semanticCorrectors;
}
public static SemanticModeller getSemanticModeller() {
return semanticModeller == null ? init(SemanticModeller.class) : semanticModeller;
}
public static <T> T getBean(String name, Class<T> tClass) {
return ContextUtils.getContext().getBean(name, tClass);
}
private static <T> List<T> init(Class<T> factoryType, List list) {
list.addAll(SpringFactoriesLoader.loadFactories(factoryType,
Thread.currentThread().getContextClassLoader()));
return list;
}
private static <T> T init(Class<T> factoryType) {
return SpringFactoriesLoader
.loadFactories(factoryType, Thread.currentThread().getContextClassLoader()).get(0);
}
}

View File

@@ -0,0 +1,31 @@
package com.tencent.supersonic.headless.server.utils;
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
import com.tencent.supersonic.headless.server.modeller.SemanticModeller;
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
/**
* QueryConverter QueryOptimizer QueryExecutor object factory
*/
@Slf4j
public class CoreComponentFactory extends ComponentFactory {
private static List<ResultProcessor> resultProcessors = new ArrayList<>();
private static SemanticModeller semanticModeller;
public static List<ResultProcessor> getResultProcessors() {
return CollectionUtils.isEmpty(resultProcessors)
? init(ResultProcessor.class, resultProcessors)
: resultProcessors;
}
public static SemanticModeller getSemanticModeller() {
return semanticModeller == null ? init(SemanticModeller.class) : semanticModeller;
}
}