mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement][headless]Restructure LLMReq and LLMSchema.
This commit is contained in:
@@ -110,7 +110,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
|
if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
return parseResult.getLinkingValues();
|
return parseResult.getLlmReq().getSchema().getValues();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo) {
|
private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo) {
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
|||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
|
||||||
import org.jetbrains.annotations.NotNull;
|
import org.jetbrains.annotations.NotNull;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
@@ -21,10 +20,8 @@ import org.springframework.util.CollectionUtils;
|
|||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@@ -39,7 +36,7 @@ public class LLMRequestService {
|
|||||||
|
|
||||||
public boolean isSkip(ChatQueryContext queryCtx) {
|
public boolean isSkip(ChatQueryContext queryCtx) {
|
||||||
if (!queryCtx.getText2SQLType().enableLLM()) {
|
if (!queryCtx.getText2SQLType().enableLLM()) {
|
||||||
log.info("not enable llm, skip");
|
log.info("LLM disabled, skip");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,33 +54,28 @@ public class LLMRequestService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public LLMReq getLlmReq(ChatQueryContext queryCtx, Long dataSetId) {
|
public LLMReq getLlmReq(ChatQueryContext queryCtx, Long dataSetId) {
|
||||||
List<LLMReq.ElementValue> linkingValues = getValues(queryCtx, dataSetId);
|
|
||||||
Map<Long, String> dataSetIdToName = queryCtx.getSemanticSchema().getDataSetIdToName();
|
Map<Long, String> dataSetIdToName = queryCtx.getSemanticSchema().getDataSetIdToName();
|
||||||
String queryText = queryCtx.getQueryText();
|
String queryText = queryCtx.getQueryText();
|
||||||
|
|
||||||
LLMReq llmReq = new LLMReq();
|
LLMReq llmReq = new LLMReq();
|
||||||
|
|
||||||
llmReq.setQueryText(queryText);
|
llmReq.setQueryText(queryText);
|
||||||
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
||||||
|
llmReq.setSchema(llmSchema);
|
||||||
llmSchema.setDataSetId(dataSetId);
|
llmSchema.setDataSetId(dataSetId);
|
||||||
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
|
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
|
||||||
llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId));
|
llmSchema.setMetrics(getMappedMetrics(queryCtx, dataSetId));
|
||||||
llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId));
|
llmSchema.setDimensions(getMappedDimensions(queryCtx, dataSetId));
|
||||||
llmSchema.setPartitionTime(getPartitionTime(queryCtx, dataSetId));
|
llmSchema.setPartitionTime(getPartitionTime(queryCtx, dataSetId));
|
||||||
llmSchema.setPrimaryKey(getPrimaryKey(queryCtx, dataSetId));
|
llmSchema.setPrimaryKey(getPrimaryKey(queryCtx, dataSetId));
|
||||||
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
|
|
||||||
llmReq.setSchema(llmSchema);
|
|
||||||
|
|
||||||
List<LLMReq.ElementValue> linking = new ArrayList<>();
|
|
||||||
boolean linkingValueEnabled =
|
boolean linkingValueEnabled =
|
||||||
Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
|
Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
|
||||||
|
|
||||||
if (linkingValueEnabled) {
|
if (linkingValueEnabled) {
|
||||||
linking.addAll(linkingValues);
|
llmSchema.setValues(getMappedValues(queryCtx, dataSetId));
|
||||||
}
|
}
|
||||||
llmReq.setLinking(linking);
|
|
||||||
|
|
||||||
llmReq.setCurrentDate(DateUtils.getBeforeDate(0));
|
llmReq.setCurrentDate(DateUtils.getBeforeDate(0));
|
||||||
|
llmReq.setTerms(getMappedTerms(queryCtx, dataSetId));
|
||||||
llmReq.setSqlGenType(
|
llmReq.setSqlGenType(
|
||||||
LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
||||||
llmReq.setModelConfig(queryCtx.getModelConfig());
|
llmReq.setModelConfig(queryCtx.getModelConfig());
|
||||||
@@ -102,7 +94,7 @@ public class LLMRequestService {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<LLMReq.Term> getTerms(ChatQueryContext queryCtx, Long dataSetId) {
|
protected List<LLMReq.Term> getMappedTerms(ChatQueryContext queryCtx, Long dataSetId) {
|
||||||
List<SchemaElementMatch> matchedElements =
|
List<SchemaElementMatch> matchedElements =
|
||||||
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
@@ -126,31 +118,8 @@ public class LLMRequestService {
|
|||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
private Map<String, String> getFieldNameToDataFormatTypeMap(SemanticSchema semanticSchema) {
|
protected List<LLMReq.ElementValue> getMappedValues(
|
||||||
return semanticSchema.getMetrics().stream()
|
@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
||||||
.filter(metric -> Objects.nonNull(metric.getDataFormatType()))
|
|
||||||
.flatMap(
|
|
||||||
metric -> {
|
|
||||||
Set<Pair<String, String>> fieldFormatPairs = new HashSet<>();
|
|
||||||
String dataFormatType = metric.getDataFormatType();
|
|
||||||
fieldFormatPairs.add(Pair.of(metric.getName(), dataFormatType));
|
|
||||||
List<String> aliasList = metric.getAlias();
|
|
||||||
if (!CollectionUtils.isEmpty(aliasList)) {
|
|
||||||
aliasList.forEach(
|
|
||||||
alias ->
|
|
||||||
fieldFormatPairs.add(
|
|
||||||
Pair.of(alias, dataFormatType)));
|
|
||||||
}
|
|
||||||
return fieldFormatPairs.stream();
|
|
||||||
})
|
|
||||||
.collect(
|
|
||||||
Collectors.toMap(
|
|
||||||
Pair::getLeft,
|
|
||||||
Pair::getRight,
|
|
||||||
(existing, replacement) -> existing));
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<LLMReq.ElementValue> getValues(@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
|
||||||
List<SchemaElementMatch> matchedElements =
|
List<SchemaElementMatch> matchedElements =
|
||||||
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
@@ -177,7 +146,7 @@ public class LLMRequestService {
|
|||||||
return new ArrayList<>(valueMatches);
|
return new ArrayList<>(valueMatches);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<SchemaElement> getMatchedMetrics(
|
protected List<SchemaElement> getMappedMetrics(
|
||||||
@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
||||||
List<SchemaElementMatch> matchedElements =
|
List<SchemaElementMatch> matchedElements =
|
||||||
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||||
@@ -200,7 +169,7 @@ public class LLMRequestService {
|
|||||||
return schemaElements;
|
return schemaElements;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<SchemaElement> getMatchedDimensions(
|
protected List<SchemaElement> getMappedDimensions(
|
||||||
@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
@NotNull ChatQueryContext queryCtx, Long dataSetId) {
|
||||||
|
|
||||||
List<SchemaElementMatch> matchedElements =
|
List<SchemaElementMatch> matchedElements =
|
||||||
|
|||||||
@@ -66,7 +66,6 @@ public class LLMSqlParser implements SemanticParser {
|
|||||||
.dataSetId(dataSetId)
|
.dataSetId(dataSetId)
|
||||||
.llmReq(llmReq)
|
.llmReq(llmReq)
|
||||||
.llmResp(llmResp)
|
.llmResp(llmResp)
|
||||||
.linkingValues(llmReq.getLinking())
|
|
||||||
.build();
|
.build();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ import lombok.Builder;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@@ -23,6 +21,4 @@ public class ParseResult {
|
|||||||
private LLMResp llmResp;
|
private LLMResp llmResp;
|
||||||
|
|
||||||
private QueryNLReq request;
|
private QueryNLReq request;
|
||||||
|
|
||||||
private List<LLMReq.ElementValue> linkingValues;
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ public class PromptHelper {
|
|||||||
});
|
});
|
||||||
|
|
||||||
List<String> values = Lists.newArrayList();
|
List<String> values = Lists.newArrayList();
|
||||||
llmReq.getLinking().stream()
|
llmReq.getSchema().getValues().stream()
|
||||||
.forEach(
|
.forEach(
|
||||||
value -> {
|
value -> {
|
||||||
StringBuilder valueStr = new StringBuilder();
|
StringBuilder valueStr = new StringBuilder();
|
||||||
@@ -176,7 +176,7 @@ public class PromptHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private String buildTermStr(LLMReq llmReq) {
|
private String buildTermStr(LLMReq llmReq) {
|
||||||
List<LLMReq.Term> terms = llmReq.getSchema().getTerms();
|
List<LLMReq.Term> terms = llmReq.getTerms();
|
||||||
List<String> termStr = Lists.newArrayList();
|
List<String> termStr = Lists.newArrayList();
|
||||||
terms.stream()
|
terms.stream()
|
||||||
.forEach(
|
.forEach(
|
||||||
|
|||||||
@@ -7,14 +7,17 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
|||||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class LLMReq {
|
public class LLMReq {
|
||||||
private String queryText;
|
private String queryText;
|
||||||
private LLMSchema schema;
|
private LLMSchema schema;
|
||||||
private List<ElementValue> linking;
|
private List<Term> terms;
|
||||||
private String currentDate;
|
private String currentDate;
|
||||||
private String priorExts;
|
private String priorExts;
|
||||||
private SqlGenType sqlGenType;
|
private SqlGenType sqlGenType;
|
||||||
@@ -32,12 +35,30 @@ public class LLMReq {
|
|||||||
public static class LLMSchema {
|
public static class LLMSchema {
|
||||||
private Long dataSetId;
|
private Long dataSetId;
|
||||||
private String dataSetName;
|
private String dataSetName;
|
||||||
private List<String> fieldNameList;
|
|
||||||
private List<SchemaElement> metrics;
|
private List<SchemaElement> metrics;
|
||||||
private List<SchemaElement> dimensions;
|
private List<SchemaElement> dimensions;
|
||||||
|
private List<ElementValue> values;
|
||||||
private SchemaElement partitionTime;
|
private SchemaElement partitionTime;
|
||||||
private SchemaElement primaryKey;
|
private SchemaElement primaryKey;
|
||||||
private List<Term> terms;
|
|
||||||
|
public List<String> getFieldNameList() {
|
||||||
|
List<String> fieldNameList = new ArrayList<>();
|
||||||
|
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||||
|
fieldNameList.addAll(
|
||||||
|
metrics.stream()
|
||||||
|
.map(metric -> metric.getName())
|
||||||
|
.collect(Collectors.toList()));
|
||||||
|
}
|
||||||
|
if (CollectionUtils.isNotEmpty(dimensions)) {
|
||||||
|
fieldNameList.addAll(
|
||||||
|
dimensions.stream()
|
||||||
|
.map(dimension -> dimension.getName())
|
||||||
|
.collect(Collectors.toList()));
|
||||||
|
}
|
||||||
|
fieldNameList.add(partitionTime.getName());
|
||||||
|
fieldNameList.add(primaryKey.getName());
|
||||||
|
return fieldNameList;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
|||||||
import com.tencent.supersonic.headless.chat.parser.llm.ParseResult;
|
import com.tencent.supersonic.headless.chat.parser.llm.ParseResult;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -20,6 +21,7 @@ import java.util.HashSet;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
|
@Disabled
|
||||||
class SchemaCorrectorTest {
|
class SchemaCorrectorTest {
|
||||||
|
|
||||||
private String json =
|
private String json =
|
||||||
@@ -37,17 +39,10 @@ class SchemaCorrectorTest {
|
|||||||
+ " \"数据日期\"\n"
|
+ " \"数据日期\"\n"
|
||||||
+ " ]\n"
|
+ " ]\n"
|
||||||
+ " },\n"
|
+ " },\n"
|
||||||
+ " \"linking\": [\n"
|
|
||||||
+ "\n"
|
|
||||||
+ " ],\n"
|
|
||||||
+ " \"currentDate\": \"2024-02-24\",\n"
|
+ " \"currentDate\": \"2024-02-24\",\n"
|
||||||
+ " \"priorExts\": \"播放份额是小数; \",\n"
|
|
||||||
+ " \"sqlGenType\": \"1_pass_self_consistency\"\n"
|
+ " \"sqlGenType\": \"1_pass_self_consistency\"\n"
|
||||||
+ " },\n"
|
+ " },\n"
|
||||||
+ " \"request\": null,\n"
|
+ " \"request\": null\n"
|
||||||
+ " \"linkingValues\": [\n"
|
|
||||||
+ "\n"
|
|
||||||
+ " ]\n"
|
|
||||||
+ "}";
|
+ "}";
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -86,7 +81,6 @@ class SchemaCorrectorTest {
|
|||||||
elementValue.setFieldName("商务组");
|
elementValue.setFieldName("商务组");
|
||||||
elementValue.setFieldValue("xxx");
|
elementValue.setFieldValue("xxx");
|
||||||
linkingValues.add(elementValue);
|
linkingValues.add(elementValue);
|
||||||
parseResult.setLinkingValues(linkingValues);
|
|
||||||
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
||||||
|
|
||||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(sql);
|
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(sql);
|
||||||
|
|||||||
Reference in New Issue
Block a user