mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-04-19 04:44:19 +08:00
[improvement][headless]Optimize code structure and code style.
This commit is contained in:
@@ -44,6 +44,7 @@ public class ChatQueryContext {
|
||||
private QueryFilters queryFilters;
|
||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private SemanticParseInfo contextParseInfo;
|
||||
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
||||
@JsonIgnore
|
||||
private SemanticSchema semanticSchema;
|
||||
@@ -53,7 +54,6 @@ public class ChatQueryContext {
|
||||
private ChatModelConfig modelConfig;
|
||||
private PromptConfig promptConfig;
|
||||
private List<Text2SQLExemplar> dynamicExemplars;
|
||||
private SemanticParseInfo contextParseInfo;
|
||||
|
||||
public List<SemanticQuery> getCandidateQueries() {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
|
||||
@@ -51,7 +51,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
// support fieldName and field alias
|
||||
Map<String, String> result = dbAllFields.stream()
|
||||
.filter(entry -> dataSetId.equals(entry.getDataSet()))
|
||||
.filter(entry -> dataSetId.equals(entry.getDataSetId()))
|
||||
.flatMap(schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
@@ -75,7 +75,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
protected void addAggregateToMetric(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
//add aggregate to all metric
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
Long dataSetId = semanticParseInfo.getDataSet().getDataSet();
|
||||
Long dataSetId = semanticParseInfo.getDataSet().getDataSetId();
|
||||
List<SchemaElement> metrics = getMetricElements(chatQueryContext, dataSetId);
|
||||
|
||||
Map<String, String> metricToAggregate = metrics.stream()
|
||||
|
||||
@@ -39,7 +39,7 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
private void addHaving(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Long dataSet = semanticParseInfo.getDataSet().getDataSet();
|
||||
Long dataSet = semanticParseInfo.getDataSet().getDataSetId();
|
||||
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ public class ModelWordBuilder extends BaseWordWithAliasBuilder {
|
||||
public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
String nature = DictWordType.NATURE_SPILT + schemaElement.getDataSet();
|
||||
String nature = DictWordType.NATURE_SPILT + schemaElement.getDataSetId();
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ public class TermWordBuilder extends BaseWordWithAliasBuilder {
|
||||
public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
Long dataSet = schemaElement.getDataSet();
|
||||
Long dataSet = schemaElement.getDataSetId();
|
||||
String nature = DictWordType.NATURE_SPILT + dataSet + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.TERM.getType();
|
||||
if (isSuffix) {
|
||||
|
||||
@@ -67,7 +67,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
Set<SchemaElement> schemaElements = entry.getValue();
|
||||
if (!CollectionUtils.isEmpty(detectDataSetIds)) {
|
||||
schemaElements = schemaElements.stream()
|
||||
.filter(schemaElement -> detectDataSetIds.contains(schemaElement.getDataSet()))
|
||||
.filter(schemaElement -> detectDataSetIds.contains(schemaElement.getDataSetId()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
|
||||
@@ -103,12 +103,12 @@ public class KeywordMapper extends BaseMapper {
|
||||
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
||||
.build();
|
||||
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||
addToSchemaMap(chatQueryContext.getMapInfo(), schemaElement.getDataSet(), schemaElementMatch);
|
||||
addToSchemaMap(chatQueryContext.getMapInfo(), schemaElement.getDataSetId(), schemaElementMatch);
|
||||
}
|
||||
}
|
||||
|
||||
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getDataSet());
|
||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getDataSetId());
|
||||
if (CollectionUtils.isEmpty(elements)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ public class QueryFilterMapper extends BaseMapper {
|
||||
.name(String.valueOf(filter.getValue()))
|
||||
.type(SchemaElementType.VALUE)
|
||||
.bizName(filter.getBizName())
|
||||
.dataSet(dataSetId)
|
||||
.dataSetId(dataSetId)
|
||||
.build();
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
|
||||
@@ -95,7 +95,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
|
||||
private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) {
|
||||
Set<Long> dataSetIds = parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement)
|
||||
.map(SchemaElement::getDataSet).collect(Collectors.toSet());
|
||||
.map(SchemaElement::getDataSetId).collect(Collectors.toSet());
|
||||
Long dataSetId = dataSetIds.iterator().next();
|
||||
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId));
|
||||
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();
|
||||
|
||||
@@ -24,7 +24,7 @@ class AggCorrectorTest {
|
||||
ChatQueryContext chatQueryContext = buildQueryContext(dataSetId);
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SchemaElement dataSet = new SchemaElement();
|
||||
dataSet.setDataSet(dataSetId);
|
||||
dataSet.setDataSetId(dataSetId);
|
||||
semanticParseInfo.setDataSet(dataSet);
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
String sql = "SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND"
|
||||
@@ -47,11 +47,11 @@ class AggCorrectorTest {
|
||||
QueryConfig queryConfig = new QueryConfig();
|
||||
dataSetSchema.setQueryConfig(queryConfig);
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setDataSet(dataSetId);
|
||||
schemaElement.setDataSetId(dataSetId);
|
||||
dataSetSchema.setDataSet(schemaElement);
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
SchemaElement element1 = new SchemaElement();
|
||||
element1.setDataSet(1L);
|
||||
element1.setDataSetId(1L);
|
||||
element1.setName("部门");
|
||||
dimensions.add(element1);
|
||||
|
||||
@@ -59,7 +59,7 @@ class AggCorrectorTest {
|
||||
|
||||
Set<SchemaElement> metrics = new HashSet<>();
|
||||
SchemaElement metric1 = new SchemaElement();
|
||||
metric1.setDataSet(1L);
|
||||
metric1.setDataSetId(1L);
|
||||
metric1.setName("访问次数");
|
||||
metrics.add(metric1);
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ class SchemaCorrectorTest {
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setDataSet(dataSetId);
|
||||
schemaElement.setDataSetId(dataSetId);
|
||||
semanticParseInfo.setDataSet(schemaElement);
|
||||
|
||||
|
||||
@@ -107,21 +107,21 @@ class SchemaCorrectorTest {
|
||||
QueryConfig queryConfig = new QueryConfig();
|
||||
dataSetSchema.setQueryConfig(queryConfig);
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setDataSet(dataSetId);
|
||||
schemaElement.setDataSetId(dataSetId);
|
||||
dataSetSchema.setDataSet(schemaElement);
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
SchemaElement element1 = new SchemaElement();
|
||||
element1.setDataSet(1L);
|
||||
element1.setDataSetId(1L);
|
||||
element1.setName("歌曲名");
|
||||
dimensions.add(element1);
|
||||
|
||||
SchemaElement element2 = new SchemaElement();
|
||||
element2.setDataSet(1L);
|
||||
element2.setDataSetId(1L);
|
||||
element2.setName("商务组");
|
||||
dimensions.add(element2);
|
||||
|
||||
SchemaElement element3 = new SchemaElement();
|
||||
element3.setDataSet(1L);
|
||||
element3.setDataSetId(1L);
|
||||
element3.setName("发行日期");
|
||||
dimensions.add(element3);
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ class SelectCorrectorTest {
|
||||
ChatQueryContext chatQueryContext = buildQueryContext(dataSetId);
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SchemaElement dataSet = new SchemaElement();
|
||||
dataSet.setDataSet(dataSetId);
|
||||
dataSet.setDataSetId(dataSetId);
|
||||
semanticParseInfo.setDataSet(dataSet);
|
||||
semanticParseInfo.setQueryType(QueryType.DETAIL);
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
@@ -74,23 +74,23 @@ class SelectCorrectorTest {
|
||||
|
||||
dataSetSchema.setQueryConfig(queryConfig);
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setDataSet(dataSetId);
|
||||
schemaElement.setDataSetId(dataSetId);
|
||||
dataSetSchema.setDataSet(schemaElement);
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
SchemaElement element1 = new SchemaElement();
|
||||
element1.setDataSet(dataSetId);
|
||||
element1.setDataSetId(dataSetId);
|
||||
element1.setId(1L);
|
||||
element1.setName("艺人名");
|
||||
dimensions.add(element1);
|
||||
|
||||
SchemaElement element2 = new SchemaElement();
|
||||
element2.setDataSet(dataSetId);
|
||||
element2.setDataSetId(dataSetId);
|
||||
element2.setId(2L);
|
||||
element2.setName("性别");
|
||||
dimensions.add(element2);
|
||||
|
||||
SchemaElement element3 = new SchemaElement();
|
||||
element3.setDataSet(dataSetId);
|
||||
element3.setDataSetId(dataSetId);
|
||||
element3.setId(3L);
|
||||
element3.setName("国籍");
|
||||
dimensions.add(element3);
|
||||
@@ -99,7 +99,7 @@ class SelectCorrectorTest {
|
||||
|
||||
Set<SchemaElement> metrics = new HashSet<>();
|
||||
SchemaElement metric1 = new SchemaElement();
|
||||
metric1.setDataSet(dataSetId);
|
||||
metric1.setDataSetId(dataSetId);
|
||||
metric1.setId(4L);
|
||||
metric1.setName("粉丝数");
|
||||
metrics.add(metric1);
|
||||
|
||||
@@ -29,7 +29,7 @@ class LLMSqlParserTest {
|
||||
SchemaElement schemaElement = SchemaElement.builder()
|
||||
.bizName("singer_name")
|
||||
.name("歌手名")
|
||||
.dataSet(2L)
|
||||
.dataSetId(2L)
|
||||
.schemaValueMaps(schemaValueMaps)
|
||||
.build();
|
||||
dimensions.add(schemaElement);
|
||||
@@ -37,7 +37,7 @@ class LLMSqlParserTest {
|
||||
SchemaElement schemaElement2 = SchemaElement.builder()
|
||||
.bizName("publish_time")
|
||||
.name("发布时间")
|
||||
.dataSet(2L)
|
||||
.dataSetId(2L)
|
||||
.build();
|
||||
dimensions.add(schemaElement2);
|
||||
|
||||
@@ -47,7 +47,7 @@ class LLMSqlParserTest {
|
||||
SchemaElement metric = SchemaElement.builder()
|
||||
.bizName("play_count")
|
||||
.name("播放量")
|
||||
.dataSet(2L)
|
||||
.dataSetId(2L)
|
||||
.build();
|
||||
metrics.add(metric);
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ class S2SqlDateHelperTest {
|
||||
QueryConfig queryConfig = new QueryConfig();
|
||||
dataSetSchema.setQueryConfig(queryConfig);
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setDataSet(dataSetId);
|
||||
schemaElement.setDataSetId(dataSetId);
|
||||
dataSetSchema.setDataSet(schemaElement);
|
||||
dataSetSchemaList.add(dataSetSchema);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user