mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
[improvement][chat] Optimize and modify the mapper method for terminology (#1866)
This commit is contained in:
@@ -10,6 +10,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
@@ -18,10 +19,9 @@ import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
public class ChatQueryContext {
|
||||
public class ChatQueryContext implements Serializable {
|
||||
|
||||
private QueryNLReq request;
|
||||
private String oriQueryText;
|
||||
private Map<Long, List<Long>> modelIdToDataSetIds;
|
||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
|
||||
@@ -1,41 +1,44 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.common.util.DeepCopyUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/** * A mapper that map the description of the term. */
|
||||
/**
|
||||
* A mapper that map the description of the term.
|
||||
*/
|
||||
@Slf4j
|
||||
public class TermDescMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(ChatQueryContext chatQueryContext) {
|
||||
List<SchemaElement> termDescriptionToMap =
|
||||
chatQueryContext.getMapInfo().getTermDescriptionToMap();
|
||||
if (CollectionUtils.isEmpty(termDescriptionToMap)) {
|
||||
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
|
||||
List<SchemaElement> termElements = mapInfo.getTermDescriptionToMap();
|
||||
if (CollectionUtils.isEmpty(termElements)) {
|
||||
return;
|
||||
}
|
||||
if (StringUtils.isBlank(chatQueryContext.getOriQueryText())) {
|
||||
chatQueryContext.setOriQueryText(chatQueryContext.getRequest().getQueryText());
|
||||
}
|
||||
for (SchemaElement schemaElement : termDescriptionToMap) {
|
||||
if (schemaElement.isDescriptionMapped()) {
|
||||
continue;
|
||||
}
|
||||
if (chatQueryContext.getRequest().getQueryText()
|
||||
.equals(schemaElement.getDescription())) {
|
||||
schemaElement.setDescriptionMapped(true);
|
||||
continue;
|
||||
}
|
||||
chatQueryContext.getRequest().setQueryText(schemaElement.getDescription());
|
||||
break;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(chatQueryContext.getMapInfo().getTermDescriptionToMap())) {
|
||||
chatQueryContext.getRequest().setQueryText(chatQueryContext.getOriQueryText());
|
||||
for (SchemaElement schemaElement : termElements) {
|
||||
ChatQueryContext queryCtx =
|
||||
buildQueryContext(chatQueryContext, schemaElement.getDescription());
|
||||
ComponentFactory.getSchemaMappers().forEach(mapper -> mapper.map(queryCtx));
|
||||
chatQueryContext.getMapInfo().addMatchedElements(queryCtx.getMapInfo());
|
||||
}
|
||||
}
|
||||
|
||||
private static ChatQueryContext buildQueryContext(ChatQueryContext chatQueryContext,
|
||||
String queryText) {
|
||||
ChatQueryContext queryContext = DeepCopyUtil.deepCopy(chatQueryContext);
|
||||
queryContext.getRequest().setQueryText(queryText);
|
||||
queryContext.setMapInfo(new SchemaMapInfo());
|
||||
queryContext.setSemanticSchema(chatQueryContext.getSemanticSchema());
|
||||
queryContext.setModelIdToDataSetIds(chatQueryContext.getModelIdToDataSetIds());
|
||||
queryContext.setChatWorkflowState(chatQueryContext.getChatWorkflowState());
|
||||
return queryContext;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,19 +1,45 @@
|
||||
package com.tencent.supersonic.headless.chat.utils;
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
|
||||
import com.tencent.supersonic.headless.chat.mapper.SchemaMapper;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.parser.llm.DataSetResolver;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.core.io.support.SpringFactoriesLoader;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/** HeadlessConverter QueryOptimizer QueryExecutor object factory */
|
||||
/**
|
||||
* QueryConverter QueryOptimizer QueryExecutor object factory
|
||||
*/
|
||||
@Slf4j
|
||||
public class ComponentFactory {
|
||||
|
||||
private static List<SchemaMapper> schemaMappers = new ArrayList<>();
|
||||
private static List<SemanticParser> semanticParsers = new ArrayList<>();
|
||||
private static List<SemanticCorrector> semanticCorrectors = new ArrayList<>();
|
||||
private static DataSetResolver modelResolver;
|
||||
|
||||
public static List<SchemaMapper> getSchemaMappers() {
|
||||
return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers)
|
||||
: schemaMappers;
|
||||
}
|
||||
|
||||
public static List<SemanticParser> getSemanticParsers() {
|
||||
return CollectionUtils.isEmpty(semanticParsers)
|
||||
? init(SemanticParser.class, semanticParsers)
|
||||
: semanticParsers;
|
||||
}
|
||||
|
||||
public static List<SemanticCorrector> getSemanticCorrectors() {
|
||||
return CollectionUtils.isEmpty(semanticCorrectors)
|
||||
? init(SemanticCorrector.class, semanticCorrectors)
|
||||
: semanticCorrectors;
|
||||
}
|
||||
|
||||
public static DataSetResolver getModelResolver() {
|
||||
if (Objects.isNull(modelResolver)) {
|
||||
modelResolver = init(DataSetResolver.class);
|
||||
@@ -25,13 +51,13 @@ public class ComponentFactory {
|
||||
return ContextUtils.getContext().getBean(name, tClass);
|
||||
}
|
||||
|
||||
private static <T> List<T> init(Class<T> factoryType, List list) {
|
||||
protected static <T> List<T> init(Class<T> factoryType, List list) {
|
||||
list.addAll(SpringFactoriesLoader.loadFactories(factoryType,
|
||||
Thread.currentThread().getContextClassLoader()));
|
||||
return list;
|
||||
}
|
||||
|
||||
private static <T> T init(Class<T> factoryType) {
|
||||
protected static <T> T init(Class<T> factoryType) {
|
||||
return SpringFactoriesLoader
|
||||
.loadFactories(factoryType, Thread.currentThread().getContextClassLoader()).get(0);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user