mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement][project]Opt some code structures.
This commit is contained in:
@@ -40,13 +40,15 @@ public class ChatQueryContext {
|
||||
private Map<Long, List<Long>> modelIdToDataSetIds;
|
||||
private User user;
|
||||
private boolean saveAnswer;
|
||||
@Builder.Default
|
||||
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
|
||||
private QueryFilters queryFilters;
|
||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private SemanticParseInfo contextParseInfo;
|
||||
@Builder.Default
|
||||
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
|
||||
@Builder.Default
|
||||
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
||||
@Builder.Default
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
@JsonIgnore
|
||||
private SemanticSchema semanticSchema;
|
||||
|
||||
@@ -78,10 +78,9 @@ public class QueryTypeParser implements SemanticParser {
|
||||
}
|
||||
|
||||
private static List<String> filterByTimeFields(List<String> whereFields) {
|
||||
List<String> selectAndWhereFilterByTimeFields = whereFields.stream()
|
||||
return whereFields.stream()
|
||||
.filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||
.collect(Collectors.toList());
|
||||
return selectAndWhereFilterByTimeFields;
|
||||
}
|
||||
|
||||
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId,
|
||||
|
||||
@@ -65,7 +65,7 @@ public class LLMRequestService {
|
||||
llmSchema.setPrimaryKey(getPrimaryKey(queryCtx, dataSetId));
|
||||
|
||||
boolean linkingValueEnabled =
|
||||
Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
|
||||
Boolean.parseBoolean(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
|
||||
if (linkingValueEnabled) {
|
||||
llmSchema.setValues(getMappedValues(queryCtx, dataSetId));
|
||||
}
|
||||
@@ -135,13 +135,10 @@ public class LLMRequestService {
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
List<SchemaElement> schemaElements = matchedElements.stream().filter(schemaElementMatch -> {
|
||||
return matchedElements.stream().filter(schemaElementMatch -> {
|
||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.METRIC.equals(elementType);
|
||||
}).map(schemaElementMatch -> {
|
||||
return schemaElementMatch.getElement();
|
||||
}).collect(Collectors.toList());
|
||||
return schemaElements;
|
||||
}).map(SchemaElementMatch::getElement).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMappedDimensions(@NotNull ChatQueryContext queryCtx,
|
||||
|
||||
@@ -23,8 +23,8 @@ import java.util.Objects;
|
||||
@Service
|
||||
public class LLMResponseService {
|
||||
|
||||
public SemanticParseInfo addParseInfo(ChatQueryContext queryCtx, ParseResult parseResult,
|
||||
String s2SQL, Double weight) {
|
||||
public void addParseInfo(ChatQueryContext queryCtx, ParseResult parseResult, String s2SQL,
|
||||
Double weight) {
|
||||
if (Objects.isNull(weight)) {
|
||||
weight = 0D;
|
||||
}
|
||||
@@ -49,7 +49,6 @@ public class LLMResponseService {
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
public Map<String, LLMSqlResp> getDeduplicationSqlResp(int currentRetry, LLMResp llmResp) {
|
||||
|
||||
@@ -31,22 +31,22 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
public static final String APP_KEY = "S2SQL_PARSER";
|
||||
public static final String INSTRUCTION = ""
|
||||
+ "#Role: You are a data analyst experienced in SQL languages."
|
||||
+ "\n#Task: You will be provided with a natural language question asked by users,"
|
||||
+ "please convert it to a SQL query so that relevant data could be returned "
|
||||
+ "by executing the SQL query against underlying database." + "\n#Rules:"
|
||||
+ "\n1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate."
|
||||
+ "\n2.ALWAYS be cautious, word in the `Schema` does not mean it must appear in the SQL."
|
||||
+ "\n3.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
|
||||
+ "\n4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
||||
+ "\n5.DO NOT calculate date range using functions."
|
||||
+ "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
|
||||
+ "\n7.ALWAYS use `with` statement if nested aggregation is needed."
|
||||
+ "\n8.ALWAYS enclose alias created by `AS` command in underscores."
|
||||
+ "\n9.ALWAYS translate alias created by `AS` command to the same language as the `#Question`."
|
||||
+ "\n#Exemplars: {{exemplar}}"
|
||||
+ "\n#Query: Question:{{question}},Schema:{{schema}},SideInfo:{{information}}";
|
||||
public static final String INSTRUCTION =
|
||||
"#Role: You are a data analyst experienced in SQL languages."
|
||||
+ "\n#Task: You will be provided with a natural language question asked by users,"
|
||||
+ "please convert it to a SQL query so that relevant data could be returned "
|
||||
+ "by executing the SQL query against underlying database." + "\n#Rules:"
|
||||
+ "\n1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate."
|
||||
+ "\n2.ALWAYS be cautious, word in the `Schema` does not mean it must appear in the SQL."
|
||||
+ "\n3.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
|
||||
+ "\n4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
||||
+ "\n5.DO NOT calculate date range using functions."
|
||||
+ "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
|
||||
+ "\n7.ALWAYS use `with` statement if nested aggregation is needed."
|
||||
+ "\n8.ALWAYS enclose alias created by `AS` command in underscores."
|
||||
+ "\n9.ALWAYS translate alias created by `AS` command to the same language as the `#Question`."
|
||||
+ "\n#Exemplars: {{exemplar}}"
|
||||
+ "\n#Query: Question:{{question}},Schema:{{schema}},SideInfo:{{information}}";
|
||||
|
||||
public OnePassSCSqlGenStrategy() {
|
||||
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL解析")
|
||||
|
||||
@@ -86,13 +86,13 @@ public class AggregateTypeParser implements SemanticParser {
|
||||
|
||||
AggregateTypeEnum type =
|
||||
aggregateCount.entrySet().stream().max(Map.Entry.comparingByValue())
|
||||
.map(entry -> entry.getKey()).orElse(AggregateTypeEnum.NONE);
|
||||
.map(Map.Entry::getKey).orElse(AggregateTypeEnum.NONE);
|
||||
String detectWord = aggregateWord.get(type);
|
||||
return new AggregateConf(type, detectWord);
|
||||
}
|
||||
|
||||
@AllArgsConstructor
|
||||
class AggregateConf {
|
||||
static class AggregateConf {
|
||||
public AggregateTypeEnum type;
|
||||
public String detectWord;
|
||||
}
|
||||
|
||||
@@ -17,8 +17,8 @@ import java.util.List;
|
||||
@Slf4j
|
||||
public class RuleSqlParser implements SemanticParser {
|
||||
|
||||
private static List<SemanticParser> auxiliaryParsers = Arrays.asList(new ContextInheritParser(),
|
||||
new TimeRangeParser(), new AggregateTypeParser());
|
||||
private static final List<SemanticParser> auxiliaryParsers = Arrays
|
||||
.asList(new ContextInheritParser(), new TimeRangeParser(), new AggregateTypeParser());
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
@@ -38,6 +38,6 @@ public class RuleSqlParser implements SemanticParser {
|
||||
}
|
||||
}
|
||||
|
||||
auxiliaryParsers.stream().forEach(p -> p.parse(chatQueryContext));
|
||||
auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,9 +66,7 @@ public class S2ChatLayerService implements ChatLayerService {
|
||||
public MapResp map(QueryNLReq queryNLReq) {
|
||||
MapResp mapResp = new MapResp(queryNLReq.getQueryText());
|
||||
ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq);
|
||||
ComponentFactory.getSchemaMappers().forEach(mapper -> {
|
||||
mapper.map(queryCtx);
|
||||
});
|
||||
ComponentFactory.getSchemaMappers().forEach(mapper -> mapper.map(queryCtx));
|
||||
mapResp.setMapInfo(queryCtx.getMapInfo());
|
||||
return mapResp;
|
||||
}
|
||||
@@ -264,22 +262,16 @@ public class S2ChatLayerService implements ChatLayerService {
|
||||
|
||||
/**
|
||||
* * get time dimension SchemaElementMatch
|
||||
*
|
||||
* @param dataSetId
|
||||
* @param dataSetName
|
||||
* @return
|
||||
*/
|
||||
private SchemaElementMatch getTimeDimension(Long dataSetId, String dataSetName) {
|
||||
SchemaElement element = SchemaElement.builder().dataSetId(dataSetId)
|
||||
.dataSetName(dataSetName).type(SchemaElementType.DIMENSION)
|
||||
.bizName(TimeDimensionEnum.DAY.getName()).build();
|
||||
|
||||
SchemaElementMatch timeDimensionMatch = SchemaElementMatch.builder().element(element)
|
||||
return SchemaElementMatch.builder().element(element)
|
||||
.detectWord(TimeDimensionEnum.DAY.getChName())
|
||||
.word(TimeDimensionEnum.DAY.getChName()).similarity(1L)
|
||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY).build();
|
||||
|
||||
return timeDimensionMatch;
|
||||
}
|
||||
|
||||
private Function<SchemaElement, SchemaElementMatch> mergeFunction() {
|
||||
|
||||
@@ -119,9 +119,7 @@ public class ChatWorkflowEngine {
|
||||
}
|
||||
|
||||
private void performProcessing(ChatQueryContext queryCtx, ParseResp parseResult) {
|
||||
resultProcessors.forEach(processor -> {
|
||||
processor.process(parseResult, queryCtx);
|
||||
});
|
||||
resultProcessors.forEach(processor -> processor.process(parseResult, queryCtx));
|
||||
}
|
||||
|
||||
private void performTranslating(ChatQueryContext chatQueryContext, ParseResp parseResult) {
|
||||
@@ -160,7 +158,7 @@ public class ChatWorkflowEngine {
|
||||
}
|
||||
});
|
||||
if (!errorMsg.isEmpty()) {
|
||||
parseResult.setErrorMsg(errorMsg.stream().collect(Collectors.joining("\n")));
|
||||
parseResult.setErrorMsg(String.join("\n", errorMsg));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user