[improvement][project]Opt some code structures.

This commit is contained in:
jerryjzhang
2024-10-27 10:38:30 +08:00
parent 3a905d7fb1
commit 1e3daffade
9 changed files with 35 additions and 48 deletions

View File

@@ -40,13 +40,15 @@ public class ChatQueryContext {
private Map<Long, List<Long>> modelIdToDataSetIds;
private User user;
private boolean saveAnswer;
@Builder.Default
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
private QueryFilters queryFilters;
private List<SemanticQuery> candidateQueries = new ArrayList<>();
private SchemaMapInfo mapInfo = new SchemaMapInfo();
private SemanticParseInfo contextParseInfo;
@Builder.Default
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
@Builder.Default
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
@Builder.Default
private QueryDataType queryDataType = QueryDataType.ALL;
@JsonIgnore
private SemanticSchema semanticSchema;

View File

@@ -78,10 +78,9 @@ public class QueryTypeParser implements SemanticParser {
}
private static List<String> filterByTimeFields(List<String> whereFields) {
List<String> selectAndWhereFilterByTimeFields = whereFields.stream()
return whereFields.stream()
.filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
.collect(Collectors.toList());
return selectAndWhereFilterByTimeFields;
}
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId,

View File

@@ -65,7 +65,7 @@ public class LLMRequestService {
llmSchema.setPrimaryKey(getPrimaryKey(queryCtx, dataSetId));
boolean linkingValueEnabled =
Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
Boolean.parseBoolean(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
if (linkingValueEnabled) {
llmSchema.setValues(getMappedValues(queryCtx, dataSetId));
}
@@ -135,13 +135,10 @@ public class LLMRequestService {
if (CollectionUtils.isEmpty(matchedElements)) {
return Collections.emptyList();
}
List<SchemaElement> schemaElements = matchedElements.stream().filter(schemaElementMatch -> {
return matchedElements.stream().filter(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.METRIC.equals(elementType);
}).map(schemaElementMatch -> {
return schemaElementMatch.getElement();
}).collect(Collectors.toList());
return schemaElements;
}).map(SchemaElementMatch::getElement).collect(Collectors.toList());
}
protected List<SchemaElement> getMappedDimensions(@NotNull ChatQueryContext queryCtx,

View File

@@ -23,8 +23,8 @@ import java.util.Objects;
@Service
public class LLMResponseService {
public SemanticParseInfo addParseInfo(ChatQueryContext queryCtx, ParseResult parseResult,
String s2SQL, Double weight) {
public void addParseInfo(ChatQueryContext queryCtx, ParseResult parseResult, String s2SQL,
Double weight) {
if (Objects.isNull(weight)) {
weight = 0D;
}
@@ -49,7 +49,6 @@ public class LLMResponseService {
parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
queryCtx.getCandidateQueries().add(semanticQuery);
return parseInfo;
}
public Map<String, LLMSqlResp> getDeduplicationSqlResp(int currentRetry, LLMResp llmResp) {

View File

@@ -31,22 +31,22 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
public static final String APP_KEY = "S2SQL_PARSER";
public static final String INSTRUCTION = ""
+ "#Role: You are a data analyst experienced in SQL languages."
+ "\n#Task: You will be provided with a natural language question asked by users,"
+ "please convert it to a SQL query so that relevant data could be returned "
+ "by executing the SQL query against underlying database." + "\n#Rules:"
+ "\n1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate."
+ "\n2.ALWAYS be cautious, word in the `Schema` does not mean it must appear in the SQL."
+ "\n3.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
+ "\n4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
+ "\n5.DO NOT calculate date range using functions."
+ "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
+ "\n7.ALWAYS use `with` statement if nested aggregation is needed."
+ "\n8.ALWAYS enclose alias created by `AS` command in underscores."
+ "\n9.ALWAYS translate alias created by `AS` command to the same language as the `#Question`."
+ "\n#Exemplars: {{exemplar}}"
+ "\n#Query: Question:{{question}},Schema:{{schema}},SideInfo:{{information}}";
public static final String INSTRUCTION =
"#Role: You are a data analyst experienced in SQL languages."
+ "\n#Task: You will be provided with a natural language question asked by users,"
+ "please convert it to a SQL query so that relevant data could be returned "
+ "by executing the SQL query against underlying database." + "\n#Rules:"
+ "\n1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate."
+ "\n2.ALWAYS be cautious, word in the `Schema` does not mean it must appear in the SQL."
+ "\n3.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
+ "\n4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
+ "\n5.DO NOT calculate date range using functions."
+ "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
+ "\n7.ALWAYS use `with` statement if nested aggregation is needed."
+ "\n8.ALWAYS enclose alias created by `AS` command in underscores."
+ "\n9.ALWAYS translate alias created by `AS` command to the same language as the `#Question`."
+ "\n#Exemplars: {{exemplar}}"
+ "\n#Query: Question:{{question}},Schema:{{schema}},SideInfo:{{information}}";
public OnePassSCSqlGenStrategy() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL解析")

View File

@@ -86,13 +86,13 @@ public class AggregateTypeParser implements SemanticParser {
AggregateTypeEnum type =
aggregateCount.entrySet().stream().max(Map.Entry.comparingByValue())
.map(entry -> entry.getKey()).orElse(AggregateTypeEnum.NONE);
.map(Map.Entry::getKey).orElse(AggregateTypeEnum.NONE);
String detectWord = aggregateWord.get(type);
return new AggregateConf(type, detectWord);
}
@AllArgsConstructor
class AggregateConf {
static class AggregateConf {
public AggregateTypeEnum type;
public String detectWord;
}

View File

@@ -17,8 +17,8 @@ import java.util.List;
@Slf4j
public class RuleSqlParser implements SemanticParser {
private static List<SemanticParser> auxiliaryParsers = Arrays.asList(new ContextInheritParser(),
new TimeRangeParser(), new AggregateTypeParser());
private static final List<SemanticParser> auxiliaryParsers = Arrays
.asList(new ContextInheritParser(), new TimeRangeParser(), new AggregateTypeParser());
@Override
public void parse(ChatQueryContext chatQueryContext) {
@@ -38,6 +38,6 @@ public class RuleSqlParser implements SemanticParser {
}
}
auxiliaryParsers.stream().forEach(p -> p.parse(chatQueryContext));
auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
}
}