mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(chat) linkingValues is not passed to llm and optimize SchemaCorrector code (#378)
This commit is contained in:
@@ -4,7 +4,6 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.parser.llm.s2sql.ParseResult;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
@@ -23,9 +22,9 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
String sql = SqlParserReplaceHelper.replaceFunction(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||
AggregateEnum.getAggregateEnum());
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql);
|
||||
|
||||
correctAggFunction(semanticParseInfo);
|
||||
|
||||
replaceAlias(semanticParseInfo);
|
||||
|
||||
updateFieldNameByLinkingValue(semanticParseInfo);
|
||||
@@ -35,6 +34,13 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
correctFieldName(semanticParseInfo);
|
||||
}
|
||||
|
||||
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||
@@ -74,8 +80,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
|
||||
return null;
|
||||
}
|
||||
LLMReq llmReq = parseResult.getLlmReq();
|
||||
return llmReq.getLinking();
|
||||
return parseResult.getLinkingValues();
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -74,7 +74,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId());
|
||||
if (StringUtils.isNotBlank(currentDate)) {
|
||||
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate);
|
||||
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, TimeDimensionEnum.DAY.getChName(),
|
||||
currentDate);
|
||||
}
|
||||
}
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
|
||||
@@ -104,7 +104,7 @@ public class LLMRequestService {
|
||||
return llmParserTool.orElse(null);
|
||||
}
|
||||
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, Long modelId) {
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, Long modelId, List<ElementValue> linkingValues) {
|
||||
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
||||
String queryText = queryCtx.getRequest().getQueryText();
|
||||
@@ -120,7 +120,7 @@ public class LLMRequestService {
|
||||
llmSchema.setModelName(modelIdToName.get(modelId));
|
||||
llmSchema.setDomainName(modelIdToName.get(modelId));
|
||||
|
||||
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig);
|
||||
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, llmParserConfig);
|
||||
|
||||
String priorExts = getPriorExts(modelId, fieldNameList);
|
||||
llmReq.setPriorExts(priorExts);
|
||||
@@ -131,7 +131,7 @@ public class LLMRequestService {
|
||||
|
||||
List<ElementValue> linking = new ArrayList<>();
|
||||
if (optimizationConfig.isUseLinkingValueSwitch()) {
|
||||
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
|
||||
linking.addAll(linkingValues);
|
||||
}
|
||||
llmReq.setLinking(linking);
|
||||
|
||||
@@ -155,7 +155,7 @@ public class LLMRequestService {
|
||||
LLMResp.class);
|
||||
|
||||
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
||||
System.currentTimeMillis() - startTime, url.toString(), entity, responseEntity.getBody());
|
||||
System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody());
|
||||
return responseEntity.getBody();
|
||||
} catch (Exception e) {
|
||||
log.error("requestLLM error", e);
|
||||
@@ -163,12 +163,11 @@ public class LLMRequestService {
|
||||
return null;
|
||||
}
|
||||
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, LLMParserConfig llmParserConfig) {
|
||||
|
||||
Set<String> results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig);
|
||||
Set<String> results = getTopNFieldNames(modelId, llmParserConfig);
|
||||
|
||||
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelId, semanticSchema);
|
||||
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelId);
|
||||
|
||||
results.addAll(fieldNameList);
|
||||
return new ArrayList<>(results);
|
||||
@@ -210,8 +209,8 @@ public class LLMRequestService {
|
||||
}
|
||||
|
||||
|
||||
protected List<ElementValue> getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
|
||||
protected List<ElementValue> getValueList(QueryContext queryCtx, Long modelId) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(modelId);
|
||||
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
@@ -233,14 +232,15 @@ public class LLMRequestService {
|
||||
return new ArrayList<>(valueMatches);
|
||||
}
|
||||
|
||||
protected Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {
|
||||
protected Map<Long, String> getItemIdToName(Long modelId) {
|
||||
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||
return semanticSchema.getDimensions(modelId).stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
|
||||
private Set<String> getTopNFieldNames(Long modelId, SemanticSchema semanticSchema,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
private Set<String> getTopNFieldNames(Long modelId, LLMParserConfig llmParserConfig) {
|
||||
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||
Set<String> results = semanticSchema.getDimensions(modelId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getDimensionTopN())
|
||||
@@ -258,8 +258,8 @@ public class LLMRequestService {
|
||||
}
|
||||
|
||||
|
||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
|
||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long modelId) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(modelId);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new HashSet<>();
|
||||
|
||||
@@ -6,8 +6,10 @@ import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -36,7 +38,8 @@ public class LLMS2SQLParser implements SemanticParser {
|
||||
return;
|
||||
}
|
||||
//4.construct a request, call the API for the large model, and retrieve the results.
|
||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, modelId);
|
||||
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, modelId);
|
||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, modelId, linkingValues);
|
||||
LLMResp llmResp = requestService.requestLLM(llmReq, modelId);
|
||||
|
||||
if (Objects.isNull(llmResp)) {
|
||||
@@ -49,7 +52,9 @@ public class LLMS2SQLParser implements SemanticParser {
|
||||
.modelId(modelId)
|
||||
.commonAgentTool(commonAgentTool)
|
||||
.llmReq(llmReq)
|
||||
.llmResp(llmResp).build();
|
||||
.llmResp(llmResp)
|
||||
.linkingValues(linkingValues)
|
||||
.build();
|
||||
|
||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||
|
||||
|
||||
@@ -3,7 +3,9 @@ package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import java.util.List;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
@@ -24,4 +26,6 @@ public class ParseResult {
|
||||
private QueryReq request;
|
||||
|
||||
private CommonAgentTool commonAgentTool;
|
||||
|
||||
private List<ElementValue> linkingValues;
|
||||
}
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
package com.tencent.supersonic.common.pojo.enums;
|
||||
|
||||
import cn.hutool.core.collection.CollectionUtil;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
||||
public enum TimeDimensionEnum {
|
||||
|
||||
DAY("sys_imp_date", "数据日期"),
|
||||
|
||||
WEEK("sys_imp_week", "数据日期_周"),
|
||||
|
||||
MONTH("sys_imp_month", "数据日期_月");
|
||||
@@ -28,10 +27,6 @@ public enum TimeDimensionEnum {
|
||||
return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getName).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public static Set<String> getChNameSet() {
|
||||
return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getChName).collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
@@ -42,8 +37,9 @@ public enum TimeDimensionEnum {
|
||||
|
||||
/**
|
||||
* Determine if a time dimension field is included in a Chinese text field
|
||||
*
|
||||
* @param fields field
|
||||
* @return true/fase
|
||||
* @return true/false
|
||||
*/
|
||||
public static boolean containsZhTimeDimension(List<String> fields) {
|
||||
if (CollectionUtil.isEmpty(fields)) {
|
||||
|
||||
Reference in New Issue
Block a user