(improvement)(chat) linkingValues is not passed to llm and optimize SchemaCorrector code (#378)

This commit is contained in:
lexluo09
2023-11-14 11:12:27 +08:00
committed by GitHub
parent 1ad2c5402b
commit 74ed269544
6 changed files with 42 additions and 31 deletions

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.parser.llm.s2sql.ParseResult;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.JsonUtil;
@@ -23,9 +22,9 @@ public class SchemaCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
String sql = SqlParserReplaceHelper.replaceFunction(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
AggregateEnum.getAggregateEnum());
semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql);
correctAggFunction(semanticParseInfo);
replaceAlias(semanticParseInfo);
updateFieldNameByLinkingValue(semanticParseInfo);
@@ -35,6 +34,13 @@ public class SchemaCorrector extends BaseSemanticCorrector {
correctFieldName(semanticParseInfo);
}
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlParserReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
sqlInfo.setCorrectS2SQL(sql);
}
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
@@ -74,8 +80,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
return null;
}
LLMReq llmReq = parseResult.getLlmReq();
return llmReq.getLinking();
return parseResult.getLinkingValues();
}

View File

@@ -74,7 +74,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId());
if (StringUtils.isNotBlank(currentDate)) {
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate);
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, TimeDimensionEnum.DAY.getChName(),
currentDate);
}
}
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);

View File

@@ -104,7 +104,7 @@ public class LLMRequestService {
return llmParserTool.orElse(null);
}
public LLMReq getLlmReq(QueryContext queryCtx, Long modelId) {
public LLMReq getLlmReq(QueryContext queryCtx, Long modelId, List<ElementValue> linkingValues) {
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
String queryText = queryCtx.getRequest().getQueryText();
@@ -120,7 +120,7 @@ public class LLMRequestService {
llmSchema.setModelName(modelIdToName.get(modelId));
llmSchema.setDomainName(modelIdToName.get(modelId));
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig);
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, llmParserConfig);
String priorExts = getPriorExts(modelId, fieldNameList);
llmReq.setPriorExts(priorExts);
@@ -131,7 +131,7 @@ public class LLMRequestService {
List<ElementValue> linking = new ArrayList<>();
if (optimizationConfig.isUseLinkingValueSwitch()) {
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
linking.addAll(linkingValues);
}
llmReq.setLinking(linking);
@@ -155,7 +155,7 @@ public class LLMRequestService {
LLMResp.class);
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
System.currentTimeMillis() - startTime, url.toString(), entity, responseEntity.getBody());
System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody());
return responseEntity.getBody();
} catch (Exception e) {
log.error("requestLLM error", e);
@@ -163,12 +163,11 @@ public class LLMRequestService {
return null;
}
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema,
LLMParserConfig llmParserConfig) {
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, LLMParserConfig llmParserConfig) {
Set<String> results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig);
Set<String> results = getTopNFieldNames(modelId, llmParserConfig);
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelId, semanticSchema);
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelId);
results.addAll(fieldNameList);
return new ArrayList<>(results);
@@ -210,8 +209,8 @@ public class LLMRequestService {
}
protected List<ElementValue> getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
protected List<ElementValue> getValueList(QueryContext queryCtx, Long modelId) {
Map<Long, String> itemIdToName = getItemIdToName(modelId);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
if (CollectionUtils.isEmpty(matchedElements)) {
@@ -233,14 +232,15 @@ public class LLMRequestService {
return new ArrayList<>(valueMatches);
}
protected Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {
protected Map<Long, String> getItemIdToName(Long modelId) {
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
return semanticSchema.getDimensions(modelId).stream()
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
}
private Set<String> getTopNFieldNames(Long modelId, SemanticSchema semanticSchema,
LLMParserConfig llmParserConfig) {
private Set<String> getTopNFieldNames(Long modelId, LLMParserConfig llmParserConfig) {
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
Set<String> results = semanticSchema.getDimensions(modelId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
@@ -258,8 +258,8 @@ public class LLMRequestService {
}
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long modelId) {
Map<Long, String> itemIdToName = getItemIdToName(modelId);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
if (CollectionUtils.isEmpty(matchedElements)) {
return new HashSet<>();

View File

@@ -6,8 +6,10 @@ import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
@@ -36,7 +38,8 @@ public class LLMS2SQLParser implements SemanticParser {
return;
}
//4.construct a request, call the API for the large model, and retrieve the results.
LLMReq llmReq = requestService.getLlmReq(queryCtx, modelId);
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, modelId);
LLMReq llmReq = requestService.getLlmReq(queryCtx, modelId, linkingValues);
LLMResp llmResp = requestService.requestLLM(llmReq, modelId);
if (Objects.isNull(llmResp)) {
@@ -49,7 +52,9 @@ public class LLMS2SQLParser implements SemanticParser {
.modelId(modelId)
.commonAgentTool(commonAgentTool)
.llmReq(llmReq)
.llmResp(llmResp).build();
.llmResp(llmResp)
.linkingValues(linkingValues)
.build();
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);

View File

@@ -3,7 +3,9 @@ package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
@@ -24,4 +26,6 @@ public class ParseResult {
private QueryReq request;
private CommonAgentTool commonAgentTool;
private List<ElementValue> linkingValues;
}

View File

@@ -1,16 +1,15 @@
package com.tencent.supersonic.common.pojo.enums;
import cn.hutool.core.collection.CollectionUtil;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
public enum TimeDimensionEnum {
DAY("sys_imp_date", "数据日期"),
WEEK("sys_imp_week", "数据日期_周"),
MONTH("sys_imp_month", "数据日期_月");
@@ -28,10 +27,6 @@ public enum TimeDimensionEnum {
return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getName).collect(Collectors.toList());
}
public static Set<String> getChNameSet() {
return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getChName).collect(Collectors.toSet());
}
public String getName() {
return name;
}
@@ -42,8 +37,9 @@ public enum TimeDimensionEnum {
/**
* Determine if a time dimension field is included in a Chinese text field
*
* @param fields field
* @return true/fase
* @return true/false
*/
public static boolean containsZhTimeDimension(List<String> fields) {
if (CollectionUtil.isEmpty(fields)) {