diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticInterpreter.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticInterpreter.java index cfa20665f..0fa9a0cea 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticInterpreter.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticInterpreter.java @@ -11,6 +11,7 @@ import com.tencent.supersonic.semantic.api.model.response.DimensionResp; import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import com.tencent.supersonic.semantic.api.model.response.ModelResp; import com.tencent.supersonic.semantic.api.model.response.MetricResp; +import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq; import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq; @@ -57,4 +58,6 @@ public interface SemanticInterpreter { ExplainResp explain(ExplainSqlReq explainSqlReq, User user) throws Exception; + List fetchModelSchema(List ids, Boolean cacheEnable); + } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMRequestService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMRequestService.java new file mode 100644 index 000000000..8a720988f --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMRequestService.java @@ -0,0 +1,275 @@ +package com.tencent.supersonic.chat.parser.llm.s2ql; + +import com.tencent.supersonic.chat.agent.tool.AgentToolType; +import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; +import com.tencent.supersonic.chat.api.component.SemanticInterpreter; +import com.tencent.supersonic.chat.api.pojo.ChatContext; +import com.tencent.supersonic.chat.api.pojo.QueryContext; +import com.tencent.supersonic.chat.api.pojo.SchemaElement; +import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; +import com.tencent.supersonic.chat.api.pojo.SchemaElementType; +import com.tencent.supersonic.chat.api.pojo.SemanticSchema; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import com.tencent.supersonic.chat.config.LLMParserConfig; +import com.tencent.supersonic.chat.parser.SatisfactionChecker; +import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq; +import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq.ElementValue; +import com.tencent.supersonic.chat.query.llm.s2ql.LLMResp; +import com.tencent.supersonic.chat.service.AgentService; +import com.tencent.supersonic.chat.utils.ComponentFactory; +import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum; +import com.tencent.supersonic.common.util.DateUtils; +import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.knowledge.service.SchemaService; +import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem; +import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; +import org.springframework.web.client.RestTemplate; + +@Slf4j +@Service +public class LLMRequestService { + + protected SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer(); + @Autowired + private LLMParserConfig llmParserConfig; + @Autowired + private AgentService agentService; + @Autowired + private SchemaService schemaService; + + @Autowired + private RestTemplate restTemplate; + + public boolean check(QueryContext queryCtx) { + QueryReq request = queryCtx.getRequest(); + if (StringUtils.isEmpty(llmParserConfig.getUrl())) { + log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMS2QLParser.class, llmParserConfig); + return true; + } + if (SatisfactionChecker.check(queryCtx)) { + log.info("skip {}, queryText:{}", LLMS2QLParser.class, request.getQueryText()); + return true; + } + return false; + } + + public Long getModelId(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) { + Set distinctModelIds = agentService.getModelIds(agentId, AgentToolType.LLM_S2QL); + if (agentService.containsAllModel(distinctModelIds)) { + distinctModelIds = new HashSet<>(); + } + ModelResolver modelResolver = ComponentFactory.getModelResolver(); + Long modelId = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds); + log.info("resolve modelId:{},llmParser Models:{}", modelId, distinctModelIds); + return modelId; + } + + public CommonAgentTool getParserTool(QueryReq request, Long modelId) { + List commonAgentTools = agentService.getParserTools(request.getAgentId(), + AgentToolType.LLM_S2QL); + Optional llmParserTool = commonAgentTools.stream() + .filter(tool -> { + List modelIds = tool.getModelIds(); + if (agentService.containsAllModel(new HashSet<>(modelIds))) { + return true; + } + return modelIds.contains(modelId); + }) + .findFirst(); + return llmParserTool.orElse(null); + } + + public LLMReq getLlmReq(QueryContext queryCtx, Long modelId) { + SemanticSchema semanticSchema = schemaService.getSemanticSchema(); + Map modelIdToName = semanticSchema.getModelIdToName(); + String queryText = queryCtx.getRequest().getQueryText(); + + LLMReq llmReq = new LLMReq(); + llmReq.setQueryText(queryText); + + LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema(); + llmSchema.setModelName(modelIdToName.get(modelId)); + llmSchema.setDomainName(modelIdToName.get(modelId)); + + List fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig); + + String priorExts = getPriorExts(modelId, fieldNameList); + llmReq.setPriorExts(priorExts); + + fieldNameList.add(DateUtils.DATE_FIELD); + llmSchema.setFieldNameList(fieldNameList); + llmReq.setSchema(llmSchema); + + List linking = new ArrayList<>(); + linking.addAll(getValueList(queryCtx, modelId, semanticSchema)); + llmReq.setLinking(linking); + + String currentDate = S2QLDateHelper.getReferenceDate(modelId); + if (StringUtils.isEmpty(currentDate)) { + currentDate = DateUtils.getBeforeDate(0); + } + llmReq.setCurrentDate(currentDate); + return llmReq; + } + + public LLMResp requestLLM(LLMReq llmReq, Long modelId) { + String questUrl = llmParserConfig.getUrl() + llmParserConfig.getQueryToSqlPath(); + long startTime = System.currentTimeMillis(); + log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq); + try { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + HttpEntity entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers); + ResponseEntity responseEntity = restTemplate.exchange(questUrl, HttpMethod.POST, entity, + LLMResp.class); + + log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}", + System.currentTimeMillis() - startTime, questUrl, entity, responseEntity.getBody()); + return responseEntity.getBody(); + } catch (Exception e) { + log.error("requestLLM error", e); + } + return null; + } + + protected List getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema, + LLMParserConfig llmParserConfig) { + + Set results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig); + + Set fieldNameList = getMatchedFieldNames(queryCtx, modelId, semanticSchema); + + results.addAll(fieldNameList); + return new ArrayList<>(results); + } + + private String getPriorExts(Long modelId, List fieldNameList) { + StringBuilder extraInfoSb = new StringBuilder(); + List modelSchemaResps = semanticInterpreter.fetchModelSchema( + Collections.singletonList(modelId), true); + if (!CollectionUtils.isEmpty(modelSchemaResps)) { + + ModelSchemaResp modelSchemaResp = modelSchemaResps.get(0); + Map fieldNameToDataFormatType = modelSchemaResp.getMetrics() + .stream().filter(metricSchemaResp -> Objects.nonNull(metricSchemaResp.getDataFormatType())) + .flatMap(metricSchemaResp -> { + Set> result = new HashSet<>(); + String dataFormatType = metricSchemaResp.getDataFormatType(); + result.add(Pair.of(metricSchemaResp.getName(), dataFormatType)); + List aliasList = SchemaItem.getAliasList(metricSchemaResp.getAlias()); + if (!CollectionUtils.isEmpty(aliasList)) { + for (String alias : aliasList) { + result.add(Pair.of(alias, dataFormatType)); + } + } + return result.stream(); + }) + .collect(Collectors.toMap(a -> a.getLeft(), a -> a.getRight(), (k1, k2) -> k1)); + + for (String fieldName : fieldNameList) { + String dataFormatType = fieldNameToDataFormatType.get(fieldName); + if (DataFormatTypeEnum.DECIMAL.getName().equalsIgnoreCase(dataFormatType) + || DataFormatTypeEnum.PERCENT.getName().equalsIgnoreCase(dataFormatType)) { + String format = String.format("%s 的字段类型是 %s", fieldName, "小数; "); + extraInfoSb.append(format); + } + } + } + return extraInfoSb.toString(); + } + + + protected List getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) { + Map itemIdToName = getItemIdToName(modelId, semanticSchema); + + List matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId); + if (CollectionUtils.isEmpty(matchedElements)) { + return new ArrayList<>(); + } + Set valueMatches = matchedElements + .stream() + .filter(elementMatch -> !elementMatch.isInherited()) + .filter(schemaElementMatch -> { + SchemaElementType type = schemaElementMatch.getElement().getType(); + return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type); + }) + .map(elementMatch -> { + ElementValue elementValue = new ElementValue(); + elementValue.setFieldName(itemIdToName.get(elementMatch.getElement().getId())); + elementValue.setFieldValue(elementMatch.getWord()); + return elementValue; + }).collect(Collectors.toSet()); + return new ArrayList<>(valueMatches); + } + + protected Map getItemIdToName(Long modelId, SemanticSchema semanticSchema) { + return semanticSchema.getDimensions(modelId).stream() + .collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2)); + } + + + private Set getTopNFieldNames(Long modelId, SemanticSchema semanticSchema, + LLMParserConfig llmParserConfig) { + Set results = semanticSchema.getDimensions(modelId).stream() + .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) + .limit(llmParserConfig.getDimensionTopN()) + .map(entry -> entry.getName()) + .collect(Collectors.toSet()); + + Set metrics = semanticSchema.getMetrics(modelId).stream() + .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) + .limit(llmParserConfig.getMetricTopN()) + .map(entry -> entry.getName()) + .collect(Collectors.toSet()); + + results.addAll(metrics); + return results; + } + + + protected Set getMatchedFieldNames(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) { + Map itemIdToName = getItemIdToName(modelId, semanticSchema); + List matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId); + if (CollectionUtils.isEmpty(matchedElements)) { + return new HashSet<>(); + } + Set fieldNameList = matchedElements.stream() + .filter(schemaElementMatch -> { + SchemaElementType elementType = schemaElementMatch.getElement().getType(); + return SchemaElementType.METRIC.equals(elementType) + || SchemaElementType.DIMENSION.equals(elementType) + || SchemaElementType.VALUE.equals(elementType); + }) + .map(schemaElementMatch -> { + SchemaElement element = schemaElementMatch.getElement(); + SchemaElementType elementType = element.getType(); + if (SchemaElementType.VALUE.equals(elementType)) { + return itemIdToName.get(element.getId()); + } + return schemaElementMatch.getWord(); + }) + .collect(Collectors.toSet()); + return fieldNameList; + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMResponseService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMResponseService.java new file mode 100644 index 000000000..8d0ee873c --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMResponseService.java @@ -0,0 +1,261 @@ +package com.tencent.supersonic.chat.parser.llm.s2ql; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; +import com.tencent.supersonic.chat.api.component.SemanticCorrector; +import com.tencent.supersonic.chat.api.pojo.QueryContext; +import com.tencent.supersonic.chat.api.pojo.SchemaElement; +import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticSchema; +import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; +import com.tencent.supersonic.chat.query.QueryManager; +import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery; +import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; +import com.tencent.supersonic.chat.utils.ComponentFactory; +import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.common.pojo.DateConf; +import com.tencent.supersonic.common.pojo.DateConf.DateMode; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.DateUtils; +import com.tencent.supersonic.common.util.jsqlparser.FilterExpression; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; +import com.tencent.supersonic.knowledge.service.SchemaService; +import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.tuple.Pair; +import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; + +@Slf4j +@Service +public class LLMResponseService { + + public void addParseInfo(QueryContext queryCtx, ParseResult parseResult, String sql, Double weight) { + + SemanticParseInfo parseInfo = getParseInfo(queryCtx, parseResult, weight); + + SemanticCorrectInfo semanticCorrectInfo = getCorrectorSql(queryCtx, parseInfo, sql); + + parseInfo.getSqlInfo().setLogicSql(semanticCorrectInfo.getSql()); + + updateParseInfo(semanticCorrectInfo, parseResult.getModelId(), parseInfo); + } + + private Set getElements(Long modelId, List allFields, List elements) { + return elements.stream() + .filter(schemaElement -> modelId.equals(schemaElement.getModel()) + && allFields.contains(schemaElement.getName()) + ).collect(Collectors.toSet()); + } + + private List getFieldsExceptDate(List allFields) { + if (CollectionUtils.isEmpty(allFields)) { + return new ArrayList<>(); + } + return allFields.stream() + .filter(entry -> !DateUtils.DATE_FIELD.equalsIgnoreCase(entry)) + .collect(Collectors.toList()); + } + + public void updateParseInfo(SemanticCorrectInfo semanticCorrectInfo, Long modelId, SemanticParseInfo parseInfo) { + + String correctorSql = semanticCorrectInfo.getSql(); + parseInfo.getSqlInfo().setLogicSql(correctorSql); + + List expressions = SqlParserSelectHelper.getFilterExpression(correctorSql); + //set dataInfo + try { + if (!CollectionUtils.isEmpty(expressions)) { + DateConf dateInfo = getDateInfo(expressions); + parseInfo.setDateInfo(dateInfo); + } + } catch (Exception e) { + log.error("set dateInfo error :", e); + } + + //set filter + try { + Map fieldNameToElement = getNameToElement(modelId); + List result = getDimensionFilter(fieldNameToElement, expressions); + parseInfo.getDimensionFilters().addAll(result); + } catch (Exception e) { + log.error("set dimensionFilter error :", e); + } + + SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); + + if (Objects.isNull(semanticSchema)) { + return; + } + List allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(semanticCorrectInfo.getSql())); + + Set metrics = getElements(modelId, allFields, semanticSchema.getMetrics()); + parseInfo.setMetrics(metrics); + + if (SqlParserSelectFunctionHelper.hasAggregateFunction(semanticCorrectInfo.getSql())) { + parseInfo.setNativeQuery(false); + List groupByFields = SqlParserSelectHelper.getGroupByFields(semanticCorrectInfo.getSql()); + List groupByDimensions = getFieldsExceptDate(groupByFields); + parseInfo.setDimensions(getElements(modelId, groupByDimensions, semanticSchema.getDimensions())); + } else { + parseInfo.setNativeQuery(true); + List selectFields = SqlParserSelectHelper.getSelectFields(semanticCorrectInfo.getSql()); + List selectDimensions = getFieldsExceptDate(selectFields); + parseInfo.setDimensions(getElements(modelId, selectDimensions, semanticSchema.getDimensions())); + } + } + + private List getDimensionFilter(Map fieldNameToElement, + List filterExpressions) { + List result = Lists.newArrayList(); + for (FilterExpression expression : filterExpressions) { + QueryFilter dimensionFilter = new QueryFilter(); + dimensionFilter.setValue(expression.getFieldValue()); + SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName()); + if (Objects.isNull(schemaElement)) { + continue; + } + dimensionFilter.setName(schemaElement.getName()); + dimensionFilter.setBizName(schemaElement.getBizName()); + dimensionFilter.setElementID(schemaElement.getId()); + + FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator()); + dimensionFilter.setOperator(operatorEnum); + dimensionFilter.setFunction(expression.getFunction()); + result.add(dimensionFilter); + } + return result; + } + + private DateConf getDateInfo(List filterExpressions) { + List dateExpressions = filterExpressions.stream() + .filter(expression -> DateUtils.DATE_FIELD.equalsIgnoreCase(expression.getFieldName())) + .collect(Collectors.toList()); + if (CollectionUtils.isEmpty(dateExpressions)) { + return new DateConf(); + } + DateConf dateInfo = new DateConf(); + dateInfo.setDateMode(DateMode.BETWEEN); + FilterExpression firstExpression = dateExpressions.get(0); + + FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator()); + if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) { + dateInfo.setStartDate(firstExpression.getFieldValue().toString()); + dateInfo.setEndDate(firstExpression.getFieldValue().toString()); + dateInfo.setDateMode(DateMode.BETWEEN); + return dateInfo; + } + if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN, + FilterOperatorEnum.GREATER_THAN_EQUALS)) { + dateInfo.setStartDate(firstExpression.getFieldValue().toString()); + if (hasSecondDate(dateExpressions)) { + dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString()); + } + } + if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN, + FilterOperatorEnum.MINOR_THAN_EQUALS)) { + dateInfo.setEndDate(firstExpression.getFieldValue().toString()); + if (hasSecondDate(dateExpressions)) { + dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString()); + } + } + return dateInfo; + } + + private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator, + FilterOperatorEnum... operatorEnums) { + return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue())); + } + + private boolean hasSecondDate(List dateExpressions) { + return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue()); + } + + private SemanticCorrectInfo getCorrectorSql(QueryContext queryCtx, SemanticParseInfo parseInfo, String sql) { + + SemanticCorrectInfo correctInfo = SemanticCorrectInfo.builder() + .queryFilters(queryCtx.getRequest().getQueryFilters()).sql(sql) + .parseInfo(parseInfo).build(); + + List corrections = ComponentFactory.getSqlCorrections(); + + corrections.forEach(correction -> { + try { + correction.correct(correctInfo); + log.info("sqlCorrection:{} sql:{}", correction.getClass().getSimpleName(), correctInfo.getSql()); + } catch (Exception e) { + log.error(String.format("correct error,correctInfo:%s", correctInfo), e); + } + }); + return correctInfo; + } + + private SemanticParseInfo getParseInfo(QueryContext queryCtx, ParseResult parseResult, Double weight) { + if (Objects.isNull(weight)) { + weight = 0D; + } + PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(S2QLQuery.QUERY_MODE); + SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); + Long modelId = parseResult.getModelId(); + CommonAgentTool commonAgentTool = parseResult.getCommonAgentTool(); + parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId)); + + Map properties = new HashMap<>(); + properties.put(Constants.CONTEXT, parseResult); + properties.put("type", "internal"); + properties.put("name", commonAgentTool.getName()); + + parseInfo.setProperties(properties); + parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight)); + parseInfo.setQueryMode(semanticQuery.getQueryMode()); + parseInfo.getSqlInfo().setS2QL(parseResult.getLlmResp().getSqlOutput()); + + SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); + Map modelIdToName = semanticSchema.getModelIdToName(); + + SchemaElement model = new SchemaElement(); + model.setModel(modelId); + model.setId(modelId); + model.setName(modelIdToName.get(modelId)); + parseInfo.setModel(model); + queryCtx.getCandidateQueries().add(semanticQuery); + return parseInfo; + } + + protected Map getNameToElement(Long modelId) { + SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); + List dimensions = semanticSchema.getDimensions(); + List metrics = semanticSchema.getMetrics(); + + List allElements = Lists.newArrayList(); + allElements.addAll(dimensions); + allElements.addAll(metrics); + //support alias + return allElements.stream() + .filter(schemaElement -> schemaElement.getModel().equals(modelId)) + .flatMap(schemaElement -> { + Set> result = new HashSet<>(); + result.add(Pair.of(schemaElement.getName(), schemaElement)); + List aliasList = schemaElement.getAlias(); + if (!CollectionUtils.isEmpty(aliasList)) { + for (String alias : aliasList) { + result.add(Pair.of(alias, schemaElement)); + } + } + return result.stream(); + }) + .collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2)); + } + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMS2QLParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMS2QLParser.java index 9158344f3..0ade89541 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMS2QLParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMS2QLParser.java @@ -1,62 +1,16 @@ package com.tencent.supersonic.chat.parser.llm.s2ql; -import com.google.common.collect.Lists; -import com.tencent.supersonic.chat.agent.tool.AgentToolType; import com.tencent.supersonic.chat.agent.tool.CommonAgentTool; -import com.tencent.supersonic.chat.api.component.SemanticCorrector; import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.QueryContext; -import com.tencent.supersonic.chat.api.pojo.SchemaElement; -import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; -import com.tencent.supersonic.chat.api.pojo.SchemaElementType; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticSchema; -import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; -import com.tencent.supersonic.chat.config.LLMParserConfig; -import com.tencent.supersonic.chat.parser.SatisfactionChecker; -import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq; -import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq.ElementValue; import com.tencent.supersonic.chat.query.llm.s2ql.LLMResp; -import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery; -import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; -import com.tencent.supersonic.chat.service.AgentService; -import com.tencent.supersonic.chat.utils.ComponentFactory; -import com.tencent.supersonic.common.pojo.Constants; -import com.tencent.supersonic.common.pojo.DateConf; -import com.tencent.supersonic.common.pojo.DateConf.DateMode; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.DateUtils; -import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.common.util.jsqlparser.FilterExpression; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; -import com.tencent.supersonic.knowledge.service.SchemaService; -import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; -import org.apache.commons.lang3.tuple.Pair; -import org.springframework.http.HttpEntity; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; -import org.springframework.http.MediaType; -import org.springframework.http.ResponseEntity; -import org.springframework.util.CollectionUtils; -import org.springframework.web.client.RestTemplate; @Slf4j public class LLMS2QLParser implements SemanticParser { @@ -64,46 +18,46 @@ public class LLMS2QLParser implements SemanticParser { @Override public void parse(QueryContext queryCtx, ChatContext chatCtx) { QueryReq request = queryCtx.getRequest(); - LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class); - if (StringUtils.isEmpty(llmParserConfig.getUrl())) { - log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMS2QLParser.class, llmParserConfig); - return; - } + LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class); //1.determine whether to skip this parser. - if (SatisfactionChecker.check(queryCtx)) { - log.info("skip {}, queryText:{}", LLMS2QLParser.class, request.getQueryText()); + if (requestService.check(queryCtx)) { return; } try { //2.get modelId from queryCtx and chatCtx. - Long modelId = getModelId(queryCtx, chatCtx, request.getAgentId()); + Long modelId = requestService.getModelId(queryCtx, chatCtx, request.getAgentId()); if (Objects.isNull(modelId) || modelId <= 0) { return; } //3.get agent tool and determine whether to skip this parser. - CommonAgentTool commonAgentTool = getParserTool(request, modelId); + CommonAgentTool commonAgentTool = requestService.getParserTool(request, modelId); if (Objects.isNull(commonAgentTool)) { log.info("no tool in this agent, skip {}", LLMS2QLParser.class); return; } //4.construct a request, call the API for the large model, and retrieve the results. - LLMReq llmReq = getLlmReq(queryCtx, modelId, llmParserConfig); - LLMResp llmResp = requestLLM(llmReq, modelId, llmParserConfig); + LLMReq llmReq = requestService.getLlmReq(queryCtx, modelId); + LLMResp llmResp = requestService.requestLLM(llmReq, modelId); if (Objects.isNull(llmResp)) { return; } //5. get and update parserInfo and corrector sql Map sqlWeight = llmResp.getSqlWeight(); + ParseResult parseResult = ParseResult.builder() + .request(request) + .modelId(modelId) + .commonAgentTool(commonAgentTool) + .llmReq(llmReq) + .llmResp(llmResp).build(); - ParseResult parseResult = ParseResult.builder().request(request) - .commonAgentTool(commonAgentTool).llmReq(llmReq).llmResp(llmResp).build(); + LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class); - if (Objects.isNull(sqlWeight) || sqlWeight.size() <= 0) { - addParseInfo(queryCtx, parseResult, modelId, commonAgentTool, llmResp.getSqlOutput(), 1D); + if (Objects.isNull(sqlWeight) || sqlWeight.isEmpty()) { + responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D); } else { sqlWeight.forEach((sql, weight) -> { - addParseInfo(queryCtx, parseResult, modelId, commonAgentTool, sql, weight); + responseService.addParseInfo(queryCtx, parseResult, sql, weight); }); } @@ -112,384 +66,5 @@ public class LLMS2QLParser implements SemanticParser { } } - private void addParseInfo(QueryContext queryCtx, ParseResult parseResult, Long modelId, - CommonAgentTool commonAgentTool, String sql, Double weight) { - - SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, commonAgentTool, parseResult, weight); - - SemanticCorrectInfo semanticCorrectInfo = getCorrectorSql(queryCtx, parseInfo, sql); - - parseInfo.getSqlInfo().setLogicSql(semanticCorrectInfo.getSql()); - - updateParseInfo(semanticCorrectInfo, modelId, parseInfo); - } - - private Set getElements(Long modelId, List allFields, List elements) { - return elements.stream() - .filter(schemaElement -> modelId.equals(schemaElement.getModel()) - && allFields.contains(schemaElement.getName()) - ).collect(Collectors.toSet()); - } - - private List getFieldsExceptDate(List allFields) { - if (CollectionUtils.isEmpty(allFields)) { - return new ArrayList<>(); - } - return allFields.stream() - .filter(entry -> !DateUtils.DATE_FIELD.equalsIgnoreCase(entry)) - .collect(Collectors.toList()); - } - - public void updateParseInfo(SemanticCorrectInfo semanticCorrectInfo, Long modelId, SemanticParseInfo parseInfo) { - - String correctorSql = semanticCorrectInfo.getSql(); - parseInfo.getSqlInfo().setLogicSql(correctorSql); - - List expressions = SqlParserSelectHelper.getFilterExpression(correctorSql); - //set dataInfo - try { - if (!CollectionUtils.isEmpty(expressions)) { - DateConf dateInfo = getDateInfo(expressions); - parseInfo.setDateInfo(dateInfo); - } - } catch (Exception e) { - log.error("set dateInfo error :", e); - } - - //set filter - try { - Map fieldNameToElement = getNameToElement(modelId); - List result = getDimensionFilter(fieldNameToElement, expressions); - parseInfo.getDimensionFilters().addAll(result); - } catch (Exception e) { - log.error("set dimensionFilter error :", e); - } - - SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); - - if (Objects.isNull(semanticSchema)) { - return; - } - List allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(semanticCorrectInfo.getSql())); - - Set metrics = getElements(modelId, allFields, semanticSchema.getMetrics()); - parseInfo.setMetrics(metrics); - - if (SqlParserSelectFunctionHelper.hasAggregateFunction(semanticCorrectInfo.getSql())) { - parseInfo.setNativeQuery(false); - List groupByFields = SqlParserSelectHelper.getGroupByFields(semanticCorrectInfo.getSql()); - List groupByDimensions = getFieldsExceptDate(groupByFields); - parseInfo.setDimensions(getElements(modelId, groupByDimensions, semanticSchema.getDimensions())); - } else { - parseInfo.setNativeQuery(true); - List selectFields = SqlParserSelectHelper.getSelectFields(semanticCorrectInfo.getSql()); - List selectDimensions = getFieldsExceptDate(selectFields); - parseInfo.setDimensions(getElements(modelId, selectDimensions, semanticSchema.getDimensions())); - } - } - - private List getDimensionFilter(Map fieldNameToElement, - List filterExpressions) { - List result = Lists.newArrayList(); - for (FilterExpression expression : filterExpressions) { - QueryFilter dimensionFilter = new QueryFilter(); - dimensionFilter.setValue(expression.getFieldValue()); - SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName()); - if (Objects.isNull(schemaElement)) { - continue; - } - dimensionFilter.setName(schemaElement.getName()); - dimensionFilter.setBizName(schemaElement.getBizName()); - dimensionFilter.setElementID(schemaElement.getId()); - - FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator()); - dimensionFilter.setOperator(operatorEnum); - dimensionFilter.setFunction(expression.getFunction()); - result.add(dimensionFilter); - } - return result; - } - - private DateConf getDateInfo(List filterExpressions) { - List dateExpressions = filterExpressions.stream() - .filter(expression -> DateUtils.DATE_FIELD.equalsIgnoreCase(expression.getFieldName())) - .collect(Collectors.toList()); - if (CollectionUtils.isEmpty(dateExpressions)) { - return new DateConf(); - } - DateConf dateInfo = new DateConf(); - dateInfo.setDateMode(DateMode.BETWEEN); - FilterExpression firstExpression = dateExpressions.get(0); - - FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator()); - if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) { - dateInfo.setStartDate(firstExpression.getFieldValue().toString()); - dateInfo.setEndDate(firstExpression.getFieldValue().toString()); - dateInfo.setDateMode(DateMode.BETWEEN); - return dateInfo; - } - if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN, - FilterOperatorEnum.GREATER_THAN_EQUALS)) { - dateInfo.setStartDate(firstExpression.getFieldValue().toString()); - if (hasSecondDate(dateExpressions)) { - dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString()); - } - } - if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN, - FilterOperatorEnum.MINOR_THAN_EQUALS)) { - dateInfo.setEndDate(firstExpression.getFieldValue().toString()); - if (hasSecondDate(dateExpressions)) { - dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString()); - } - } - return dateInfo; - } - - private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator, - FilterOperatorEnum... operatorEnums) { - return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue())); - } - - private boolean hasSecondDate(List dateExpressions) { - return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue()); - } - - private SemanticCorrectInfo getCorrectorSql(QueryContext queryCtx, SemanticParseInfo parseInfo, String sql) { - - SemanticCorrectInfo correctInfo = SemanticCorrectInfo.builder() - .queryFilters(queryCtx.getRequest().getQueryFilters()).sql(sql) - .parseInfo(parseInfo).build(); - - List corrections = ComponentFactory.getSqlCorrections(); - - corrections.forEach(correction -> { - try { - correction.correct(correctInfo); - log.info("sqlCorrection:{} sql:{}", correction.getClass().getSimpleName(), correctInfo.getSql()); - } catch (Exception e) { - log.error(String.format("correct error,correctInfo:%s", correctInfo), e); - } - }); - return correctInfo; - } - - private SemanticParseInfo getParseInfo(QueryContext queryCtx, Long modelId, CommonAgentTool commonAgentTool, - ParseResult parseResult, Double weight) { - if (Objects.isNull(weight)) { - weight = 0D; - } - PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(S2QLQuery.QUERY_MODE); - SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); - parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId)); - - Map properties = new HashMap<>(); - properties.put(Constants.CONTEXT, parseResult); - properties.put("type", "internal"); - properties.put("name", commonAgentTool.getName()); - - parseInfo.setProperties(properties); - parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight)); - parseInfo.setQueryMode(semanticQuery.getQueryMode()); - parseInfo.getSqlInfo().setS2QL(parseResult.getLlmResp().getSqlOutput()); - - SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); - Map modelIdToName = semanticSchema.getModelIdToName(); - - SchemaElement model = new SchemaElement(); - model.setModel(modelId); - model.setId(modelId); - model.setName(modelIdToName.get(modelId)); - parseInfo.setModel(model); - queryCtx.getCandidateQueries().add(semanticQuery); - return parseInfo; - } - - private CommonAgentTool getParserTool(QueryReq request, Long modelId) { - AgentService agentService = ContextUtils.getBean(AgentService.class); - List commonAgentTools = agentService.getParserTools(request.getAgentId(), - AgentToolType.LLM_S2QL); - Optional llmParserTool = commonAgentTools.stream() - .filter(tool -> { - List modelIds = tool.getModelIds(); - if (agentService.containsAllModel(new HashSet<>(modelIds))) { - return true; - } - return modelIds.contains(modelId); - }) - .findFirst(); - return llmParserTool.orElse(null); - } - - private Long getModelId(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) { - AgentService agentService = ContextUtils.getBean(AgentService.class); - Set distinctModelIds = agentService.getModelIds(agentId, AgentToolType.LLM_S2QL); - if (agentService.containsAllModel(distinctModelIds)) { - distinctModelIds = new HashSet<>(); - } - ModelResolver modelResolver = ComponentFactory.getModelResolver(); - Long modelId = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds); - log.info("resolve modelId:{},llmParser Models:{}", modelId, distinctModelIds); - return modelId; - } - - private LLMResp requestLLM(LLMReq llmReq, Long modelId, LLMParserConfig llmParserConfig) { - String questUrl = llmParserConfig.getUrl() + llmParserConfig.getQueryToSqlPath(); - long startTime = System.currentTimeMillis(); - log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq); - RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class); - try { - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); - HttpEntity entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers); - ResponseEntity responseEntity = restTemplate.exchange(questUrl, HttpMethod.POST, entity, - LLMResp.class); - - log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}", - System.currentTimeMillis() - startTime, questUrl, entity, responseEntity.getBody()); - return responseEntity.getBody(); - } catch (Exception e) { - log.error("requestLLM error", e); - } - return null; - } - - private LLMReq getLlmReq(QueryContext queryCtx, Long modelId, LLMParserConfig llmParserConfig) { - SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); - Map modelIdToName = semanticSchema.getModelIdToName(); - String queryText = queryCtx.getRequest().getQueryText(); - - LLMReq llmReq = new LLMReq(); - llmReq.setQueryText(queryText); - - LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema(); - llmSchema.setModelName(modelIdToName.get(modelId)); - llmSchema.setDomainName(modelIdToName.get(modelId)); - - List fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig); - - fieldNameList.add(DateUtils.DATE_FIELD); - llmSchema.setFieldNameList(fieldNameList); - llmReq.setSchema(llmSchema); - - List linking = new ArrayList<>(); - linking.addAll(getValueList(queryCtx, modelId, semanticSchema)); - llmReq.setLinking(linking); - - String currentDate = S2QLDateHelper.getReferenceDate(modelId); - if (StringUtils.isEmpty(currentDate)) { - currentDate = DateUtils.getBeforeDate(0); - } - llmReq.setCurrentDate(currentDate); - return llmReq; - } - - protected List getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) { - Map itemIdToName = getItemIdToName(modelId, semanticSchema); - - List matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId); - if (CollectionUtils.isEmpty(matchedElements)) { - return new ArrayList<>(); - } - Set valueMatches = matchedElements - .stream() - .filter(elementMatch -> !elementMatch.isInherited()) - .filter(schemaElementMatch -> { - SchemaElementType type = schemaElementMatch.getElement().getType(); - return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type); - }) - .map(elementMatch -> { - ElementValue elementValue = new ElementValue(); - elementValue.setFieldName(itemIdToName.get(elementMatch.getElement().getId())); - elementValue.setFieldValue(elementMatch.getWord()); - return elementValue; - }).collect(Collectors.toSet()); - return new ArrayList<>(valueMatches); - } - - - protected Map getNameToElement(Long modelId) { - SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); - List dimensions = semanticSchema.getDimensions(); - List metrics = semanticSchema.getMetrics(); - - List allElements = Lists.newArrayList(); - allElements.addAll(dimensions); - allElements.addAll(metrics); - //support alias - return allElements.stream() - .filter(schemaElement -> schemaElement.getModel().equals(modelId)) - .flatMap(schemaElement -> { - Set> result = new HashSet<>(); - result.add(Pair.of(schemaElement.getName(), schemaElement)); - List aliasList = schemaElement.getAlias(); - if (!CollectionUtils.isEmpty(aliasList)) { - for (String alias : aliasList) { - result.add(Pair.of(alias, schemaElement)); - } - } - return result.stream(); - }) - .collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2)); - } - - - protected List getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema, - LLMParserConfig llmParserConfig) { - - Set results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig); - - Set fieldNameList = getMatchedFieldNames(queryCtx, modelId, semanticSchema); - - results.addAll(fieldNameList); - return new ArrayList<>(results); - } - - protected Set getMatchedFieldNames(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) { - Map itemIdToName = getItemIdToName(modelId, semanticSchema); - List matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId); - if (CollectionUtils.isEmpty(matchedElements)) { - return new HashSet<>(); - } - Set fieldNameList = matchedElements.stream() - .filter(schemaElementMatch -> { - SchemaElementType elementType = schemaElementMatch.getElement().getType(); - return SchemaElementType.METRIC.equals(elementType) - || SchemaElementType.DIMENSION.equals(elementType) - || SchemaElementType.VALUE.equals(elementType); - }) - .map(schemaElementMatch -> { - SchemaElement element = schemaElementMatch.getElement(); - SchemaElementType elementType = element.getType(); - if (SchemaElementType.VALUE.equals(elementType)) { - return itemIdToName.get(element.getId()); - } - return schemaElementMatch.getWord(); - }) - .collect(Collectors.toSet()); - return fieldNameList; - } - - private Set getTopNFieldNames(Long modelId, SemanticSchema semanticSchema, - LLMParserConfig llmParserConfig) { - Set results = semanticSchema.getDimensions(modelId).stream() - .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) - .limit(llmParserConfig.getDimensionTopN()) - .map(entry -> entry.getName()) - .collect(Collectors.toSet()); - - Set metrics = semanticSchema.getMetrics(modelId).stream() - .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) - .limit(llmParserConfig.getMetricTopN()) - .map(entry -> entry.getName()) - .collect(Collectors.toSet()); - - results.addAll(metrics); - return results; - } - - protected Map getItemIdToName(Long modelId, SemanticSchema semanticSchema) { - return semanticSchema.getDimensions(modelId).stream() - .collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2)); - } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/ParseResult.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/ParseResult.java index 4efc8a351..5f43da192 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/ParseResult.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2ql/ParseResult.java @@ -15,6 +15,8 @@ import lombok.NoArgsConstructor; @NoArgsConstructor public class ParseResult { + private Long modelId; + private LLMReq llmReq; private LLMResp llmResp; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2ql/LLMReq.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2ql/LLMReq.java index d8c3ba5ef..cf036197a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2ql/LLMReq.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2ql/LLMReq.java @@ -14,6 +14,8 @@ public class LLMReq { private String currentDate; + private String priorExts; + @Data public static class ElementValue { diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMS2QLParserTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMS2QLParserTest.java index 869efa136..b31b8a5f5 100644 --- a/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMS2QLParserTest.java +++ b/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/s2ql/LLMS2QLParserTest.java @@ -4,8 +4,6 @@ import static org.mockito.Mockito.when; import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaValueMap; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.knowledge.service.SchemaService; @@ -62,19 +60,5 @@ class LLMS2QLParserTest { when(mockSchemaService.getSemanticSchema()).thenReturn(mockSemanticSchema); mockContextUtils.when(() -> ContextUtils.getBean(SchemaService.class)).thenReturn(mockSchemaService); - - SemanticParseInfo parseInfo = new SemanticParseInfo(); - SchemaElement model = new SchemaElement(); - model.setId(2L); - parseInfo.setModel(model); - SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder() - .sql("select count(song_name) from 歌曲库 where singer_name = '周先生' and YEAR(publish_time) >= 2023 ") - .parseInfo(parseInfo) - .build(); - - LLMS2QLParser llms2QLParser = new LLMS2QLParser(); - - llms2QLParser.updateParseInfo(semanticCorrectInfo, 2L, parseInfo); - } } \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DataFormatTypeEnum.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DataFormatTypeEnum.java new file mode 100644 index 000000000..a4c5b5564 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DataFormatTypeEnum.java @@ -0,0 +1,19 @@ +package com.tencent.supersonic.common.pojo.enums; + +public enum DataFormatTypeEnum { + + PERCENT("percent"), + + DECIMAL("decimal"); + + private String name; + + DataFormatTypeEnum(String name) { + this.name = name; + } + + public String getName() { + return name; + } + +} \ No newline at end of file