diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java index bc901a51e..bbc84ce07 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java @@ -40,13 +40,15 @@ public class ChatQueryContext { private Map> modelIdToDataSetIds; private User user; private boolean saveAnswer; - @Builder.Default - private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM; private QueryFilters queryFilters; private List candidateQueries = new ArrayList<>(); private SchemaMapInfo mapInfo = new SchemaMapInfo(); private SemanticParseInfo contextParseInfo; + @Builder.Default + private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM; + @Builder.Default private MapModeEnum mapModeEnum = MapModeEnum.STRICT; + @Builder.Default private QueryDataType queryDataType = QueryDataType.ALL; @JsonIgnore private SemanticSchema semanticSchema; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java index 5ff17608e..60d3af5ce 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java @@ -78,10 +78,9 @@ public class QueryTypeParser implements SemanticParser { } private static List filterByTimeFields(List whereFields) { - List selectAndWhereFilterByTimeFields = whereFields.stream() + return whereFields.stream() .filter(field -> !TimeDimensionEnum.containsTimeDimension(field)) .collect(Collectors.toList()); - return selectAndWhereFilterByTimeFields; } private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java index 232121014..429d02e45 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java @@ -65,7 +65,7 @@ public class LLMRequestService { llmSchema.setPrimaryKey(getPrimaryKey(queryCtx, dataSetId)); boolean linkingValueEnabled = - Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE)); + Boolean.parseBoolean(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE)); if (linkingValueEnabled) { llmSchema.setValues(getMappedValues(queryCtx, dataSetId)); } @@ -135,13 +135,10 @@ public class LLMRequestService { if (CollectionUtils.isEmpty(matchedElements)) { return Collections.emptyList(); } - List schemaElements = matchedElements.stream().filter(schemaElementMatch -> { + return matchedElements.stream().filter(schemaElementMatch -> { SchemaElementType elementType = schemaElementMatch.getElement().getType(); return SchemaElementType.METRIC.equals(elementType); - }).map(schemaElementMatch -> { - return schemaElementMatch.getElement(); - }).collect(Collectors.toList()); - return schemaElements; + }).map(SchemaElementMatch::getElement).collect(Collectors.toList()); } protected List getMappedDimensions(@NotNull ChatQueryContext queryCtx, diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java index 6aa750505..b0556747e 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java @@ -23,8 +23,8 @@ import java.util.Objects; @Service public class LLMResponseService { - public SemanticParseInfo addParseInfo(ChatQueryContext queryCtx, ParseResult parseResult, - String s2SQL, Double weight) { + public void addParseInfo(ChatQueryContext queryCtx, ParseResult parseResult, String s2SQL, + Double weight) { if (Objects.isNull(weight)) { weight = 0D; } @@ -49,7 +49,6 @@ public class LLMResponseService { parseInfo.setQueryMode(semanticQuery.getQueryMode()); parseInfo.getSqlInfo().setParsedS2SQL(s2SQL); queryCtx.getCandidateQueries().add(semanticQuery); - return parseInfo; } public Map getDeduplicationSqlResp(int currentRetry, LLMResp llmResp) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index b9e35f779..fd5f176e3 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -31,22 +31,22 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); public static final String APP_KEY = "S2SQL_PARSER"; - public static final String INSTRUCTION = "" - + "#Role: You are a data analyst experienced in SQL languages." - + "\n#Task: You will be provided with a natural language question asked by users," - + "please convert it to a SQL query so that relevant data could be returned " - + "by executing the SQL query against underlying database." + "\n#Rules:" - + "\n1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate." - + "\n2.ALWAYS be cautious, word in the `Schema` does not mean it must appear in the SQL." - + "\n3.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator." - + "\n4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`." - + "\n5.DO NOT calculate date range using functions." - + "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." - + "\n7.ALWAYS use `with` statement if nested aggregation is needed." - + "\n8.ALWAYS enclose alias created by `AS` command in underscores." - + "\n9.ALWAYS translate alias created by `AS` command to the same language as the `#Question`." - + "\n#Exemplars: {{exemplar}}" - + "\n#Query: Question:{{question}},Schema:{{schema}},SideInfo:{{information}}"; + public static final String INSTRUCTION = + "#Role: You are a data analyst experienced in SQL languages." + + "\n#Task: You will be provided with a natural language question asked by users," + + "please convert it to a SQL query so that relevant data could be returned " + + "by executing the SQL query against underlying database." + "\n#Rules:" + + "\n1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate." + + "\n2.ALWAYS be cautious, word in the `Schema` does not mean it must appear in the SQL." + + "\n3.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator." + + "\n4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`." + + "\n5.DO NOT calculate date range using functions." + + "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." + + "\n7.ALWAYS use `with` statement if nested aggregation is needed." + + "\n8.ALWAYS enclose alias created by `AS` command in underscores." + + "\n9.ALWAYS translate alias created by `AS` command to the same language as the `#Question`." + + "\n#Exemplars: {{exemplar}}" + + "\n#Query: Question:{{question}},Schema:{{schema}},SideInfo:{{information}}"; public OnePassSCSqlGenStrategy() { ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL解析") diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/AggregateTypeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/AggregateTypeParser.java index 2d04dd1d7..a5e0a05e1 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/AggregateTypeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/AggregateTypeParser.java @@ -86,13 +86,13 @@ public class AggregateTypeParser implements SemanticParser { AggregateTypeEnum type = aggregateCount.entrySet().stream().max(Map.Entry.comparingByValue()) - .map(entry -> entry.getKey()).orElse(AggregateTypeEnum.NONE); + .map(Map.Entry::getKey).orElse(AggregateTypeEnum.NONE); String detectWord = aggregateWord.get(type); return new AggregateConf(type, detectWord); } @AllArgsConstructor - class AggregateConf { + static class AggregateConf { public AggregateTypeEnum type; public String detectWord; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java index 40113401d..d5b1c3956 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java @@ -17,8 +17,8 @@ import java.util.List; @Slf4j public class RuleSqlParser implements SemanticParser { - private static List auxiliaryParsers = Arrays.asList(new ContextInheritParser(), - new TimeRangeParser(), new AggregateTypeParser()); + private static final List auxiliaryParsers = Arrays + .asList(new ContextInheritParser(), new TimeRangeParser(), new AggregateTypeParser()); @Override public void parse(ChatQueryContext chatQueryContext) { @@ -38,6 +38,6 @@ public class RuleSqlParser implements SemanticParser { } } - auxiliaryParsers.stream().forEach(p -> p.parse(chatQueryContext)); + auxiliaryParsers.forEach(p -> p.parse(chatQueryContext)); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java index 186d27777..b9edcd3e4 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java @@ -66,9 +66,7 @@ public class S2ChatLayerService implements ChatLayerService { public MapResp map(QueryNLReq queryNLReq) { MapResp mapResp = new MapResp(queryNLReq.getQueryText()); ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq); - ComponentFactory.getSchemaMappers().forEach(mapper -> { - mapper.map(queryCtx); - }); + ComponentFactory.getSchemaMappers().forEach(mapper -> mapper.map(queryCtx)); mapResp.setMapInfo(queryCtx.getMapInfo()); return mapResp; } @@ -264,22 +262,16 @@ public class S2ChatLayerService implements ChatLayerService { /** * * get time dimension SchemaElementMatch - * - * @param dataSetId - * @param dataSetName - * @return */ private SchemaElementMatch getTimeDimension(Long dataSetId, String dataSetName) { SchemaElement element = SchemaElement.builder().dataSetId(dataSetId) .dataSetName(dataSetName).type(SchemaElementType.DIMENSION) .bizName(TimeDimensionEnum.DAY.getName()).build(); - SchemaElementMatch timeDimensionMatch = SchemaElementMatch.builder().element(element) + return SchemaElementMatch.builder().element(element) .detectWord(TimeDimensionEnum.DAY.getChName()) .word(TimeDimensionEnum.DAY.getChName()).similarity(1L) .frequency(BaseWordBuilder.DEFAULT_FREQUENCY).build(); - - return timeDimensionMatch; } private Function mergeFunction() { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java index 61562a6b3..1e987ad25 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java @@ -119,9 +119,7 @@ public class ChatWorkflowEngine { } private void performProcessing(ChatQueryContext queryCtx, ParseResp parseResult) { - resultProcessors.forEach(processor -> { - processor.process(parseResult, queryCtx); - }); + resultProcessors.forEach(processor -> processor.process(parseResult, queryCtx)); } private void performTranslating(ChatQueryContext chatQueryContext, ParseResp parseResult) { @@ -160,7 +158,7 @@ public class ChatWorkflowEngine { } }); if (!errorMsg.isEmpty()) { - parseResult.setErrorMsg(errorMsg.stream().collect(Collectors.joining("\n"))); + parseResult.setErrorMsg(String.join("\n", errorMsg)); } } }