diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java index 790b3bdd7..ee6f8c51a 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java @@ -167,7 +167,7 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy variable.put("retrievedInfo", JSONObject.toJSONString(results)); Prompt prompt = PromptTemplate.from(LLM_FILTER_PROMPT).apply(variable); - ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(); + ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatQueryContext.getRequest().getChatAppConfig().get("REWRITE_MULTI_TURN").getChatModelConfig()); String response = chatLanguageModel.generate(prompt.toUserMessage().singleText()); if (StringUtils.isBlank(response)) { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/SqlQueryParser.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/SqlQueryParser.java index b52e9bbea..04b89ee5e 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/SqlQueryParser.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/parser/SqlQueryParser.java @@ -47,12 +47,17 @@ public class SqlQueryParser implements QueryParser { SqlQuery sqlQuery = queryStatement.getSqlQuery(); List queryFields = SqlSelectHelper.getAllSelectFields(sqlQuery.getSql()); Set queryAliases = SqlSelectHelper.getAliasFields(sqlQuery.getSql()); + Set ontologyMetricsDimensions = Collections.synchronizedSet(new HashSet()); + Set ontologyBizNameMetricsDimensions = Collections.synchronizedSet(new HashSet()); queryFields.removeAll(queryAliases); Ontology ontology = queryStatement.getOntology(); OntologyQuery ontologyQuery = buildOntologyQuery(ontology, queryFields); + Set queryFieldsSet = new HashSet<>(queryFields); + ontologyQuery.getMetrics().forEach(m -> {ontologyMetricsDimensions.add(m.getName()); ontologyBizNameMetricsDimensions.add(m.getBizName());}); + ontologyQuery.getDimensions().forEach(d -> {ontologyMetricsDimensions.add(d.getName()); ontologyBizNameMetricsDimensions.add(d.getBizName());}); // check if there are fields not matched with any metric or dimension - if (queryFields.size() > ontologyQuery.getMetrics().size() - + ontologyQuery.getDimensions().size()) { + + if (!(queryFieldsSet.containsAll(ontologyMetricsDimensions)||queryFieldsSet.containsAll(ontologyBizNameMetricsDimensions))) { List semanticFields = Lists.newArrayList(); ontologyQuery.getMetrics().forEach(m -> semanticFields.add(m.getName())); ontologyQuery.getDimensions().forEach(d -> semanticFields.add(d.getName()));