(improvement)(chat) linkingValues is not passed to llm and optimize SchemaCorrector code (#378)

This commit is contained in:
lexluo09
2023-11-14 11:12:27 +08:00
committed by GitHub
parent 1ad2c5402b
commit 74ed269544
6 changed files with 42 additions and 31 deletions

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.parser.llm.s2sql.ParseResult; import com.tencent.supersonic.chat.parser.llm.s2sql.ParseResult;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
@@ -23,9 +22,9 @@ public class SchemaCorrector extends BaseSemanticCorrector {
@Override @Override
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
String sql = SqlParserReplaceHelper.replaceFunction(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
AggregateEnum.getAggregateEnum()); correctAggFunction(semanticParseInfo);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql);
replaceAlias(semanticParseInfo); replaceAlias(semanticParseInfo);
updateFieldNameByLinkingValue(semanticParseInfo); updateFieldNameByLinkingValue(semanticParseInfo);
@@ -35,6 +34,13 @@ public class SchemaCorrector extends BaseSemanticCorrector {
correctFieldName(semanticParseInfo); correctFieldName(semanticParseInfo);
} }
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlParserReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
sqlInfo.setCorrectS2SQL(sql);
}
private void replaceAlias(SemanticParseInfo semanticParseInfo) { private void replaceAlias(SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL()); String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
@@ -74,8 +80,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) { if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
return null; return null;
} }
LLMReq llmReq = parseResult.getLlmReq(); return parseResult.getLinkingValues();
return llmReq.getLinking();
} }

View File

@@ -74,7 +74,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId()); String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId());
if (StringUtils.isNotBlank(currentDate)) { if (StringUtils.isNotBlank(currentDate)) {
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL); correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate); correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, TimeDimensionEnum.DAY.getChName(),
currentDate);
} }
} }
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);

View File

@@ -104,7 +104,7 @@ public class LLMRequestService {
return llmParserTool.orElse(null); return llmParserTool.orElse(null);
} }
public LLMReq getLlmReq(QueryContext queryCtx, Long modelId) { public LLMReq getLlmReq(QueryContext queryCtx, Long modelId, List<ElementValue> linkingValues) {
SemanticSchema semanticSchema = schemaService.getSemanticSchema(); SemanticSchema semanticSchema = schemaService.getSemanticSchema();
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName(); Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
String queryText = queryCtx.getRequest().getQueryText(); String queryText = queryCtx.getRequest().getQueryText();
@@ -120,7 +120,7 @@ public class LLMRequestService {
llmSchema.setModelName(modelIdToName.get(modelId)); llmSchema.setModelName(modelIdToName.get(modelId));
llmSchema.setDomainName(modelIdToName.get(modelId)); llmSchema.setDomainName(modelIdToName.get(modelId));
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig); List<String> fieldNameList = getFieldNameList(queryCtx, modelId, llmParserConfig);
String priorExts = getPriorExts(modelId, fieldNameList); String priorExts = getPriorExts(modelId, fieldNameList);
llmReq.setPriorExts(priorExts); llmReq.setPriorExts(priorExts);
@@ -131,7 +131,7 @@ public class LLMRequestService {
List<ElementValue> linking = new ArrayList<>(); List<ElementValue> linking = new ArrayList<>();
if (optimizationConfig.isUseLinkingValueSwitch()) { if (optimizationConfig.isUseLinkingValueSwitch()) {
linking.addAll(getValueList(queryCtx, modelId, semanticSchema)); linking.addAll(linkingValues);
} }
llmReq.setLinking(linking); llmReq.setLinking(linking);
@@ -155,7 +155,7 @@ public class LLMRequestService {
LLMResp.class); LLMResp.class);
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}", log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
System.currentTimeMillis() - startTime, url.toString(), entity, responseEntity.getBody()); System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody());
return responseEntity.getBody(); return responseEntity.getBody();
} catch (Exception e) { } catch (Exception e) {
log.error("requestLLM error", e); log.error("requestLLM error", e);
@@ -163,12 +163,11 @@ public class LLMRequestService {
return null; return null;
} }
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema, protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, LLMParserConfig llmParserConfig) {
LLMParserConfig llmParserConfig) {
Set<String> results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig); Set<String> results = getTopNFieldNames(modelId, llmParserConfig);
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelId, semanticSchema); Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelId);
results.addAll(fieldNameList); results.addAll(fieldNameList);
return new ArrayList<>(results); return new ArrayList<>(results);
@@ -210,8 +209,8 @@ public class LLMRequestService {
} }
protected List<ElementValue> getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) { protected List<ElementValue> getValueList(QueryContext queryCtx, Long modelId) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema); Map<Long, String> itemIdToName = getItemIdToName(modelId);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId); List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
if (CollectionUtils.isEmpty(matchedElements)) { if (CollectionUtils.isEmpty(matchedElements)) {
@@ -233,14 +232,15 @@ public class LLMRequestService {
return new ArrayList<>(valueMatches); return new ArrayList<>(valueMatches);
} }
protected Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) { protected Map<Long, String> getItemIdToName(Long modelId) {
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
return semanticSchema.getDimensions(modelId).stream() return semanticSchema.getDimensions(modelId).stream()
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2)); .collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
} }
private Set<String> getTopNFieldNames(Long modelId, SemanticSchema semanticSchema, private Set<String> getTopNFieldNames(Long modelId, LLMParserConfig llmParserConfig) {
LLMParserConfig llmParserConfig) { SemanticSchema semanticSchema = schemaService.getSemanticSchema();
Set<String> results = semanticSchema.getDimensions(modelId).stream() Set<String> results = semanticSchema.getDimensions(modelId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN()) .limit(llmParserConfig.getDimensionTopN())
@@ -258,8 +258,8 @@ public class LLMRequestService {
} }
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) { protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long modelId) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema); Map<Long, String> itemIdToName = getItemIdToName(modelId);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId); List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
if (CollectionUtils.isEmpty(matchedElements)) { if (CollectionUtils.isEmpty(matchedElements)) {
return new HashSet<>(); return new HashSet<>();

View File

@@ -6,8 +6,10 @@ import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -36,7 +38,8 @@ public class LLMS2SQLParser implements SemanticParser {
return; return;
} }
//4.construct a request, call the API for the large model, and retrieve the results. //4.construct a request, call the API for the large model, and retrieve the results.
LLMReq llmReq = requestService.getLlmReq(queryCtx, modelId); List<ElementValue> linkingValues = requestService.getValueList(queryCtx, modelId);
LLMReq llmReq = requestService.getLlmReq(queryCtx, modelId, linkingValues);
LLMResp llmResp = requestService.requestLLM(llmReq, modelId); LLMResp llmResp = requestService.requestLLM(llmReq, modelId);
if (Objects.isNull(llmResp)) { if (Objects.isNull(llmResp)) {
@@ -49,7 +52,9 @@ public class LLMS2SQLParser implements SemanticParser {
.modelId(modelId) .modelId(modelId)
.commonAgentTool(commonAgentTool) .commonAgentTool(commonAgentTool)
.llmReq(llmReq) .llmReq(llmReq)
.llmResp(llmResp).build(); .llmResp(llmResp)
.linkingValues(linkingValues)
.build();
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class); LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);

View File

@@ -3,7 +3,9 @@ package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import java.util.List;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
@@ -24,4 +26,6 @@ public class ParseResult {
private QueryReq request; private QueryReq request;
private CommonAgentTool commonAgentTool; private CommonAgentTool commonAgentTool;
private List<ElementValue> linkingValues;
} }

View File

@@ -1,16 +1,15 @@
package com.tencent.supersonic.common.pojo.enums; package com.tencent.supersonic.common.pojo.enums;
import cn.hutool.core.collection.CollectionUtil; import cn.hutool.core.collection.CollectionUtil;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public enum TimeDimensionEnum { public enum TimeDimensionEnum {
DAY("sys_imp_date", "数据日期"), DAY("sys_imp_date", "数据日期"),
WEEK("sys_imp_week", "数据日期_周"), WEEK("sys_imp_week", "数据日期_周"),
MONTH("sys_imp_month", "数据日期_月"); MONTH("sys_imp_month", "数据日期_月");
@@ -28,10 +27,6 @@ public enum TimeDimensionEnum {
return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getName).collect(Collectors.toList()); return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getName).collect(Collectors.toList());
} }
public static Set<String> getChNameSet() {
return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getChName).collect(Collectors.toSet());
}
public String getName() { public String getName() {
return name; return name;
} }
@@ -42,8 +37,9 @@ public enum TimeDimensionEnum {
/** /**
* Determine if a time dimension field is included in a Chinese text field * Determine if a time dimension field is included in a Chinese text field
*
* @param fields field * @param fields field
* @return true/fase * @return true/false
*/ */
public static boolean containsZhTimeDimension(List<String> fields) { public static boolean containsZhTimeDimension(List<String> fields) {
if (CollectionUtil.isEmpty(fields)) { if (CollectionUtil.isEmpty(fields)) {