diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java index f8d37313b..cd224d4ae 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java @@ -23,6 +23,7 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; import java.util.ArrayList; +import java.util.Collections; import java.util.Comparator; import java.util.HashSet; import java.util.List; @@ -80,15 +81,17 @@ public class LLMRequestService { llmSchema.setDomainName(dataSetIdToName.get(dataSetId)); List fieldNameList = getFieldNameList(queryCtx, dataSetId, llmParserConfig); + fieldNameList.add(TimeDimensionEnum.DAY.getChName()); + llmSchema.setFieldNameList(fieldNameList); + + llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId)); + llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId)); + llmSchema.setTerms(getTerms(queryCtx, dataSetId)); + llmReq.setSchema(llmSchema); String priorExts = getPriorExts(queryCtx, fieldNameList); llmReq.setPriorExts(priorExts); - fieldNameList.add(TimeDimensionEnum.DAY.getChName()); - llmSchema.setFieldNameList(fieldNameList); - llmSchema.setTerms(getTerms(queryCtx, dataSetId)); - llmReq.setSchema(llmSchema); - List linking = new ArrayList<>(); boolean linkingValueEnabled = Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE)); @@ -104,10 +107,11 @@ public class LLMRequestService { llmReq.setCurrentDate(currentDate); llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE))); llmReq.setLlmConfig(queryCtx.getLlmConfig()); + return llmReq; } - public LLMResp invokeLLM(LLMReq llmReq) { + public LLMResp runText2SQL(LLMReq llmReq) { return ComponentFactory.getLLMProxy().text2sql(llmReq); } @@ -169,7 +173,7 @@ public class LLMRequestService { return extraInfoSb.toString(); } - public List getValueList(QueryContext queryCtx, Long dataSetId) { + public List getValues(QueryContext queryCtx, Long dataSetId) { Map itemIdToName = getItemIdToName(queryCtx, dataSetId); List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); if (CollectionUtils.isEmpty(matchedElements)) { @@ -216,6 +220,40 @@ public class LLMRequestService { return results; } + protected List getMatchedMetrics(QueryContext queryCtx, Long dataSetId) { + List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); + if (CollectionUtils.isEmpty(matchedElements)) { + return Collections.emptyList(); + } + List schemaElements = matchedElements.stream() + .filter(schemaElementMatch -> { + SchemaElementType elementType = schemaElementMatch.getElement().getType(); + return SchemaElementType.METRIC.equals(elementType); + }) + .map(schemaElementMatch -> { + return schemaElementMatch.getElement(); + }) + .collect(Collectors.toList()); + return schemaElements; + } + + protected List getMatchedDimensions(QueryContext queryCtx, Long dataSetId) { + List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); + if (CollectionUtils.isEmpty(matchedElements)) { + return Collections.emptyList(); + } + List schemaElements = matchedElements.stream() + .filter(schemaElementMatch -> { + SchemaElementType elementType = schemaElementMatch.getElement().getType(); + return SchemaElementType.DIMENSION.equals(elementType); + }) + .map(schemaElementMatch -> { + return schemaElementMatch.getElement(); + }) + .collect(Collectors.toList()); + return schemaElements; + } + protected Set getMatchedFieldNames(QueryContext queryCtx, Long dataSetId) { Map itemIdToName = getItemIdToName(queryCtx, dataSetId); List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMSqlParser.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMSqlParser.java index cc4507445..4f530654c 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMSqlParser.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMSqlParser.java @@ -40,10 +40,10 @@ public class LLMSqlParser implements SemanticParser { log.info("Generate query statement for dataSetId:{}", dataSetId); //3.invoke LLM service to do parsing. - List linkingValues = requestService.getValueList(queryCtx, dataSetId); + List linkingValues = requestService.getValues(queryCtx, dataSetId); SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues); - LLMResp llmResp = requestService.invokeLLM(llmReq); + LLMResp llmResp = requestService.runText2SQL(llmReq); if (Objects.isNull(llmResp)) { return; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGenStrategy.java index a9d4c0b7b..a50c91484 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -12,6 +12,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.tuple.Pair; import org.springframework.stereotype.Service; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -68,7 +69,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { + "1.ALWAYS use `数据日期` as the date field.\n" + "2.ALWAYS use `datediff()` as the date function.\n" + "3.DO NOT specify date filter in the where clause if not explicitly mentioned in the query.\n" - + "4.ONLY output SQL statement.\n" + + "4.ONLY respond with the converted SQL statement.\n" + "#Exemplars:\n%s" + "#UserQuery: %s " + "#DatabaseMetadata: %s " @@ -85,11 +86,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { } Pair questionPrompt = promptHelper.transformQuestionPrompt(llmReq); - String dbSchema = questionPrompt.getLeft(); + String dataSemanticsStr = promptHelper.buildMetadataStr(llmReq); String questionAugmented = questionPrompt.getRight(); - String promptStr = String.format(instruction, exemplarsStr, questionAugmented, dbSchema); + String promptStr = String.format(instruction, exemplarsStr, questionAugmented, dataSemanticsStr); - return PromptTemplate.from(promptStr).apply(new HashMap<>()); + return PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); } @Override diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PromptHelper.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PromptHelper.java index f147a049c..0a4a24e5c 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PromptHelper.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PromptHelper.java @@ -65,14 +65,48 @@ public class PromptHelper { } String currentDataStr = "当前的日期是" + currentDate; String linkingListStr = String.join(",", priorLinkingList); - String termStr = getTermStr(llmReq); + String termStr = buildTermStr(llmReq); String questionAugmented = String.format("%s (补充信息:%s;%s;%s;%s)", llmReq.getQueryText(), linkingListStr, currentDataStr, termStr, priorExts); return Pair.of(dbSchema, questionAugmented); } - private String getTermStr(LLMReq llmReq) { + public String buildMetadataStr(LLMReq llmReq) { + String tableStr = llmReq.getSchema().getDataSetName(); + StringBuilder metricStr = new StringBuilder(); + StringBuilder dimensionStr = new StringBuilder(); + + llmReq.getSchema().getMetrics().stream().forEach( + metric -> { + metricStr.append(metric.getName()); + if (StringUtils.isNotEmpty(metric.getDescription())) { + metricStr.append(" COMMENT '" + metric.getDescription() + "'"); + } + if (StringUtils.isNotEmpty(metric.getDefaultAgg())) { + metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'"); + } + metricStr.append(","); + } + ); + + llmReq.getSchema().getDimensions().stream().forEach( + dimension -> { + dimensionStr.append(dimension.getName()); + if (StringUtils.isNotEmpty(dimension.getDescription())) { + dimensionStr.append(" COMMENT '" + dimension.getDescription() + "'"); + } + dimensionStr.append(","); + } + ); + + String template = "Table: %s, Metrics: [%s], Dimensions: [%s]"; + + + return String.format(template, tableStr, metricStr, dimensionStr); + } + + private String buildTermStr(LLMReq llmReq) { List terms = llmReq.getSchema().getTerms(); StringBuilder termsDesc = new StringBuilder(); if (!CollectionUtils.isEmpty(terms)) { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java index c79f74c6c..d995796c7 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.core.chat.query.llm.s2sql; import com.fasterxml.jackson.annotation.JsonValue; import com.google.common.collect.Lists; import com.tencent.supersonic.headless.api.pojo.LLMConfig; +import com.tencent.supersonic.headless.api.pojo.SchemaElement; import lombok.Data; import java.util.List; @@ -46,6 +47,10 @@ public class LLMReq { private List fieldNameList; + private List metrics; + + private List dimensions; + private List terms; }