mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(Headless) Give term to the LLM as a reference (#1044)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -6,7 +6,6 @@ import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import javax.validation.constraints.NotNull;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@@ -16,7 +15,6 @@ public class TermResp extends RecordInfo {
|
||||
|
||||
private Long id;
|
||||
|
||||
@NotNull(message = "主题域ID不可为空")
|
||||
private Long domainId;
|
||||
|
||||
private String name;
|
||||
|
||||
@@ -82,6 +82,7 @@ public class LLMRequestService {
|
||||
|
||||
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
|
||||
llmSchema.setFieldNameList(fieldNameList);
|
||||
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
|
||||
llmReq.setSchema(llmSchema);
|
||||
|
||||
List<ElementValue> linking = new ArrayList<>();
|
||||
@@ -115,6 +116,24 @@ public class LLMRequestService {
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
protected List<LLMReq.Term> getTerms(QueryContext queryCtx, Long dataSetId) {
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
return matchedElements.stream()
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.TERM.equals(elementType);
|
||||
}).map(schemaElementMatch -> {
|
||||
LLMReq.Term term = new LLMReq.Term();
|
||||
term.setName(schemaElementMatch.getElement().getName());
|
||||
term.setDescription(schemaElementMatch.getElement().getDescription());
|
||||
term.setAlias(schemaElementMatch.getElement().getAlias());
|
||||
return term;
|
||||
}).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private String getPriorExts(QueryContext queryContext, List<String> fieldNameList) {
|
||||
StringBuilder extraInfoSb = new StringBuilder();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
|
||||
@@ -3,14 +3,16 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Map;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
@@ -112,11 +114,33 @@ public class PromptGenerator {
|
||||
}
|
||||
String currentDataStr = "当前的日期是" + currentDate;
|
||||
String linkingListStr = String.join(",", priorLinkingList);
|
||||
String questionAugmented = String.format("%s (补充信息:%s 。 %s) (备注: %s)", llmReq.getQueryText(), linkingListStr,
|
||||
currentDataStr, priorExts);
|
||||
String termStr = getTermStr(llmReq);
|
||||
String questionAugmented = String.format("%s (补充信息:%s . %s . %s) (备注: %s)", llmReq.getQueryText(), linkingListStr,
|
||||
currentDataStr, termStr, priorExts);
|
||||
return Pair.of(dbSchema, questionAugmented);
|
||||
}
|
||||
|
||||
private String getTermStr(LLMReq llmReq) {
|
||||
List<LLMReq.Term> terms = llmReq.getSchema().getTerms();
|
||||
StringBuilder termsDesc = new StringBuilder();
|
||||
if (!CollectionUtils.isEmpty(terms)) {
|
||||
termsDesc.append("相关业务术语:");
|
||||
for (int idx = 0 ; idx < terms.size() ; idx++) {
|
||||
LLMReq.Term term = terms.get(idx);
|
||||
String name = term.getName();
|
||||
String description = term.getDescription();
|
||||
List<String> alias = term.getAlias();
|
||||
String descPart = StringUtils.isBlank(description) ? "" : String.format(",它通常是指<%s>", description);
|
||||
String aliasPart = CollectionUtils.isEmpty(alias) ? "" : String.format(",类似的表达还有%s", alias);
|
||||
termsDesc.append(String.format("%d.<%s>是业务术语%s%s;", idx + 1, name, descPart, aliasPart));
|
||||
}
|
||||
if (termsDesc.length() > 0) {
|
||||
termsDesc.setLength(termsDesc.length() - 1);
|
||||
}
|
||||
}
|
||||
return termsDesc.toString();
|
||||
}
|
||||
|
||||
public List<String> generateSqlPromptPool(LLMReq llmReq, List<String> schemaLinkStrPool,
|
||||
List<List<Map<String, String>>> fewshotExampleListPool) {
|
||||
List<String> sqlPromptPool = new ArrayList<>();
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.core.chat.query.llm.s2sql;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonValue;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
||||
import lombok.Data;
|
||||
|
||||
@@ -45,6 +46,8 @@ public class LLMReq {
|
||||
|
||||
private List<String> fieldNameList;
|
||||
|
||||
private List<Term> terms;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
@@ -53,6 +56,17 @@ public class LLMReq {
|
||||
private String tableName;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class Term {
|
||||
|
||||
private String name;
|
||||
|
||||
private String description;
|
||||
|
||||
private List<String> alias = Lists.newArrayList();
|
||||
|
||||
}
|
||||
|
||||
public enum SqlGenType {
|
||||
|
||||
ONE_PASS_AUTO_COT("1_pass_auto_cot"),
|
||||
|
||||
Reference in New Issue
Block a user