[improvement][headless]Optimize code structure and code style.

This commit is contained in:
jerryjzhang
2024-07-27 18:29:08 +08:00
parent e5504473a4
commit ccd79e4830
55 changed files with 138 additions and 168 deletions

View File

@@ -44,6 +44,7 @@ public class ChatQueryContext {
private QueryFilters queryFilters;
private List<SemanticQuery> candidateQueries = new ArrayList<>();
private SchemaMapInfo mapInfo = new SchemaMapInfo();
private SemanticParseInfo contextParseInfo;
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
@JsonIgnore
private SemanticSchema semanticSchema;
@@ -53,7 +54,6 @@ public class ChatQueryContext {
private ChatModelConfig modelConfig;
private PromptConfig promptConfig;
private List<Text2SQLExemplar> dynamicExemplars;
private SemanticParseInfo contextParseInfo;
public List<SemanticQuery> getCandidateQueries() {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);

View File

@@ -51,7 +51,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
// support fieldName and field alias
Map<String, String> result = dbAllFields.stream()
.filter(entry -> dataSetId.equals(entry.getDataSet()))
.filter(entry -> dataSetId.equals(entry.getDataSetId()))
.flatMap(schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
@@ -75,7 +75,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
protected void addAggregateToMetric(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
Long dataSetId = semanticParseInfo.getDataSet().getDataSet();
Long dataSetId = semanticParseInfo.getDataSet().getDataSetId();
List<SchemaElement> metrics = getMetricElements(chatQueryContext, dataSetId);
Map<String, String> metricToAggregate = metrics.stream()

View File

@@ -39,7 +39,7 @@ public class HavingCorrector extends BaseSemanticCorrector {
}
private void addHaving(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
Long dataSet = semanticParseInfo.getDataSet().getDataSet();
Long dataSet = semanticParseInfo.getDataSet().getDataSetId();
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();

View File

@@ -31,7 +31,7 @@ public class ModelWordBuilder extends BaseWordWithAliasBuilder {
public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
DictWord dictWord = new DictWord();
dictWord.setWord(word);
String nature = DictWordType.NATURE_SPILT + schemaElement.getDataSet();
String nature = DictWordType.NATURE_SPILT + schemaElement.getDataSetId();
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
return dictWord;
}

View File

@@ -30,7 +30,7 @@ public class TermWordBuilder extends BaseWordWithAliasBuilder {
public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
DictWord dictWord = new DictWord();
dictWord.setWord(word);
Long dataSet = schemaElement.getDataSet();
Long dataSet = schemaElement.getDataSetId();
String nature = DictWordType.NATURE_SPILT + dataSet + DictWordType.NATURE_SPILT + schemaElement.getId()
+ DictWordType.TERM.getType();
if (isSuffix) {

View File

@@ -67,7 +67,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
Set<SchemaElement> schemaElements = entry.getValue();
if (!CollectionUtils.isEmpty(detectDataSetIds)) {
schemaElements = schemaElements.stream()
.filter(schemaElement -> detectDataSetIds.contains(schemaElement.getDataSet()))
.filter(schemaElement -> detectDataSetIds.contains(schemaElement.getDataSetId()))
.collect(Collectors.toSet());
}
for (SchemaElement schemaElement : schemaElements) {

View File

@@ -103,12 +103,12 @@ public class KeywordMapper extends BaseMapper {
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
.build();
log.info("add to schema, elementMatch {}", schemaElementMatch);
addToSchemaMap(chatQueryContext.getMapInfo(), schemaElement.getDataSet(), schemaElementMatch);
addToSchemaMap(chatQueryContext.getMapInfo(), schemaElement.getDataSetId(), schemaElementMatch);
}
}
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getDataSet());
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getDataSetId());
if (CollectionUtils.isEmpty(elements)) {
return new HashSet<>();
}

View File

@@ -64,7 +64,7 @@ public class QueryFilterMapper extends BaseMapper {
.name(String.valueOf(filter.getValue()))
.type(SchemaElementType.VALUE)
.bizName(filter.getBizName())
.dataSet(dataSetId)
.dataSetId(dataSetId)
.build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element)

View File

@@ -95,7 +95,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) {
Set<Long> dataSetIds = parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement)
.map(SchemaElement::getDataSet).collect(Collectors.toSet());
.map(SchemaElement::getDataSetId).collect(Collectors.toSet());
Long dataSetId = dataSetIds.iterator().next();
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId));
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();

View File

@@ -24,7 +24,7 @@ class AggCorrectorTest {
ChatQueryContext chatQueryContext = buildQueryContext(dataSetId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SchemaElement dataSet = new SchemaElement();
dataSet.setDataSet(dataSetId);
dataSet.setDataSetId(dataSetId);
semanticParseInfo.setDataSet(dataSet);
SqlInfo sqlInfo = new SqlInfo();
String sql = "SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND"
@@ -47,11 +47,11 @@ class AggCorrectorTest {
QueryConfig queryConfig = new QueryConfig();
dataSetSchema.setQueryConfig(queryConfig);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSet(dataSetId);
schemaElement.setDataSetId(dataSetId);
dataSetSchema.setDataSet(schemaElement);
Set<SchemaElement> dimensions = new HashSet<>();
SchemaElement element1 = new SchemaElement();
element1.setDataSet(1L);
element1.setDataSetId(1L);
element1.setName("部门");
dimensions.add(element1);
@@ -59,7 +59,7 @@ class AggCorrectorTest {
Set<SchemaElement> metrics = new HashSet<>();
SchemaElement metric1 = new SchemaElement();
metric1.setDataSet(1L);
metric1.setDataSetId(1L);
metric1.setName("访问次数");
metrics.add(metric1);

View File

@@ -70,7 +70,7 @@ class SchemaCorrectorTest {
semanticParseInfo.setSqlInfo(sqlInfo);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSet(dataSetId);
schemaElement.setDataSetId(dataSetId);
semanticParseInfo.setDataSet(schemaElement);
@@ -107,21 +107,21 @@ class SchemaCorrectorTest {
QueryConfig queryConfig = new QueryConfig();
dataSetSchema.setQueryConfig(queryConfig);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSet(dataSetId);
schemaElement.setDataSetId(dataSetId);
dataSetSchema.setDataSet(schemaElement);
Set<SchemaElement> dimensions = new HashSet<>();
SchemaElement element1 = new SchemaElement();
element1.setDataSet(1L);
element1.setDataSetId(1L);
element1.setName("歌曲名");
dimensions.add(element1);
SchemaElement element2 = new SchemaElement();
element2.setDataSet(1L);
element2.setDataSetId(1L);
element2.setName("商务组");
dimensions.add(element2);
SchemaElement element3 = new SchemaElement();
element3.setDataSet(1L);
element3.setDataSetId(1L);
element3.setName("发行日期");
dimensions.add(element3);

View File

@@ -39,7 +39,7 @@ class SelectCorrectorTest {
ChatQueryContext chatQueryContext = buildQueryContext(dataSetId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SchemaElement dataSet = new SchemaElement();
dataSet.setDataSet(dataSetId);
dataSet.setDataSetId(dataSetId);
semanticParseInfo.setDataSet(dataSet);
semanticParseInfo.setQueryType(QueryType.DETAIL);
SqlInfo sqlInfo = new SqlInfo();
@@ -74,23 +74,23 @@ class SelectCorrectorTest {
dataSetSchema.setQueryConfig(queryConfig);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSet(dataSetId);
schemaElement.setDataSetId(dataSetId);
dataSetSchema.setDataSet(schemaElement);
Set<SchemaElement> dimensions = new HashSet<>();
SchemaElement element1 = new SchemaElement();
element1.setDataSet(dataSetId);
element1.setDataSetId(dataSetId);
element1.setId(1L);
element1.setName("艺人名");
dimensions.add(element1);
SchemaElement element2 = new SchemaElement();
element2.setDataSet(dataSetId);
element2.setDataSetId(dataSetId);
element2.setId(2L);
element2.setName("性别");
dimensions.add(element2);
SchemaElement element3 = new SchemaElement();
element3.setDataSet(dataSetId);
element3.setDataSetId(dataSetId);
element3.setId(3L);
element3.setName("国籍");
dimensions.add(element3);
@@ -99,7 +99,7 @@ class SelectCorrectorTest {
Set<SchemaElement> metrics = new HashSet<>();
SchemaElement metric1 = new SchemaElement();
metric1.setDataSet(dataSetId);
metric1.setDataSetId(dataSetId);
metric1.setId(4L);
metric1.setName("粉丝数");
metrics.add(metric1);

View File

@@ -29,7 +29,7 @@ class LLMSqlParserTest {
SchemaElement schemaElement = SchemaElement.builder()
.bizName("singer_name")
.name("歌手名")
.dataSet(2L)
.dataSetId(2L)
.schemaValueMaps(schemaValueMaps)
.build();
dimensions.add(schemaElement);
@@ -37,7 +37,7 @@ class LLMSqlParserTest {
SchemaElement schemaElement2 = SchemaElement.builder()
.bizName("publish_time")
.name("发布时间")
.dataSet(2L)
.dataSetId(2L)
.build();
dimensions.add(schemaElement2);
@@ -47,7 +47,7 @@ class LLMSqlParserTest {
SchemaElement metric = SchemaElement.builder()
.bizName("play_count")
.name("播放量")
.dataSet(2L)
.dataSetId(2L)
.build();
metrics.add(metric);

View File

@@ -114,7 +114,7 @@ class S2SqlDateHelperTest {
QueryConfig queryConfig = new QueryConfig();
dataSetSchema.setQueryConfig(queryConfig);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSet(dataSetId);
schemaElement.setDataSetId(dataSetId);
dataSetSchema.setDataSet(schemaElement);
dataSetSchemaList.add(dataSetSchema);