mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
(improvement)(headless)Refactor text2sql prompts by getting more use of the data semantics. #1149
This commit is contained in:
@@ -23,6 +23,7 @@ import org.springframework.beans.factory.annotation.Autowired;
|
|||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -80,15 +81,17 @@ public class LLMRequestService {
|
|||||||
llmSchema.setDomainName(dataSetIdToName.get(dataSetId));
|
llmSchema.setDomainName(dataSetIdToName.get(dataSetId));
|
||||||
|
|
||||||
List<String> fieldNameList = getFieldNameList(queryCtx, dataSetId, llmParserConfig);
|
List<String> fieldNameList = getFieldNameList(queryCtx, dataSetId, llmParserConfig);
|
||||||
|
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
|
||||||
|
llmSchema.setFieldNameList(fieldNameList);
|
||||||
|
|
||||||
|
llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId));
|
||||||
|
llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId));
|
||||||
|
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
|
||||||
|
llmReq.setSchema(llmSchema);
|
||||||
|
|
||||||
String priorExts = getPriorExts(queryCtx, fieldNameList);
|
String priorExts = getPriorExts(queryCtx, fieldNameList);
|
||||||
llmReq.setPriorExts(priorExts);
|
llmReq.setPriorExts(priorExts);
|
||||||
|
|
||||||
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
|
|
||||||
llmSchema.setFieldNameList(fieldNameList);
|
|
||||||
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
|
|
||||||
llmReq.setSchema(llmSchema);
|
|
||||||
|
|
||||||
List<ElementValue> linking = new ArrayList<>();
|
List<ElementValue> linking = new ArrayList<>();
|
||||||
boolean linkingValueEnabled = Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
|
boolean linkingValueEnabled = Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
|
||||||
|
|
||||||
@@ -104,10 +107,11 @@ public class LLMRequestService {
|
|||||||
llmReq.setCurrentDate(currentDate);
|
llmReq.setCurrentDate(currentDate);
|
||||||
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
||||||
llmReq.setLlmConfig(queryCtx.getLlmConfig());
|
llmReq.setLlmConfig(queryCtx.getLlmConfig());
|
||||||
|
|
||||||
return llmReq;
|
return llmReq;
|
||||||
}
|
}
|
||||||
|
|
||||||
public LLMResp invokeLLM(LLMReq llmReq) {
|
public LLMResp runText2SQL(LLMReq llmReq) {
|
||||||
return ComponentFactory.getLLMProxy().text2sql(llmReq);
|
return ComponentFactory.getLLMProxy().text2sql(llmReq);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,7 +173,7 @@ public class LLMRequestService {
|
|||||||
return extraInfoSb.toString();
|
return extraInfoSb.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<ElementValue> getValueList(QueryContext queryCtx, Long dataSetId) {
|
public List<ElementValue> getValues(QueryContext queryCtx, Long dataSetId) {
|
||||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
|
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
|
||||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
@@ -216,6 +220,40 @@ public class LLMRequestService {
|
|||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected List<SchemaElement> getMatchedMetrics(QueryContext queryCtx, Long dataSetId) {
|
||||||
|
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||||
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
List<SchemaElement> schemaElements = matchedElements.stream()
|
||||||
|
.filter(schemaElementMatch -> {
|
||||||
|
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||||
|
return SchemaElementType.METRIC.equals(elementType);
|
||||||
|
})
|
||||||
|
.map(schemaElementMatch -> {
|
||||||
|
return schemaElementMatch.getElement();
|
||||||
|
})
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
return schemaElements;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected List<SchemaElement> getMatchedDimensions(QueryContext queryCtx, Long dataSetId) {
|
||||||
|
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||||
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
List<SchemaElement> schemaElements = matchedElements.stream()
|
||||||
|
.filter(schemaElementMatch -> {
|
||||||
|
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||||
|
return SchemaElementType.DIMENSION.equals(elementType);
|
||||||
|
})
|
||||||
|
.map(schemaElementMatch -> {
|
||||||
|
return schemaElementMatch.getElement();
|
||||||
|
})
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
return schemaElements;
|
||||||
|
}
|
||||||
|
|
||||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long dataSetId) {
|
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long dataSetId) {
|
||||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
|
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
|
||||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||||
|
|||||||
@@ -40,10 +40,10 @@ public class LLMSqlParser implements SemanticParser {
|
|||||||
log.info("Generate query statement for dataSetId:{}", dataSetId);
|
log.info("Generate query statement for dataSetId:{}", dataSetId);
|
||||||
|
|
||||||
//3.invoke LLM service to do parsing.
|
//3.invoke LLM service to do parsing.
|
||||||
List<LLMReq.ElementValue> linkingValues = requestService.getValueList(queryCtx, dataSetId);
|
List<LLMReq.ElementValue> linkingValues = requestService.getValues(queryCtx, dataSetId);
|
||||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues);
|
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues);
|
||||||
LLMResp llmResp = requestService.invokeLLM(llmReq);
|
LLMResp llmResp = requestService.runText2SQL(llmReq);
|
||||||
if (Objects.isNull(llmResp)) {
|
if (Objects.isNull(llmResp)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -68,7 +69,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
+ "1.ALWAYS use `数据日期` as the date field.\n"
|
+ "1.ALWAYS use `数据日期` as the date field.\n"
|
||||||
+ "2.ALWAYS use `datediff()` as the date function.\n"
|
+ "2.ALWAYS use `datediff()` as the date function.\n"
|
||||||
+ "3.DO NOT specify date filter in the where clause if not explicitly mentioned in the query.\n"
|
+ "3.DO NOT specify date filter in the where clause if not explicitly mentioned in the query.\n"
|
||||||
+ "4.ONLY output SQL statement.\n"
|
+ "4.ONLY respond with the converted SQL statement.\n"
|
||||||
+ "#Exemplars:\n%s"
|
+ "#Exemplars:\n%s"
|
||||||
+ "#UserQuery: %s "
|
+ "#UserQuery: %s "
|
||||||
+ "#DatabaseMetadata: %s "
|
+ "#DatabaseMetadata: %s "
|
||||||
@@ -85,11 +86,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
|
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
|
||||||
String dbSchema = questionPrompt.getLeft();
|
String dataSemanticsStr = promptHelper.buildMetadataStr(llmReq);
|
||||||
String questionAugmented = questionPrompt.getRight();
|
String questionAugmented = questionPrompt.getRight();
|
||||||
String promptStr = String.format(instruction, exemplarsStr, questionAugmented, dbSchema);
|
String promptStr = String.format(instruction, exemplarsStr, questionAugmented, dataSemanticsStr);
|
||||||
|
|
||||||
return PromptTemplate.from(promptStr).apply(new HashMap<>());
|
return PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -65,14 +65,48 @@ public class PromptHelper {
|
|||||||
}
|
}
|
||||||
String currentDataStr = "当前的日期是" + currentDate;
|
String currentDataStr = "当前的日期是" + currentDate;
|
||||||
String linkingListStr = String.join(",", priorLinkingList);
|
String linkingListStr = String.join(",", priorLinkingList);
|
||||||
String termStr = getTermStr(llmReq);
|
String termStr = buildTermStr(llmReq);
|
||||||
String questionAugmented = String.format("%s (补充信息:%s;%s;%s;%s)", llmReq.getQueryText(),
|
String questionAugmented = String.format("%s (补充信息:%s;%s;%s;%s)", llmReq.getQueryText(),
|
||||||
linkingListStr, currentDataStr, termStr, priorExts);
|
linkingListStr, currentDataStr, termStr, priorExts);
|
||||||
|
|
||||||
return Pair.of(dbSchema, questionAugmented);
|
return Pair.of(dbSchema, questionAugmented);
|
||||||
}
|
}
|
||||||
|
|
||||||
private String getTermStr(LLMReq llmReq) {
|
public String buildMetadataStr(LLMReq llmReq) {
|
||||||
|
String tableStr = llmReq.getSchema().getDataSetName();
|
||||||
|
StringBuilder metricStr = new StringBuilder();
|
||||||
|
StringBuilder dimensionStr = new StringBuilder();
|
||||||
|
|
||||||
|
llmReq.getSchema().getMetrics().stream().forEach(
|
||||||
|
metric -> {
|
||||||
|
metricStr.append(metric.getName());
|
||||||
|
if (StringUtils.isNotEmpty(metric.getDescription())) {
|
||||||
|
metricStr.append(" COMMENT '" + metric.getDescription() + "'");
|
||||||
|
}
|
||||||
|
if (StringUtils.isNotEmpty(metric.getDefaultAgg())) {
|
||||||
|
metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'");
|
||||||
|
}
|
||||||
|
metricStr.append(",");
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
llmReq.getSchema().getDimensions().stream().forEach(
|
||||||
|
dimension -> {
|
||||||
|
dimensionStr.append(dimension.getName());
|
||||||
|
if (StringUtils.isNotEmpty(dimension.getDescription())) {
|
||||||
|
dimensionStr.append(" COMMENT '" + dimension.getDescription() + "'");
|
||||||
|
}
|
||||||
|
dimensionStr.append(",");
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
String template = "Table: %s, Metrics: [%s], Dimensions: [%s]";
|
||||||
|
|
||||||
|
|
||||||
|
return String.format(template, tableStr, metricStr, dimensionStr);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String buildTermStr(LLMReq llmReq) {
|
||||||
List<LLMReq.Term> terms = llmReq.getSchema().getTerms();
|
List<LLMReq.Term> terms = llmReq.getSchema().getTerms();
|
||||||
StringBuilder termsDesc = new StringBuilder();
|
StringBuilder termsDesc = new StringBuilder();
|
||||||
if (!CollectionUtils.isEmpty(terms)) {
|
if (!CollectionUtils.isEmpty(terms)) {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.core.chat.query.llm.s2sql;
|
|||||||
import com.fasterxml.jackson.annotation.JsonValue;
|
import com.fasterxml.jackson.annotation.JsonValue;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -46,6 +47,10 @@ public class LLMReq {
|
|||||||
|
|
||||||
private List<String> fieldNameList;
|
private List<String> fieldNameList;
|
||||||
|
|
||||||
|
private List<SchemaElement> metrics;
|
||||||
|
|
||||||
|
private List<SchemaElement> dimensions;
|
||||||
|
|
||||||
private List<Term> terms;
|
private List<Term> terms;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user