From 74ed269544be650bc41f8060ced9fb55e1a1a699 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Tue, 14 Nov 2023 11:12:27 +0800 Subject: [PATCH] (improvement)(chat) linkingValues is not passed to llm and optimize SchemaCorrector code (#378) --- .../chat/corrector/SchemaCorrector.java | 17 +++++++---- .../chat/corrector/WhereCorrector.java | 3 +- .../parser/llm/s2sql/LLMRequestService.java | 30 +++++++++---------- .../chat/parser/llm/s2sql/LLMS2SQLParser.java | 9 ++++-- .../chat/parser/llm/s2sql/ParseResult.java | 4 +++ .../common/pojo/enums/TimeDimensionEnum.java | 10 ++----- 6 files changed, 42 insertions(+), 31 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SchemaCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SchemaCorrector.java index 2362bf0fe..1623a1749 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SchemaCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SchemaCorrector.java @@ -4,7 +4,6 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; import com.tencent.supersonic.chat.parser.llm.s2sql.ParseResult; -import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.JsonUtil; @@ -23,9 +22,9 @@ public class SchemaCorrector extends BaseSemanticCorrector { @Override public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { - String sql = SqlParserReplaceHelper.replaceFunction(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), - AggregateEnum.getAggregateEnum()); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql); + + correctAggFunction(semanticParseInfo); + replaceAlias(semanticParseInfo); updateFieldNameByLinkingValue(semanticParseInfo); @@ -35,6 +34,13 @@ public class SchemaCorrector extends BaseSemanticCorrector { correctFieldName(semanticParseInfo); } + private void correctAggFunction(SemanticParseInfo semanticParseInfo) { + Map aggregateEnum = AggregateEnum.getAggregateEnum(); + SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); + String sql = SqlParserReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum); + sqlInfo.setCorrectS2SQL(sql); + } + private void replaceAlias(SemanticParseInfo semanticParseInfo) { SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL()); @@ -74,8 +80,7 @@ public class SchemaCorrector extends BaseSemanticCorrector { if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) { return null; } - LLMReq llmReq = parseResult.getLlmReq(); - return llmReq.getLinking(); + return parseResult.getLinkingValues(); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java index 9df20f37e..1812b107a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java @@ -74,7 +74,8 @@ public class WhereCorrector extends BaseSemanticCorrector { String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId()); if (StringUtils.isNotBlank(currentDate)) { correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL); - correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate); + correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, TimeDimensionEnum.DAY.getChName(), + currentDate); } } semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMRequestService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMRequestService.java index c88016afb..44935fc40 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMRequestService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMRequestService.java @@ -104,7 +104,7 @@ public class LLMRequestService { return llmParserTool.orElse(null); } - public LLMReq getLlmReq(QueryContext queryCtx, Long modelId) { + public LLMReq getLlmReq(QueryContext queryCtx, Long modelId, List linkingValues) { SemanticSchema semanticSchema = schemaService.getSemanticSchema(); Map modelIdToName = semanticSchema.getModelIdToName(); String queryText = queryCtx.getRequest().getQueryText(); @@ -120,7 +120,7 @@ public class LLMRequestService { llmSchema.setModelName(modelIdToName.get(modelId)); llmSchema.setDomainName(modelIdToName.get(modelId)); - List fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig); + List fieldNameList = getFieldNameList(queryCtx, modelId, llmParserConfig); String priorExts = getPriorExts(modelId, fieldNameList); llmReq.setPriorExts(priorExts); @@ -131,7 +131,7 @@ public class LLMRequestService { List linking = new ArrayList<>(); if (optimizationConfig.isUseLinkingValueSwitch()) { - linking.addAll(getValueList(queryCtx, modelId, semanticSchema)); + linking.addAll(linkingValues); } llmReq.setLinking(linking); @@ -155,7 +155,7 @@ public class LLMRequestService { LLMResp.class); log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}", - System.currentTimeMillis() - startTime, url.toString(), entity, responseEntity.getBody()); + System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody()); return responseEntity.getBody(); } catch (Exception e) { log.error("requestLLM error", e); @@ -163,12 +163,11 @@ public class LLMRequestService { return null; } - protected List getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema, - LLMParserConfig llmParserConfig) { + protected List getFieldNameList(QueryContext queryCtx, Long modelId, LLMParserConfig llmParserConfig) { - Set results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig); + Set results = getTopNFieldNames(modelId, llmParserConfig); - Set fieldNameList = getMatchedFieldNames(queryCtx, modelId, semanticSchema); + Set fieldNameList = getMatchedFieldNames(queryCtx, modelId); results.addAll(fieldNameList); return new ArrayList<>(results); @@ -210,8 +209,8 @@ public class LLMRequestService { } - protected List getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) { - Map itemIdToName = getItemIdToName(modelId, semanticSchema); + protected List getValueList(QueryContext queryCtx, Long modelId) { + Map itemIdToName = getItemIdToName(modelId); List matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId); if (CollectionUtils.isEmpty(matchedElements)) { @@ -233,14 +232,15 @@ public class LLMRequestService { return new ArrayList<>(valueMatches); } - protected Map getItemIdToName(Long modelId, SemanticSchema semanticSchema) { + protected Map getItemIdToName(Long modelId) { + SemanticSchema semanticSchema = schemaService.getSemanticSchema(); return semanticSchema.getDimensions(modelId).stream() .collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2)); } - private Set getTopNFieldNames(Long modelId, SemanticSchema semanticSchema, - LLMParserConfig llmParserConfig) { + private Set getTopNFieldNames(Long modelId, LLMParserConfig llmParserConfig) { + SemanticSchema semanticSchema = schemaService.getSemanticSchema(); Set results = semanticSchema.getDimensions(modelId).stream() .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) .limit(llmParserConfig.getDimensionTopN()) @@ -258,8 +258,8 @@ public class LLMRequestService { } - protected Set getMatchedFieldNames(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) { - Map itemIdToName = getItemIdToName(modelId, semanticSchema); + protected Set getMatchedFieldNames(QueryContext queryCtx, Long modelId) { + Map itemIdToName = getItemIdToName(modelId); List matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId); if (CollectionUtils.isEmpty(matchedElements)) { return new HashSet<>(); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMS2SQLParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMS2SQLParser.java index b5420d574..5f312c1c6 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMS2SQLParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMS2SQLParser.java @@ -6,8 +6,10 @@ import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.common.util.ContextUtils; +import java.util.List; import java.util.Map; import java.util.Objects; import lombok.extern.slf4j.Slf4j; @@ -36,7 +38,8 @@ public class LLMS2SQLParser implements SemanticParser { return; } //4.construct a request, call the API for the large model, and retrieve the results. - LLMReq llmReq = requestService.getLlmReq(queryCtx, modelId); + List linkingValues = requestService.getValueList(queryCtx, modelId); + LLMReq llmReq = requestService.getLlmReq(queryCtx, modelId, linkingValues); LLMResp llmResp = requestService.requestLLM(llmReq, modelId); if (Objects.isNull(llmResp)) { @@ -49,7 +52,9 @@ public class LLMS2SQLParser implements SemanticParser { .modelId(modelId) .commonAgentTool(commonAgentTool) .llmReq(llmReq) - .llmResp(llmResp).build(); + .llmResp(llmResp) + .linkingValues(linkingValues) + .build(); LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/ParseResult.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/ParseResult.java index a01fb8fad..65914f82d 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/ParseResult.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/ParseResult.java @@ -3,7 +3,9 @@ package com.tencent.supersonic.chat.parser.llm.s2sql; import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; +import java.util.List; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; @@ -24,4 +26,6 @@ public class ParseResult { private QueryReq request; private CommonAgentTool commonAgentTool; + + private List linkingValues; } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TimeDimensionEnum.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TimeDimensionEnum.java index 352b17f2c..2ab03c342 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TimeDimensionEnum.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TimeDimensionEnum.java @@ -1,16 +1,15 @@ package com.tencent.supersonic.common.pojo.enums; import cn.hutool.core.collection.CollectionUtil; - import java.util.Arrays; import java.util.List; -import java.util.Set; import java.util.stream.Collectors; public enum TimeDimensionEnum { DAY("sys_imp_date", "数据日期"), + WEEK("sys_imp_week", "数据日期_周"), MONTH("sys_imp_month", "数据日期_月"); @@ -28,10 +27,6 @@ public enum TimeDimensionEnum { return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getName).collect(Collectors.toList()); } - public static Set getChNameSet() { - return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getChName).collect(Collectors.toSet()); - } - public String getName() { return name; } @@ -42,8 +37,9 @@ public enum TimeDimensionEnum { /** * Determine if a time dimension field is included in a Chinese text field + * * @param fields field - * @return true/fase + * @return true/false */ public static boolean containsZhTimeDimension(List fields) { if (CollectionUtil.isEmpty(fields)) {