From c7cb6df80bc2b5703956b69251c87198ef0837fc Mon Sep 17 00:00:00 2001 From: LXW <1264174498@qq.com> Date: Wed, 29 May 2024 16:43:48 +0800 Subject: [PATCH] (improvement)(Headless) Give term to the LLM as a reference (#1044) Co-authored-by: jolunoluo --- .../headless/api/pojo/response/TermResp.java | 2 -- .../chat/parser/llm/LLMRequestService.java | 19 +++++++++++ .../core/chat/parser/llm/PromptGenerator.java | 32 ++++++++++++++++--- .../core/chat/query/llm/s2sql/LLMReq.java | 14 ++++++++ 4 files changed, 61 insertions(+), 6 deletions(-) diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/TermResp.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/TermResp.java index ba34deb9a..39109f834 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/TermResp.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/TermResp.java @@ -6,7 +6,6 @@ import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; -import javax.validation.constraints.NotNull; import java.util.List; @Data @@ -16,7 +15,6 @@ public class TermResp extends RecordInfo { private Long id; - @NotNull(message = "主题域ID不可为空") private Long domainId; private String name; diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java index 0018c1204..7945bb2a6 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java @@ -82,6 +82,7 @@ public class LLMRequestService { fieldNameList.add(TimeDimensionEnum.DAY.getChName()); llmSchema.setFieldNameList(fieldNameList); + llmSchema.setTerms(getTerms(queryCtx, dataSetId)); llmReq.setSchema(llmSchema); List linking = new ArrayList<>(); @@ -115,6 +116,24 @@ public class LLMRequestService { return new ArrayList<>(results); } + protected List getTerms(QueryContext queryCtx, Long dataSetId) { + List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); + if (CollectionUtils.isEmpty(matchedElements)) { + return new ArrayList<>(); + } + return matchedElements.stream() + .filter(schemaElementMatch -> { + SchemaElementType elementType = schemaElementMatch.getElement().getType(); + return SchemaElementType.TERM.equals(elementType); + }).map(schemaElementMatch -> { + LLMReq.Term term = new LLMReq.Term(); + term.setName(schemaElementMatch.getElement().getName()); + term.setDescription(schemaElementMatch.getElement().getDescription()); + term.setAlias(schemaElementMatch.getElement().getAlias()); + return term; + }).collect(Collectors.toList()); + } + private String getPriorExts(QueryContext queryContext, List fieldNameList) { StringBuilder extraInfoSb = new StringBuilder(); SemanticSchema semanticSchema = queryContext.getSemanticSchema(); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PromptGenerator.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PromptGenerator.java index 36c86fd5f..afbc5a28d 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PromptGenerator.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PromptGenerator.java @@ -3,14 +3,16 @@ package com.tencent.supersonic.headless.core.chat.parser.llm; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.ElementValue; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.springframework.stereotype.Component; +import org.springframework.util.CollectionUtils; -import java.util.List; import java.util.ArrayList; -import java.util.Map; import java.util.Arrays; import java.util.Collections; +import java.util.List; +import java.util.Map; @Component @Slf4j @@ -112,11 +114,33 @@ public class PromptGenerator { } String currentDataStr = "当前的日期是" + currentDate; String linkingListStr = String.join(",", priorLinkingList); - String questionAugmented = String.format("%s (补充信息:%s 。 %s) (备注: %s)", llmReq.getQueryText(), linkingListStr, - currentDataStr, priorExts); + String termStr = getTermStr(llmReq); + String questionAugmented = String.format("%s (补充信息:%s . %s . %s) (备注: %s)", llmReq.getQueryText(), linkingListStr, + currentDataStr, termStr, priorExts); return Pair.of(dbSchema, questionAugmented); } + private String getTermStr(LLMReq llmReq) { + List terms = llmReq.getSchema().getTerms(); + StringBuilder termsDesc = new StringBuilder(); + if (!CollectionUtils.isEmpty(terms)) { + termsDesc.append("相关业务术语:"); + for (int idx = 0 ; idx < terms.size() ; idx++) { + LLMReq.Term term = terms.get(idx); + String name = term.getName(); + String description = term.getDescription(); + List alias = term.getAlias(); + String descPart = StringUtils.isBlank(description) ? "" : String.format(",它通常是指<%s>", description); + String aliasPart = CollectionUtils.isEmpty(alias) ? "" : String.format(",类似的表达还有%s", alias); + termsDesc.append(String.format("%d.<%s>是业务术语%s%s;", idx + 1, name, descPart, aliasPart)); + } + if (termsDesc.length() > 0) { + termsDesc.setLength(termsDesc.length() - 1); + } + } + return termsDesc.toString(); + } + public List generateSqlPromptPool(LLMReq llmReq, List schemaLinkStrPool, List>> fewshotExampleListPool) { List sqlPromptPool = new ArrayList<>(); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java index f739b6c9a..9fe011da6 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.core.chat.query.llm.s2sql; import com.fasterxml.jackson.annotation.JsonValue; +import com.google.common.collect.Lists; import com.tencent.supersonic.headless.api.pojo.LLMConfig; import lombok.Data; @@ -45,6 +46,8 @@ public class LLMReq { private List fieldNameList; + private List terms; + } @Data @@ -53,6 +56,17 @@ public class LLMReq { private String tableName; } + @Data + public static class Term { + + private String name; + + private String description; + + private List alias = Lists.newArrayList(); + + } + public enum SqlGenType { ONE_PASS_AUTO_COT("1_pass_auto_cot"),