[improvement][project]Opt some code structures.

This commit is contained in:
jerryjzhang
2024-10-27 10:38:30 +08:00
parent 3a905d7fb1
commit 1e3daffade
9 changed files with 35 additions and 48 deletions

View File

@@ -40,13 +40,15 @@ public class ChatQueryContext {
private Map<Long, List<Long>> modelIdToDataSetIds;
private User user;
private boolean saveAnswer;
@Builder.Default
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
private QueryFilters queryFilters;
private List<SemanticQuery> candidateQueries = new ArrayList<>();
private SchemaMapInfo mapInfo = new SchemaMapInfo();
private SemanticParseInfo contextParseInfo;
@Builder.Default
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
@Builder.Default
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
@Builder.Default
private QueryDataType queryDataType = QueryDataType.ALL;
@JsonIgnore
private SemanticSchema semanticSchema;

View File

@@ -78,10 +78,9 @@ public class QueryTypeParser implements SemanticParser {
}
private static List<String> filterByTimeFields(List<String> whereFields) {
List<String> selectAndWhereFilterByTimeFields = whereFields.stream()
return whereFields.stream()
.filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
.collect(Collectors.toList());
return selectAndWhereFilterByTimeFields;
}
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId,

View File

@@ -65,7 +65,7 @@ public class LLMRequestService {
llmSchema.setPrimaryKey(getPrimaryKey(queryCtx, dataSetId));
boolean linkingValueEnabled =
Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
Boolean.parseBoolean(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
if (linkingValueEnabled) {
llmSchema.setValues(getMappedValues(queryCtx, dataSetId));
}
@@ -135,13 +135,10 @@ public class LLMRequestService {
if (CollectionUtils.isEmpty(matchedElements)) {
return Collections.emptyList();
}
List<SchemaElement> schemaElements = matchedElements.stream().filter(schemaElementMatch -> {
return matchedElements.stream().filter(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.METRIC.equals(elementType);
}).map(schemaElementMatch -> {
return schemaElementMatch.getElement();
}).collect(Collectors.toList());
return schemaElements;
}).map(SchemaElementMatch::getElement).collect(Collectors.toList());
}
protected List<SchemaElement> getMappedDimensions(@NotNull ChatQueryContext queryCtx,

View File

@@ -23,8 +23,8 @@ import java.util.Objects;
@Service
public class LLMResponseService {
public SemanticParseInfo addParseInfo(ChatQueryContext queryCtx, ParseResult parseResult,
String s2SQL, Double weight) {
public void addParseInfo(ChatQueryContext queryCtx, ParseResult parseResult, String s2SQL,
Double weight) {
if (Objects.isNull(weight)) {
weight = 0D;
}
@@ -49,7 +49,6 @@ public class LLMResponseService {
parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
queryCtx.getCandidateQueries().add(semanticQuery);
return parseInfo;
}
public Map<String, LLMSqlResp> getDeduplicationSqlResp(int currentRetry, LLMResp llmResp) {

View File

@@ -31,22 +31,22 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
public static final String APP_KEY = "S2SQL_PARSER";
public static final String INSTRUCTION = ""
+ "#Role: You are a data analyst experienced in SQL languages."
+ "\n#Task: You will be provided with a natural language question asked by users,"
+ "please convert it to a SQL query so that relevant data could be returned "
+ "by executing the SQL query against underlying database." + "\n#Rules:"
+ "\n1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate."
+ "\n2.ALWAYS be cautious, word in the `Schema` does not mean it must appear in the SQL."
+ "\n3.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
+ "\n4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
+ "\n5.DO NOT calculate date range using functions."
+ "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
+ "\n7.ALWAYS use `with` statement if nested aggregation is needed."
+ "\n8.ALWAYS enclose alias created by `AS` command in underscores."
+ "\n9.ALWAYS translate alias created by `AS` command to the same language as the `#Question`."
+ "\n#Exemplars: {{exemplar}}"
+ "\n#Query: Question:{{question}},Schema:{{schema}},SideInfo:{{information}}";
public static final String INSTRUCTION =
"#Role: You are a data analyst experienced in SQL languages."
+ "\n#Task: You will be provided with a natural language question asked by users,"
+ "please convert it to a SQL query so that relevant data could be returned "
+ "by executing the SQL query against underlying database." + "\n#Rules:"
+ "\n1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate."
+ "\n2.ALWAYS be cautious, word in the `Schema` does not mean it must appear in the SQL."
+ "\n3.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
+ "\n4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
+ "\n5.DO NOT calculate date range using functions."
+ "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
+ "\n7.ALWAYS use `with` statement if nested aggregation is needed."
+ "\n8.ALWAYS enclose alias created by `AS` command in underscores."
+ "\n9.ALWAYS translate alias created by `AS` command to the same language as the `#Question`."
+ "\n#Exemplars: {{exemplar}}"
+ "\n#Query: Question:{{question}},Schema:{{schema}},SideInfo:{{information}}";
public OnePassSCSqlGenStrategy() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL解析")

View File

@@ -86,13 +86,13 @@ public class AggregateTypeParser implements SemanticParser {
AggregateTypeEnum type =
aggregateCount.entrySet().stream().max(Map.Entry.comparingByValue())
.map(entry -> entry.getKey()).orElse(AggregateTypeEnum.NONE);
.map(Map.Entry::getKey).orElse(AggregateTypeEnum.NONE);
String detectWord = aggregateWord.get(type);
return new AggregateConf(type, detectWord);
}
@AllArgsConstructor
class AggregateConf {
static class AggregateConf {
public AggregateTypeEnum type;
public String detectWord;
}

View File

@@ -17,8 +17,8 @@ import java.util.List;
@Slf4j
public class RuleSqlParser implements SemanticParser {
private static List<SemanticParser> auxiliaryParsers = Arrays.asList(new ContextInheritParser(),
new TimeRangeParser(), new AggregateTypeParser());
private static final List<SemanticParser> auxiliaryParsers = Arrays
.asList(new ContextInheritParser(), new TimeRangeParser(), new AggregateTypeParser());
@Override
public void parse(ChatQueryContext chatQueryContext) {
@@ -38,6 +38,6 @@ public class RuleSqlParser implements SemanticParser {
}
}
auxiliaryParsers.stream().forEach(p -> p.parse(chatQueryContext));
auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
}
}

View File

@@ -66,9 +66,7 @@ public class S2ChatLayerService implements ChatLayerService {
public MapResp map(QueryNLReq queryNLReq) {
MapResp mapResp = new MapResp(queryNLReq.getQueryText());
ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq);
ComponentFactory.getSchemaMappers().forEach(mapper -> {
mapper.map(queryCtx);
});
ComponentFactory.getSchemaMappers().forEach(mapper -> mapper.map(queryCtx));
mapResp.setMapInfo(queryCtx.getMapInfo());
return mapResp;
}
@@ -264,22 +262,16 @@ public class S2ChatLayerService implements ChatLayerService {
/**
* * get time dimension SchemaElementMatch
*
* @param dataSetId
* @param dataSetName
* @return
*/
private SchemaElementMatch getTimeDimension(Long dataSetId, String dataSetName) {
SchemaElement element = SchemaElement.builder().dataSetId(dataSetId)
.dataSetName(dataSetName).type(SchemaElementType.DIMENSION)
.bizName(TimeDimensionEnum.DAY.getName()).build();
SchemaElementMatch timeDimensionMatch = SchemaElementMatch.builder().element(element)
return SchemaElementMatch.builder().element(element)
.detectWord(TimeDimensionEnum.DAY.getChName())
.word(TimeDimensionEnum.DAY.getChName()).similarity(1L)
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY).build();
return timeDimensionMatch;
}
private Function<SchemaElement, SchemaElementMatch> mergeFunction() {

View File

@@ -119,9 +119,7 @@ public class ChatWorkflowEngine {
}
private void performProcessing(ChatQueryContext queryCtx, ParseResp parseResult) {
resultProcessors.forEach(processor -> {
processor.process(parseResult, queryCtx);
});
resultProcessors.forEach(processor -> processor.process(parseResult, queryCtx));
}
private void performTranslating(ChatQueryContext chatQueryContext, ParseResp parseResult) {
@@ -160,7 +158,7 @@ public class ChatWorkflowEngine {
}
});
if (!errorMsg.isEmpty()) {
parseResult.setErrorMsg(errorMsg.stream().collect(Collectors.joining("\n")));
parseResult.setErrorMsg(String.join("\n", errorMsg));
}
}
}