[improvement][project]Opt some code structures.

This commit is contained in:
jerryjzhang
2024-10-27 10:38:30 +08:00
parent 3a905d7fb1
commit 1e3daffade
9 changed files with 35 additions and 48 deletions

View File

@@ -40,13 +40,15 @@ public class ChatQueryContext {
private Map<Long, List<Long>> modelIdToDataSetIds; private Map<Long, List<Long>> modelIdToDataSetIds;
private User user; private User user;
private boolean saveAnswer; private boolean saveAnswer;
@Builder.Default
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
private QueryFilters queryFilters; private QueryFilters queryFilters;
private List<SemanticQuery> candidateQueries = new ArrayList<>(); private List<SemanticQuery> candidateQueries = new ArrayList<>();
private SchemaMapInfo mapInfo = new SchemaMapInfo(); private SchemaMapInfo mapInfo = new SchemaMapInfo();
private SemanticParseInfo contextParseInfo; private SemanticParseInfo contextParseInfo;
@Builder.Default
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
@Builder.Default
private MapModeEnum mapModeEnum = MapModeEnum.STRICT; private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
@Builder.Default
private QueryDataType queryDataType = QueryDataType.ALL; private QueryDataType queryDataType = QueryDataType.ALL;
@JsonIgnore @JsonIgnore
private SemanticSchema semanticSchema; private SemanticSchema semanticSchema;

View File

@@ -78,10 +78,9 @@ public class QueryTypeParser implements SemanticParser {
} }
private static List<String> filterByTimeFields(List<String> whereFields) { private static List<String> filterByTimeFields(List<String> whereFields) {
List<String> selectAndWhereFilterByTimeFields = whereFields.stream() return whereFields.stream()
.filter(field -> !TimeDimensionEnum.containsTimeDimension(field)) .filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
.collect(Collectors.toList()); .collect(Collectors.toList());
return selectAndWhereFilterByTimeFields;
} }
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId,

View File

@@ -65,7 +65,7 @@ public class LLMRequestService {
llmSchema.setPrimaryKey(getPrimaryKey(queryCtx, dataSetId)); llmSchema.setPrimaryKey(getPrimaryKey(queryCtx, dataSetId));
boolean linkingValueEnabled = boolean linkingValueEnabled =
Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE)); Boolean.parseBoolean(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
if (linkingValueEnabled) { if (linkingValueEnabled) {
llmSchema.setValues(getMappedValues(queryCtx, dataSetId)); llmSchema.setValues(getMappedValues(queryCtx, dataSetId));
} }
@@ -135,13 +135,10 @@ public class LLMRequestService {
if (CollectionUtils.isEmpty(matchedElements)) { if (CollectionUtils.isEmpty(matchedElements)) {
return Collections.emptyList(); return Collections.emptyList();
} }
List<SchemaElement> schemaElements = matchedElements.stream().filter(schemaElementMatch -> { return matchedElements.stream().filter(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType(); SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.METRIC.equals(elementType); return SchemaElementType.METRIC.equals(elementType);
}).map(schemaElementMatch -> { }).map(SchemaElementMatch::getElement).collect(Collectors.toList());
return schemaElementMatch.getElement();
}).collect(Collectors.toList());
return schemaElements;
} }
protected List<SchemaElement> getMappedDimensions(@NotNull ChatQueryContext queryCtx, protected List<SchemaElement> getMappedDimensions(@NotNull ChatQueryContext queryCtx,

View File

@@ -23,8 +23,8 @@ import java.util.Objects;
@Service @Service
public class LLMResponseService { public class LLMResponseService {
public SemanticParseInfo addParseInfo(ChatQueryContext queryCtx, ParseResult parseResult, public void addParseInfo(ChatQueryContext queryCtx, ParseResult parseResult, String s2SQL,
String s2SQL, Double weight) { Double weight) {
if (Objects.isNull(weight)) { if (Objects.isNull(weight)) {
weight = 0D; weight = 0D;
} }
@@ -49,7 +49,6 @@ public class LLMResponseService {
parseInfo.setQueryMode(semanticQuery.getQueryMode()); parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL); parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
queryCtx.getCandidateQueries().add(semanticQuery); queryCtx.getCandidateQueries().add(semanticQuery);
return parseInfo;
} }
public Map<String, LLMSqlResp> getDeduplicationSqlResp(int currentRetry, LLMResp llmResp) { public Map<String, LLMSqlResp> getDeduplicationSqlResp(int currentRetry, LLMResp llmResp) {

View File

@@ -31,22 +31,22 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
public static final String APP_KEY = "S2SQL_PARSER"; public static final String APP_KEY = "S2SQL_PARSER";
public static final String INSTRUCTION = "" public static final String INSTRUCTION =
+ "#Role: You are a data analyst experienced in SQL languages." "#Role: You are a data analyst experienced in SQL languages."
+ "\n#Task: You will be provided with a natural language question asked by users," + "\n#Task: You will be provided with a natural language question asked by users,"
+ "please convert it to a SQL query so that relevant data could be returned " + "please convert it to a SQL query so that relevant data could be returned "
+ "by executing the SQL query against underlying database." + "\n#Rules:" + "by executing the SQL query against underlying database." + "\n#Rules:"
+ "\n1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate." + "\n1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate."
+ "\n2.ALWAYS be cautious, word in the `Schema` does not mean it must appear in the SQL." + "\n2.ALWAYS be cautious, word in the `Schema` does not mean it must appear in the SQL."
+ "\n3.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator." + "\n3.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
+ "\n4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`." + "\n4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
+ "\n5.DO NOT calculate date range using functions." + "\n5.DO NOT calculate date range using functions."
+ "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." + "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
+ "\n7.ALWAYS use `with` statement if nested aggregation is needed." + "\n7.ALWAYS use `with` statement if nested aggregation is needed."
+ "\n8.ALWAYS enclose alias created by `AS` command in underscores." + "\n8.ALWAYS enclose alias created by `AS` command in underscores."
+ "\n9.ALWAYS translate alias created by `AS` command to the same language as the `#Question`." + "\n9.ALWAYS translate alias created by `AS` command to the same language as the `#Question`."
+ "\n#Exemplars: {{exemplar}}" + "\n#Exemplars: {{exemplar}}"
+ "\n#Query: Question:{{question}},Schema:{{schema}},SideInfo:{{information}}"; + "\n#Query: Question:{{question}},Schema:{{schema}},SideInfo:{{information}}";
public OnePassSCSqlGenStrategy() { public OnePassSCSqlGenStrategy() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL解析") ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL解析")

View File

@@ -86,13 +86,13 @@ public class AggregateTypeParser implements SemanticParser {
AggregateTypeEnum type = AggregateTypeEnum type =
aggregateCount.entrySet().stream().max(Map.Entry.comparingByValue()) aggregateCount.entrySet().stream().max(Map.Entry.comparingByValue())
.map(entry -> entry.getKey()).orElse(AggregateTypeEnum.NONE); .map(Map.Entry::getKey).orElse(AggregateTypeEnum.NONE);
String detectWord = aggregateWord.get(type); String detectWord = aggregateWord.get(type);
return new AggregateConf(type, detectWord); return new AggregateConf(type, detectWord);
} }
@AllArgsConstructor @AllArgsConstructor
class AggregateConf { static class AggregateConf {
public AggregateTypeEnum type; public AggregateTypeEnum type;
public String detectWord; public String detectWord;
} }

View File

@@ -17,8 +17,8 @@ import java.util.List;
@Slf4j @Slf4j
public class RuleSqlParser implements SemanticParser { public class RuleSqlParser implements SemanticParser {
private static List<SemanticParser> auxiliaryParsers = Arrays.asList(new ContextInheritParser(), private static final List<SemanticParser> auxiliaryParsers = Arrays
new TimeRangeParser(), new AggregateTypeParser()); .asList(new ContextInheritParser(), new TimeRangeParser(), new AggregateTypeParser());
@Override @Override
public void parse(ChatQueryContext chatQueryContext) { public void parse(ChatQueryContext chatQueryContext) {
@@ -38,6 +38,6 @@ public class RuleSqlParser implements SemanticParser {
} }
} }
auxiliaryParsers.stream().forEach(p -> p.parse(chatQueryContext)); auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
} }
} }

View File

@@ -66,9 +66,7 @@ public class S2ChatLayerService implements ChatLayerService {
public MapResp map(QueryNLReq queryNLReq) { public MapResp map(QueryNLReq queryNLReq) {
MapResp mapResp = new MapResp(queryNLReq.getQueryText()); MapResp mapResp = new MapResp(queryNLReq.getQueryText());
ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq); ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq);
ComponentFactory.getSchemaMappers().forEach(mapper -> { ComponentFactory.getSchemaMappers().forEach(mapper -> mapper.map(queryCtx));
mapper.map(queryCtx);
});
mapResp.setMapInfo(queryCtx.getMapInfo()); mapResp.setMapInfo(queryCtx.getMapInfo());
return mapResp; return mapResp;
} }
@@ -264,22 +262,16 @@ public class S2ChatLayerService implements ChatLayerService {
/** /**
* * get time dimension SchemaElementMatch * * get time dimension SchemaElementMatch
*
* @param dataSetId
* @param dataSetName
* @return
*/ */
private SchemaElementMatch getTimeDimension(Long dataSetId, String dataSetName) { private SchemaElementMatch getTimeDimension(Long dataSetId, String dataSetName) {
SchemaElement element = SchemaElement.builder().dataSetId(dataSetId) SchemaElement element = SchemaElement.builder().dataSetId(dataSetId)
.dataSetName(dataSetName).type(SchemaElementType.DIMENSION) .dataSetName(dataSetName).type(SchemaElementType.DIMENSION)
.bizName(TimeDimensionEnum.DAY.getName()).build(); .bizName(TimeDimensionEnum.DAY.getName()).build();
SchemaElementMatch timeDimensionMatch = SchemaElementMatch.builder().element(element) return SchemaElementMatch.builder().element(element)
.detectWord(TimeDimensionEnum.DAY.getChName()) .detectWord(TimeDimensionEnum.DAY.getChName())
.word(TimeDimensionEnum.DAY.getChName()).similarity(1L) .word(TimeDimensionEnum.DAY.getChName()).similarity(1L)
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY).build(); .frequency(BaseWordBuilder.DEFAULT_FREQUENCY).build();
return timeDimensionMatch;
} }
private Function<SchemaElement, SchemaElementMatch> mergeFunction() { private Function<SchemaElement, SchemaElementMatch> mergeFunction() {

View File

@@ -119,9 +119,7 @@ public class ChatWorkflowEngine {
} }
private void performProcessing(ChatQueryContext queryCtx, ParseResp parseResult) { private void performProcessing(ChatQueryContext queryCtx, ParseResp parseResult) {
resultProcessors.forEach(processor -> { resultProcessors.forEach(processor -> processor.process(parseResult, queryCtx));
processor.process(parseResult, queryCtx);
});
} }
private void performTranslating(ChatQueryContext chatQueryContext, ParseResp parseResult) { private void performTranslating(ChatQueryContext chatQueryContext, ParseResp parseResult) {
@@ -160,7 +158,7 @@ public class ChatWorkflowEngine {
} }
}); });
if (!errorMsg.isEmpty()) { if (!errorMsg.isEmpty()) {
parseResult.setErrorMsg(errorMsg.stream().collect(Collectors.joining("\n"))); parseResult.setErrorMsg(String.join("\n", errorMsg));
} }
} }
} }