diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/PluginSemanticQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/PluginSemanticQuery.java index 5dda72c1f..19a35a4cc 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/PluginSemanticQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/PluginSemanticQuery.java @@ -1,8 +1,19 @@ package com.tencent.supersonic.chat.query.plugin; +import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; +import com.tencent.supersonic.chat.api.pojo.SchemaElementType; +import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; +import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import com.tencent.supersonic.chat.plugin.PluginParseResult; import com.tencent.supersonic.chat.query.BaseSemanticQuery; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import lombok.extern.slf4j.Slf4j; +import org.springframework.util.CollectionUtils; @Slf4j public abstract class PluginSemanticQuery extends BaseSemanticQuery { @@ -16,4 +27,73 @@ public abstract class PluginSemanticQuery extends BaseSemanticQuery { public void initS2Sql(User user) { } + + private Map getFilterMap(PluginParseResult pluginParseResult) { + Map map = new HashMap<>(); + QueryReq queryReq = pluginParseResult.getRequest(); + if (queryReq == null || queryReq.getQueryFilters() == null) { + return map; + } + QueryFilters queryFilters = queryReq.getQueryFilters(); + List queryFilterList = queryFilters.getFilters(); + if (CollectionUtils.isEmpty(queryFilterList)) { + return map; + } + for (QueryFilter queryFilter : queryFilterList) { + map.put(queryFilter.getElementID(), queryFilter.getValue()); + } + return map; + } + + protected Map getElementMap(PluginParseResult pluginParseResult) { + Map elementValueMap = new HashMap<>(); + Map filterValueMap = getFilterMap(pluginParseResult); + List schemaElementMatchList = parseInfo.getElementMatches(); + if (!CollectionUtils.isEmpty(schemaElementMatchList)) { + schemaElementMatchList.stream() + .filter(schemaElementMatch -> + SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()) + || SchemaElementType.ID.equals(schemaElementMatch.getElement().getType())) + .filter(schemaElementMatch -> schemaElementMatch.getSimilarity() == 1.0) + .forEach(schemaElementMatch -> { + Object queryFilterValue = filterValueMap.get(schemaElementMatch.getElement().getId()); + if (queryFilterValue != null) { + if (String.valueOf(queryFilterValue).equals(String.valueOf(schemaElementMatch.getWord()))) { + elementValueMap.put( + String.valueOf(schemaElementMatch.getElement().getId()), + schemaElementMatch.getWord()); + } + } else { + elementValueMap.computeIfAbsent( + String.valueOf(schemaElementMatch.getElement().getId()), + k -> schemaElementMatch.getWord()); + } + }); + } + return elementValueMap; + } + + protected WebBase fillWebBaseResult(WebBase webPage, PluginParseResult pluginParseResult) { + WebBase webBaseResult = new WebBase(); + webBaseResult.setUrl(webPage.getUrl()); + Map elementValueMap = getElementMap(pluginParseResult); + List paramOptions = Lists.newArrayList(); + if (!CollectionUtils.isEmpty(webPage.getParamOptions()) && !CollectionUtils.isEmpty(elementValueMap)) { + for (ParamOption paramOption : webPage.getParamOptions()) { + if (paramOption.getModelId() != null + && !parseInfo.getModel().getModelIds().contains(paramOption.getModelId())) { + continue; + } + paramOptions.add(paramOption); + if (!ParamOption.ParamType.SEMANTIC.equals(paramOption.getParamType())) { + continue; + } + String elementId = String.valueOf(paramOption.getElementId()); + Object elementValue = elementValueMap.get(elementId); + paramOption.setValue(elementValue); + } + } + webBaseResult.setParamOptions(paramOptions); + return webBaseResult; + } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/webpage/WebPageQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/webpage/WebPageQuery.java index d7cf8da7c..1ee5b271e 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/webpage/WebPageQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/webpage/WebPageQuery.java @@ -60,78 +60,8 @@ public class WebPageQuery extends PluginSemanticQuery { webPageResponse.setPluginId(plugin.getId()); webPageResponse.setPluginType(plugin.getType()); WebBase webPage = JsonUtil.toObject(plugin.getConfig(), WebBase.class); - WebBase webBase = buildWebPageResult(webPage, pluginParseResult); + WebBase webBase = fillWebBaseResult(webPage, pluginParseResult); webPageResponse.setWebPage(webBase); return webPageResponse; } - - private WebBase buildWebPageResult(WebBase webPage, PluginParseResult pluginParseResult) { - WebBase webBaseResult = new WebBase(); - webBaseResult.setUrl(webPage.getUrl()); - Map elementValueMap = getElementMap(pluginParseResult); - List paramOptions = Lists.newArrayList(); - if (!CollectionUtils.isEmpty(webPage.getParamOptions()) && !CollectionUtils.isEmpty(elementValueMap)) { - for (ParamOption paramOption : webPage.getParamOptions()) { - if (paramOption.getModelId() != null - && !parseInfo.getModel().getModelIds().contains(paramOption.getModelId())) { - continue; - } - paramOptions.add(paramOption); - if (!ParamOption.ParamType.SEMANTIC.equals(paramOption.getParamType())) { - continue; - } - String elementId = String.valueOf(paramOption.getElementId()); - Object elementValue = elementValueMap.get(elementId); - paramOption.setValue(elementValue); - } - } - webBaseResult.setParamOptions(paramOptions); - return webBaseResult; - } - - protected Map getElementMap(PluginParseResult pluginParseResult) { - Map elementValueMap = new HashMap<>(); - Map filterValueMap = getFilterMap(pluginParseResult); - List schemaElementMatchList = parseInfo.getElementMatches(); - if (!CollectionUtils.isEmpty(schemaElementMatchList)) { - schemaElementMatchList.stream() - .filter(schemaElementMatch -> - SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()) - || SchemaElementType.ID.equals(schemaElementMatch.getElement().getType())) - .filter(schemaElementMatch -> schemaElementMatch.getSimilarity() == 1.0) - .forEach(schemaElementMatch -> { - Object queryFilterValue = filterValueMap.get(schemaElementMatch.getElement().getId()); - if (queryFilterValue != null) { - if (String.valueOf(queryFilterValue).equals(String.valueOf(schemaElementMatch.getWord()))) { - elementValueMap.put( - String.valueOf(schemaElementMatch.getElement().getId()), - schemaElementMatch.getWord()); - } - } else { - elementValueMap.computeIfAbsent( - String.valueOf(schemaElementMatch.getElement().getId()), - k -> schemaElementMatch.getWord()); - } - }); - } - return elementValueMap; - } - - private Map getFilterMap(PluginParseResult pluginParseResult) { - Map map = new HashMap<>(); - QueryReq queryReq = pluginParseResult.getRequest(); - if (queryReq == null || queryReq.getQueryFilters() == null) { - return map; - } - QueryFilters queryFilters = queryReq.getQueryFilters(); - List queryFilterList = queryFilters.getFilters(); - if (CollectionUtils.isEmpty(queryFilterList)) { - return map; - } - for (QueryFilter queryFilter : queryFilterList) { - map.put(queryFilter.getElementID(), queryFilter.getValue()); - } - return map; - } - } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/webservice/WebServiceQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/webservice/WebServiceQuery.java index 9664bd049..cbabfa9c0 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/webservice/WebServiceQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/webservice/WebServiceQuery.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.chat.query.plugin.webservice; import com.alibaba.fastjson.JSON; +import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; @@ -26,6 +27,7 @@ import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.stereotype.Component; +import org.springframework.util.CollectionUtils; import org.springframework.web.client.RestTemplate; import org.springframework.web.util.UriComponentsBuilder; @@ -75,7 +77,7 @@ public class WebServiceQuery extends PluginSemanticQuery { protected WebServiceResp buildResponse(PluginParseResult pluginParseResult) { WebServiceResp webServiceResponse = new WebServiceResp(); Plugin plugin = pluginParseResult.getPlugin(); - WebBase webBase = JsonUtil.toObject(plugin.getConfig(), WebBase.class); + WebBase webBase = fillWebBaseResult(JsonUtil.toObject(plugin.getConfig(), WebBase.class), pluginParseResult); webServiceResponse.setWebBase(webBase); List paramOptions = webBase.getParamOptions(); Map params = new HashMap<>();