(improvement)(Headless) Give term to the LLM as a reference (#1044)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2024-05-29 16:43:48 +08:00
committed by GitHub
parent 26ab536c32
commit c7cb6df80b
4 changed files with 61 additions and 6 deletions

View File

@@ -6,7 +6,6 @@ import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import javax.validation.constraints.NotNull;
import java.util.List;
@Data
@@ -16,7 +15,6 @@ public class TermResp extends RecordInfo {
private Long id;
@NotNull(message = "主题域ID不可为空")
private Long domainId;
private String name;

View File

@@ -82,6 +82,7 @@ public class LLMRequestService {
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
llmSchema.setFieldNameList(fieldNameList);
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
llmReq.setSchema(llmSchema);
List<ElementValue> linking = new ArrayList<>();
@@ -115,6 +116,24 @@ public class LLMRequestService {
return new ArrayList<>(results);
}
protected List<LLMReq.Term> getTerms(QueryContext queryCtx, Long dataSetId) {
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>();
}
return matchedElements.stream()
.filter(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.TERM.equals(elementType);
}).map(schemaElementMatch -> {
LLMReq.Term term = new LLMReq.Term();
term.setName(schemaElementMatch.getElement().getName());
term.setDescription(schemaElementMatch.getElement().getDescription());
term.setAlias(schemaElementMatch.getElement().getAlias());
return term;
}).collect(Collectors.toList());
}
private String getPriorExts(QueryContext queryContext, List<String> fieldNameList) {
StringBuilder extraInfoSb = new StringBuilder();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();

View File

@@ -3,14 +3,16 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.ElementValue;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.ArrayList;
import java.util.Map;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@Component
@Slf4j
@@ -112,11 +114,33 @@ public class PromptGenerator {
}
String currentDataStr = "当前的日期是" + currentDate;
String linkingListStr = String.join("", priorLinkingList);
String questionAugmented = String.format("%s (补充信息:%s 。 %s) (备注: %s)", llmReq.getQueryText(), linkingListStr,
currentDataStr, priorExts);
String termStr = getTermStr(llmReq);
String questionAugmented = String.format("%s (补充信息:%s . %s . %s) (备注: %s)", llmReq.getQueryText(), linkingListStr,
currentDataStr, termStr, priorExts);
return Pair.of(dbSchema, questionAugmented);
}
private String getTermStr(LLMReq llmReq) {
List<LLMReq.Term> terms = llmReq.getSchema().getTerms();
StringBuilder termsDesc = new StringBuilder();
if (!CollectionUtils.isEmpty(terms)) {
termsDesc.append("相关业务术语:");
for (int idx = 0 ; idx < terms.size() ; idx++) {
LLMReq.Term term = terms.get(idx);
String name = term.getName();
String description = term.getDescription();
List<String> alias = term.getAlias();
String descPart = StringUtils.isBlank(description) ? "" : String.format(",它通常是指<%s>", description);
String aliasPart = CollectionUtils.isEmpty(alias) ? "" : String.format(",类似的表达还有%s", alias);
termsDesc.append(String.format("%d.<%s>是业务术语%s%s", idx + 1, name, descPart, aliasPart));
}
if (termsDesc.length() > 0) {
termsDesc.setLength(termsDesc.length() - 1);
}
}
return termsDesc.toString();
}
public List<String> generateSqlPromptPool(LLMReq llmReq, List<String> schemaLinkStrPool,
List<List<Map<String, String>>> fewshotExampleListPool) {
List<String> sqlPromptPool = new ArrayList<>();

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.core.chat.query.llm.s2sql;
import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import lombok.Data;
@@ -45,6 +46,8 @@ public class LLMReq {
private List<String> fieldNameList;
private List<Term> terms;
}
@Data
@@ -53,6 +56,17 @@ public class LLMReq {
private String tableName;
}
@Data
public static class Term {
private String name;
private String description;
private List<String> alias = Lists.newArrayList();
}
public enum SqlGenType {
ONE_PASS_AUTO_COT("1_pass_auto_cot"),