(improvement)(headless)Refactor text2sql prompts by getting more use of the data semantics. #1149

This commit is contained in:
jerryjzhang
2024-06-15 01:01:14 +08:00
parent f8b818cb82
commit eadd20046e
5 changed files with 93 additions and 15 deletions

View File

@@ -23,6 +23,7 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
@@ -80,15 +81,17 @@ public class LLMRequestService {
llmSchema.setDomainName(dataSetIdToName.get(dataSetId));
List<String> fieldNameList = getFieldNameList(queryCtx, dataSetId, llmParserConfig);
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
llmSchema.setFieldNameList(fieldNameList);
llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId));
llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId));
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
llmReq.setSchema(llmSchema);
String priorExts = getPriorExts(queryCtx, fieldNameList);
llmReq.setPriorExts(priorExts);
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
llmSchema.setFieldNameList(fieldNameList);
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
llmReq.setSchema(llmSchema);
List<ElementValue> linking = new ArrayList<>();
boolean linkingValueEnabled = Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
@@ -104,10 +107,11 @@ public class LLMRequestService {
llmReq.setCurrentDate(currentDate);
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
llmReq.setLlmConfig(queryCtx.getLlmConfig());
return llmReq;
}
public LLMResp invokeLLM(LLMReq llmReq) {
public LLMResp runText2SQL(LLMReq llmReq) {
return ComponentFactory.getLLMProxy().text2sql(llmReq);
}
@@ -169,7 +173,7 @@ public class LLMRequestService {
return extraInfoSb.toString();
}
public List<ElementValue> getValueList(QueryContext queryCtx, Long dataSetId) {
public List<ElementValue> getValues(QueryContext queryCtx, Long dataSetId) {
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
@@ -216,6 +220,40 @@ public class LLMRequestService {
return results;
}
protected List<SchemaElement> getMatchedMetrics(QueryContext queryCtx, Long dataSetId) {
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
return Collections.emptyList();
}
List<SchemaElement> schemaElements = matchedElements.stream()
.filter(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.METRIC.equals(elementType);
})
.map(schemaElementMatch -> {
return schemaElementMatch.getElement();
})
.collect(Collectors.toList());
return schemaElements;
}
protected List<SchemaElement> getMatchedDimensions(QueryContext queryCtx, Long dataSetId) {
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
return Collections.emptyList();
}
List<SchemaElement> schemaElements = matchedElements.stream()
.filter(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.DIMENSION.equals(elementType);
})
.map(schemaElementMatch -> {
return schemaElementMatch.getElement();
})
.collect(Collectors.toList());
return schemaElements;
}
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long dataSetId) {
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);

View File

@@ -40,10 +40,10 @@ public class LLMSqlParser implements SemanticParser {
log.info("Generate query statement for dataSetId:{}", dataSetId);
//3.invoke LLM service to do parsing.
List<LLMReq.ElementValue> linkingValues = requestService.getValueList(queryCtx, dataSetId);
List<LLMReq.ElementValue> linkingValues = requestService.getValues(queryCtx, dataSetId);
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues);
LLMResp llmResp = requestService.invokeLLM(llmReq);
LLMResp llmResp = requestService.runText2SQL(llmReq);
if (Objects.isNull(llmResp)) {
return;
}

View File

@@ -12,6 +12,7 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Service;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -68,7 +69,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
+ "1.ALWAYS use `数据日期` as the date field.\n"
+ "2.ALWAYS use `datediff()` as the date function.\n"
+ "3.DO NOT specify date filter in the where clause if not explicitly mentioned in the query.\n"
+ "4.ONLY output SQL statement.\n"
+ "4.ONLY respond with the converted SQL statement.\n"
+ "#Exemplars:\n%s"
+ "#UserQuery: %s "
+ "#DatabaseMetadata: %s "
@@ -85,11 +86,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
}
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
String dbSchema = questionPrompt.getLeft();
String dataSemanticsStr = promptHelper.buildMetadataStr(llmReq);
String questionAugmented = questionPrompt.getRight();
String promptStr = String.format(instruction, exemplarsStr, questionAugmented, dbSchema);
String promptStr = String.format(instruction, exemplarsStr, questionAugmented, dataSemanticsStr);
return PromptTemplate.from(promptStr).apply(new HashMap<>());
return PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
}
@Override

View File

@@ -65,14 +65,48 @@ public class PromptHelper {
}
String currentDataStr = "当前的日期是" + currentDate;
String linkingListStr = String.join("", priorLinkingList);
String termStr = getTermStr(llmReq);
String termStr = buildTermStr(llmReq);
String questionAugmented = String.format("%s (补充信息:%s;%s;%s;%s)", llmReq.getQueryText(),
linkingListStr, currentDataStr, termStr, priorExts);
return Pair.of(dbSchema, questionAugmented);
}
private String getTermStr(LLMReq llmReq) {
public String buildMetadataStr(LLMReq llmReq) {
String tableStr = llmReq.getSchema().getDataSetName();
StringBuilder metricStr = new StringBuilder();
StringBuilder dimensionStr = new StringBuilder();
llmReq.getSchema().getMetrics().stream().forEach(
metric -> {
metricStr.append(metric.getName());
if (StringUtils.isNotEmpty(metric.getDescription())) {
metricStr.append(" COMMENT '" + metric.getDescription() + "'");
}
if (StringUtils.isNotEmpty(metric.getDefaultAgg())) {
metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'");
}
metricStr.append(",");
}
);
llmReq.getSchema().getDimensions().stream().forEach(
dimension -> {
dimensionStr.append(dimension.getName());
if (StringUtils.isNotEmpty(dimension.getDescription())) {
dimensionStr.append(" COMMENT '" + dimension.getDescription() + "'");
}
dimensionStr.append(",");
}
);
String template = "Table: %s, Metrics: [%s], Dimensions: [%s]";
return String.format(template, tableStr, metricStr, dimensionStr);
}
private String buildTermStr(LLMReq llmReq) {
List<LLMReq.Term> terms = llmReq.getSchema().getTerms();
StringBuilder termsDesc = new StringBuilder();
if (!CollectionUtils.isEmpty(terms)) {

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.core.chat.query.llm.s2sql;
import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import lombok.Data;
import java.util.List;
@@ -46,6 +47,10 @@ public class LLMReq {
private List<String> fieldNameList;
private List<SchemaElement> metrics;
private List<SchemaElement> dimensions;
private List<Term> terms;
}