mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
(improvement)(chat) linkingValues is not passed to llm and optimize SchemaCorrector code (#378)
This commit is contained in:
@@ -4,7 +4,6 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
import com.tencent.supersonic.chat.parser.llm.s2sql.ParseResult;
|
import com.tencent.supersonic.chat.parser.llm.s2sql.ParseResult;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
@@ -23,9 +22,9 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
String sql = SqlParserReplaceHelper.replaceFunction(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
|
||||||
AggregateEnum.getAggregateEnum());
|
correctAggFunction(semanticParseInfo);
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql);
|
|
||||||
replaceAlias(semanticParseInfo);
|
replaceAlias(semanticParseInfo);
|
||||||
|
|
||||||
updateFieldNameByLinkingValue(semanticParseInfo);
|
updateFieldNameByLinkingValue(semanticParseInfo);
|
||||||
@@ -35,6 +34,13 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
correctFieldName(semanticParseInfo);
|
correctFieldName(semanticParseInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
||||||
|
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
|
||||||
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
String sql = SqlParserReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
||||||
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
|
}
|
||||||
|
|
||||||
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||||
@@ -74,8 +80,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
|
if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
LLMReq llmReq = parseResult.getLlmReq();
|
return parseResult.getLinkingValues();
|
||||||
return llmReq.getLinking();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
|||||||
String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId());
|
String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId());
|
||||||
if (StringUtils.isNotBlank(currentDate)) {
|
if (StringUtils.isNotBlank(currentDate)) {
|
||||||
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
|
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||||
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate);
|
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, TimeDimensionEnum.DAY.getChName(),
|
||||||
|
currentDate);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ public class LLMRequestService {
|
|||||||
return llmParserTool.orElse(null);
|
return llmParserTool.orElse(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public LLMReq getLlmReq(QueryContext queryCtx, Long modelId) {
|
public LLMReq getLlmReq(QueryContext queryCtx, Long modelId, List<ElementValue> linkingValues) {
|
||||||
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||||
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
||||||
String queryText = queryCtx.getRequest().getQueryText();
|
String queryText = queryCtx.getRequest().getQueryText();
|
||||||
@@ -120,7 +120,7 @@ public class LLMRequestService {
|
|||||||
llmSchema.setModelName(modelIdToName.get(modelId));
|
llmSchema.setModelName(modelIdToName.get(modelId));
|
||||||
llmSchema.setDomainName(modelIdToName.get(modelId));
|
llmSchema.setDomainName(modelIdToName.get(modelId));
|
||||||
|
|
||||||
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig);
|
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, llmParserConfig);
|
||||||
|
|
||||||
String priorExts = getPriorExts(modelId, fieldNameList);
|
String priorExts = getPriorExts(modelId, fieldNameList);
|
||||||
llmReq.setPriorExts(priorExts);
|
llmReq.setPriorExts(priorExts);
|
||||||
@@ -131,7 +131,7 @@ public class LLMRequestService {
|
|||||||
|
|
||||||
List<ElementValue> linking = new ArrayList<>();
|
List<ElementValue> linking = new ArrayList<>();
|
||||||
if (optimizationConfig.isUseLinkingValueSwitch()) {
|
if (optimizationConfig.isUseLinkingValueSwitch()) {
|
||||||
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
|
linking.addAll(linkingValues);
|
||||||
}
|
}
|
||||||
llmReq.setLinking(linking);
|
llmReq.setLinking(linking);
|
||||||
|
|
||||||
@@ -155,7 +155,7 @@ public class LLMRequestService {
|
|||||||
LLMResp.class);
|
LLMResp.class);
|
||||||
|
|
||||||
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
||||||
System.currentTimeMillis() - startTime, url.toString(), entity, responseEntity.getBody());
|
System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody());
|
||||||
return responseEntity.getBody();
|
return responseEntity.getBody();
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("requestLLM error", e);
|
log.error("requestLLM error", e);
|
||||||
@@ -163,12 +163,11 @@ public class LLMRequestService {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema,
|
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, LLMParserConfig llmParserConfig) {
|
||||||
LLMParserConfig llmParserConfig) {
|
|
||||||
|
|
||||||
Set<String> results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig);
|
Set<String> results = getTopNFieldNames(modelId, llmParserConfig);
|
||||||
|
|
||||||
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelId, semanticSchema);
|
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelId);
|
||||||
|
|
||||||
results.addAll(fieldNameList);
|
results.addAll(fieldNameList);
|
||||||
return new ArrayList<>(results);
|
return new ArrayList<>(results);
|
||||||
@@ -210,8 +209,8 @@ public class LLMRequestService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected List<ElementValue> getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
|
protected List<ElementValue> getValueList(QueryContext queryCtx, Long modelId) {
|
||||||
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
|
Map<Long, String> itemIdToName = getItemIdToName(modelId);
|
||||||
|
|
||||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
|
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
|
||||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
@@ -233,14 +232,15 @@ public class LLMRequestService {
|
|||||||
return new ArrayList<>(valueMatches);
|
return new ArrayList<>(valueMatches);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {
|
protected Map<Long, String> getItemIdToName(Long modelId) {
|
||||||
|
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||||
return semanticSchema.getDimensions(modelId).stream()
|
return semanticSchema.getDimensions(modelId).stream()
|
||||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private Set<String> getTopNFieldNames(Long modelId, SemanticSchema semanticSchema,
|
private Set<String> getTopNFieldNames(Long modelId, LLMParserConfig llmParserConfig) {
|
||||||
LLMParserConfig llmParserConfig) {
|
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||||
Set<String> results = semanticSchema.getDimensions(modelId).stream()
|
Set<String> results = semanticSchema.getDimensions(modelId).stream()
|
||||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||||
.limit(llmParserConfig.getDimensionTopN())
|
.limit(llmParserConfig.getDimensionTopN())
|
||||||
@@ -258,8 +258,8 @@ public class LLMRequestService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
|
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long modelId) {
|
||||||
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
|
Map<Long, String> itemIdToName = getItemIdToName(modelId);
|
||||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
|
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
|
||||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
return new HashSet<>();
|
return new HashSet<>();
|
||||||
|
|||||||
@@ -6,8 +6,10 @@ import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -36,7 +38,8 @@ public class LLMS2SQLParser implements SemanticParser {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
//4.construct a request, call the API for the large model, and retrieve the results.
|
//4.construct a request, call the API for the large model, and retrieve the results.
|
||||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, modelId);
|
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, modelId);
|
||||||
|
LLMReq llmReq = requestService.getLlmReq(queryCtx, modelId, linkingValues);
|
||||||
LLMResp llmResp = requestService.requestLLM(llmReq, modelId);
|
LLMResp llmResp = requestService.requestLLM(llmReq, modelId);
|
||||||
|
|
||||||
if (Objects.isNull(llmResp)) {
|
if (Objects.isNull(llmResp)) {
|
||||||
@@ -49,7 +52,9 @@ public class LLMS2SQLParser implements SemanticParser {
|
|||||||
.modelId(modelId)
|
.modelId(modelId)
|
||||||
.commonAgentTool(commonAgentTool)
|
.commonAgentTool(commonAgentTool)
|
||||||
.llmReq(llmReq)
|
.llmReq(llmReq)
|
||||||
.llmResp(llmResp).build();
|
.llmResp(llmResp)
|
||||||
|
.linkingValues(linkingValues)
|
||||||
|
.build();
|
||||||
|
|
||||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,9 @@ package com.tencent.supersonic.chat.parser.llm.s2sql;
|
|||||||
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
|
import java.util.List;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -24,4 +26,6 @@ public class ParseResult {
|
|||||||
private QueryReq request;
|
private QueryReq request;
|
||||||
|
|
||||||
private CommonAgentTool commonAgentTool;
|
private CommonAgentTool commonAgentTool;
|
||||||
|
|
||||||
|
private List<ElementValue> linkingValues;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +1,15 @@
|
|||||||
package com.tencent.supersonic.common.pojo.enums;
|
package com.tencent.supersonic.common.pojo.enums;
|
||||||
|
|
||||||
import cn.hutool.core.collection.CollectionUtil;
|
import cn.hutool.core.collection.CollectionUtil;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
|
||||||
public enum TimeDimensionEnum {
|
public enum TimeDimensionEnum {
|
||||||
|
|
||||||
DAY("sys_imp_date", "数据日期"),
|
DAY("sys_imp_date", "数据日期"),
|
||||||
|
|
||||||
WEEK("sys_imp_week", "数据日期_周"),
|
WEEK("sys_imp_week", "数据日期_周"),
|
||||||
|
|
||||||
MONTH("sys_imp_month", "数据日期_月");
|
MONTH("sys_imp_month", "数据日期_月");
|
||||||
@@ -28,10 +27,6 @@ public enum TimeDimensionEnum {
|
|||||||
return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getName).collect(Collectors.toList());
|
return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getName).collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Set<String> getChNameSet() {
|
|
||||||
return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getChName).collect(Collectors.toSet());
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getName() {
|
public String getName() {
|
||||||
return name;
|
return name;
|
||||||
}
|
}
|
||||||
@@ -42,8 +37,9 @@ public enum TimeDimensionEnum {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Determine if a time dimension field is included in a Chinese text field
|
* Determine if a time dimension field is included in a Chinese text field
|
||||||
|
*
|
||||||
* @param fields field
|
* @param fields field
|
||||||
* @return true/fase
|
* @return true/false
|
||||||
*/
|
*/
|
||||||
public static boolean containsZhTimeDimension(List<String> fields) {
|
public static boolean containsZhTimeDimension(List<String> fields) {
|
||||||
if (CollectionUtil.isEmpty(fields)) {
|
if (CollectionUtil.isEmpty(fields)) {
|
||||||
|
|||||||
Reference in New Issue
Block a user