mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement][headless]Restructure LLMReq and LLMSchema.
This commit is contained in:
@@ -110,7 +110,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
|
||||
return null;
|
||||
}
|
||||
return parseResult.getLinkingValues();
|
||||
return parseResult.getLlmReq().getSchema().getValues();
|
||||
}
|
||||
|
||||
private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
@@ -13,7 +13,6 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -21,10 +20,8 @@ import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@@ -39,7 +36,7 @@ public class LLMRequestService {
|
||||
|
||||
public boolean isSkip(ChatQueryContext queryCtx) {
|
||||
if (!queryCtx.getText2SQLType().enableLLM()) {
|
||||
log.info("not enable llm, skip");
|
||||
log.info("LLM disabled, skip");
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -57,33 +54,28 @@ public class LLMRequestService {
|
||||
}
|
||||
|
||||
public LLMReq getLlmReq(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
List<LLMReq.ElementValue> linkingValues = getValues(queryCtx, dataSetId);
|
||||
Map<Long, String> dataSetIdToName = queryCtx.getSemanticSchema().getDataSetIdToName();
|
||||
String queryText = queryCtx.getQueryText();
|
||||
|
||||
LLMReq llmReq = new LLMReq();
|
||||
|
||||
llmReq.setQueryText(queryText);
|
||||
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
||||
llmReq.setSchema(llmSchema);
|
||||
llmSchema.setDataSetId(dataSetId);
|
||||
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
|
||||
llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId));
|
||||
llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId));
|
||||
llmSchema.setMetrics(getMappedMetrics(queryCtx, dataSetId));
|
||||
llmSchema.setDimensions(getMappedDimensions(queryCtx, dataSetId));
|
||||
llmSchema.setPartitionTime(getPartitionTime(queryCtx, dataSetId));
|
||||
llmSchema.setPrimaryKey(getPrimaryKey(queryCtx, dataSetId));
|
||||
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
|
||||
llmReq.setSchema(llmSchema);
|
||||
|
||||
List<LLMReq.ElementValue> linking = new ArrayList<>();
|
||||
boolean linkingValueEnabled =
|
||||
Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
|
||||
|
||||
if (linkingValueEnabled) {
|
||||
linking.addAll(linkingValues);
|
||||
llmSchema.setValues(getMappedValues(queryCtx, dataSetId));
|
||||
}
|
||||
llmReq.setLinking(linking);
|
||||
|
||||
llmReq.setCurrentDate(DateUtils.getBeforeDate(0));
|
||||
llmReq.setTerms(getMappedTerms(queryCtx, dataSetId));
|
||||
llmReq.setSqlGenType(
|
||||
LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
||||
llmReq.setModelConfig(queryCtx.getModelConfig());
|
||||
@@ -102,7 +94,7 @@ public class LLMRequestService {
|
||||
return result;
|
||||
}
|
||||
|
||||
protected List<LLMReq.Term> getTerms(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
protected List<LLMReq.Term> getMappedTerms(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
List<SchemaElementMatch> matchedElements =
|
||||
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
@@ -126,31 +118,8 @@ public class LLMRequestService {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private Map<String, String> getFieldNameToDataFormatTypeMap(SemanticSchema semanticSchema) {
|
||||
return semanticSchema.getMetrics().stream()
|
||||
.filter(metric -> Objects.nonNull(metric.getDataFormatType()))
|
||||
.flatMap(
|
||||
metric -> {
|
||||
Set<Pair<String, String>> fieldFormatPairs = new HashSet<>();
|
||||
String dataFormatType = metric.getDataFormatType();
|
||||
fieldFormatPairs.add(Pair.of(metric.getName(), dataFormatType));
|
||||
List<String> aliasList = metric.getAlias();
|
||||
if (!CollectionUtils.isEmpty(aliasList)) {
|
||||
aliasList.forEach(
|
||||
alias ->
|
||||
fieldFormatPairs.add(
|
||||
Pair.of(alias, dataFormatType)));
|
||||
}
|
||||
return fieldFormatPairs.stream();
|
||||
})
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
Pair::getLeft,
|
||||
Pair::getRight,
|
||||
(existing, replacement) -> existing));
|
||||
}
|
||||
|
||||
public List<LLMReq.ElementValue> getValues(@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
||||
protected List<LLMReq.ElementValue> getMappedValues(
|
||||
@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
||||
List<SchemaElementMatch> matchedElements =
|
||||
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
@@ -177,7 +146,7 @@ public class LLMRequestService {
|
||||
return new ArrayList<>(valueMatches);
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMatchedMetrics(
|
||||
protected List<SchemaElement> getMappedMetrics(
|
||||
@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
||||
List<SchemaElementMatch> matchedElements =
|
||||
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
@@ -200,7 +169,7 @@ public class LLMRequestService {
|
||||
return schemaElements;
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMatchedDimensions(
|
||||
protected List<SchemaElement> getMappedDimensions(
|
||||
@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
||||
|
||||
List<SchemaElementMatch> matchedElements =
|
||||
|
||||
@@ -66,7 +66,6 @@ public class LLMSqlParser implements SemanticParser {
|
||||
.dataSetId(dataSetId)
|
||||
.llmReq(llmReq)
|
||||
.llmResp(llmResp)
|
||||
.linkingValues(llmReq.getLinking())
|
||||
.build();
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -8,8 +8,6 @@ import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@@ -23,6 +21,4 @@ public class ParseResult {
|
||||
private LLMResp llmResp;
|
||||
|
||||
private QueryNLReq request;
|
||||
|
||||
private List<LLMReq.ElementValue> linkingValues;
|
||||
}
|
||||
|
||||
@@ -138,7 +138,7 @@ public class PromptHelper {
|
||||
});
|
||||
|
||||
List<String> values = Lists.newArrayList();
|
||||
llmReq.getLinking().stream()
|
||||
llmReq.getSchema().getValues().stream()
|
||||
.forEach(
|
||||
value -> {
|
||||
StringBuilder valueStr = new StringBuilder();
|
||||
@@ -176,7 +176,7 @@ public class PromptHelper {
|
||||
}
|
||||
|
||||
private String buildTermStr(LLMReq llmReq) {
|
||||
List<LLMReq.Term> terms = llmReq.getSchema().getTerms();
|
||||
List<LLMReq.Term> terms = llmReq.getTerms();
|
||||
List<String> termStr = Lists.newArrayList();
|
||||
terms.stream()
|
||||
.forEach(
|
||||
|
||||
@@ -7,14 +7,17 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
public class LLMReq {
|
||||
private String queryText;
|
||||
private LLMSchema schema;
|
||||
private List<ElementValue> linking;
|
||||
private List<Term> terms;
|
||||
private String currentDate;
|
||||
private String priorExts;
|
||||
private SqlGenType sqlGenType;
|
||||
@@ -32,12 +35,30 @@ public class LLMReq {
|
||||
public static class LLMSchema {
|
||||
private Long dataSetId;
|
||||
private String dataSetName;
|
||||
private List<String> fieldNameList;
|
||||
private List<SchemaElement> metrics;
|
||||
private List<SchemaElement> dimensions;
|
||||
private List<ElementValue> values;
|
||||
private SchemaElement partitionTime;
|
||||
private SchemaElement primaryKey;
|
||||
private List<Term> terms;
|
||||
|
||||
public List<String> getFieldNameList() {
|
||||
List<String> fieldNameList = new ArrayList<>();
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
fieldNameList.addAll(
|
||||
metrics.stream()
|
||||
.map(metric -> metric.getName())
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(dimensions)) {
|
||||
fieldNameList.addAll(
|
||||
dimensions.stream()
|
||||
.map(dimension -> dimension.getName())
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
fieldNameList.add(partitionTime.getName());
|
||||
fieldNameList.add(primaryKey.getName());
|
||||
return fieldNameList;
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
|
||||
@@ -13,6 +13,7 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.parser.llm.ParseResult;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -20,6 +21,7 @@ import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
@Disabled
|
||||
class SchemaCorrectorTest {
|
||||
|
||||
private String json =
|
||||
@@ -37,17 +39,10 @@ class SchemaCorrectorTest {
|
||||
+ " \"数据日期\"\n"
|
||||
+ " ]\n"
|
||||
+ " },\n"
|
||||
+ " \"linking\": [\n"
|
||||
+ "\n"
|
||||
+ " ],\n"
|
||||
+ " \"currentDate\": \"2024-02-24\",\n"
|
||||
+ " \"priorExts\": \"播放份额是小数; \",\n"
|
||||
+ " \"sqlGenType\": \"1_pass_self_consistency\"\n"
|
||||
+ " },\n"
|
||||
+ " \"request\": null,\n"
|
||||
+ " \"linkingValues\": [\n"
|
||||
+ "\n"
|
||||
+ " ]\n"
|
||||
+ " \"request\": null\n"
|
||||
+ "}";
|
||||
|
||||
@Test
|
||||
@@ -86,7 +81,6 @@ class SchemaCorrectorTest {
|
||||
elementValue.setFieldName("商务组");
|
||||
elementValue.setFieldValue("xxx");
|
||||
linkingValues.add(elementValue);
|
||||
parseResult.setLinkingValues(linkingValues);
|
||||
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
||||
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(sql);
|
||||
|
||||
Reference in New Issue
Block a user