mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-04-30 04:54:25 +08:00
Compare commits
4 Commits
6c7051535f
...
1f1367f4a8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1f1367f4a8 | ||
|
|
004d802f76 | ||
|
|
4081bd6c80 | ||
|
|
e6598a79bb |
@@ -261,6 +261,11 @@
|
|||||||
<version>${mockito-inline.version}</version>
|
<version>${mockito-inline.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.microsoft.onnxruntime</groupId>
|
||||||
|
<artifactId>onnxruntime</artifactId>
|
||||||
|
<version>1.21.0</version>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
</project>
|
</project>
|
||||||
|
|||||||
@@ -33,15 +33,13 @@ public class PromptHelper {
|
|||||||
|
|
||||||
public List<List<Text2SQLExemplar>> getFewShotExemplars(LLMReq llmReq) {
|
public List<List<Text2SQLExemplar>> getFewShotExemplars(LLMReq llmReq) {
|
||||||
int exemplarRecallNumber =
|
int exemplarRecallNumber =
|
||||||
Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
Integer.parseInt(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
||||||
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
|
int fewShotNumber = Integer.parseInt(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
|
||||||
int selfConsistencyNumber =
|
int selfConsistencyNumber =
|
||||||
Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
Integer.parseInt(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
||||||
|
|
||||||
List<Text2SQLExemplar> exemplars = Lists.newArrayList();
|
List<Text2SQLExemplar> exemplars = Lists.newArrayList();
|
||||||
llmReq.getDynamicExemplars().stream().forEach(e -> {
|
exemplars.addAll(llmReq.getDynamicExemplars());
|
||||||
exemplars.add(e);
|
|
||||||
});
|
|
||||||
|
|
||||||
int recallSize = exemplarRecallNumber - llmReq.getDynamicExemplars().size();
|
int recallSize = exemplarRecallNumber - llmReq.getDynamicExemplars().size();
|
||||||
if (recallSize > 0) {
|
if (recallSize > 0) {
|
||||||
@@ -85,60 +83,63 @@ public class PromptHelper {
|
|||||||
String tableStr = llmReq.getSchema().getDataSetName();
|
String tableStr = llmReq.getSchema().getDataSetName();
|
||||||
|
|
||||||
List<String> metrics = Lists.newArrayList();
|
List<String> metrics = Lists.newArrayList();
|
||||||
llmReq.getSchema().getMetrics().stream().forEach(metric -> {
|
llmReq.getSchema().getMetrics().forEach(metric -> {
|
||||||
StringBuilder metricStr = new StringBuilder();
|
StringBuilder metricStr = new StringBuilder();
|
||||||
metricStr.append("<");
|
metricStr.append("<");
|
||||||
metricStr.append(metric.getName());
|
metricStr.append(metric.getName());
|
||||||
if (!CollectionUtils.isEmpty(metric.getAlias())) {
|
if (!CollectionUtils.isEmpty(metric.getAlias())) {
|
||||||
StringBuilder alias = new StringBuilder();
|
StringBuilder alias = new StringBuilder();
|
||||||
metric.getAlias().stream().forEach(a -> alias.append(a + ","));
|
metric.getAlias().forEach(a -> alias.append(a).append(","));
|
||||||
metricStr.append(" ALIAS '" + alias + "'");
|
metricStr.append(" ALIAS '").append(alias).append("'");
|
||||||
}
|
}
|
||||||
if (StringUtils.isNotEmpty(metric.getDataFormatType())) {
|
if (StringUtils.isNotEmpty(metric.getDataFormatType())) {
|
||||||
String dataFormatType = metric.getDataFormatType();
|
String dataFormatType = metric.getDataFormatType();
|
||||||
if (DataFormatTypeEnum.DECIMAL.getName().equalsIgnoreCase(dataFormatType)
|
if (DataFormatTypeEnum.DECIMAL.getName().equalsIgnoreCase(dataFormatType)
|
||||||
|| DataFormatTypeEnum.PERCENT.getName().equalsIgnoreCase(dataFormatType)) {
|
|| DataFormatTypeEnum.PERCENT.getName().equalsIgnoreCase(dataFormatType)) {
|
||||||
metricStr.append(" FORMAT '" + dataFormatType + "'");
|
metricStr.append(" FORMAT '").append(dataFormatType).append("'");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (StringUtils.isNotEmpty(metric.getDescription())) {
|
if (StringUtils.isNotEmpty(metric.getDescription())) {
|
||||||
metricStr.append(" COMMENT '" + metric.getDescription() + "'");
|
metricStr.append(" COMMENT '").append(metric.getDescription()).append("'");
|
||||||
}
|
}
|
||||||
if (StringUtils.isNotEmpty(metric.getDefaultAgg())) {
|
if (StringUtils.isNotEmpty(metric.getDefaultAgg())) {
|
||||||
metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'");
|
metricStr.append(" AGGREGATE '").append(metric.getDefaultAgg().toUpperCase()).append("'");
|
||||||
}
|
}
|
||||||
metricStr.append(">");
|
metricStr.append(">");
|
||||||
metrics.add(metricStr.toString());
|
metrics.add(metricStr.toString());
|
||||||
});
|
});
|
||||||
|
|
||||||
List<String> dimensions = Lists.newArrayList();
|
List<String> dimensions = Lists.newArrayList();
|
||||||
llmReq.getSchema().getDimensions().stream().forEach(dimension -> {
|
llmReq.getSchema().getDimensions().forEach(dimension -> {
|
||||||
StringBuilder dimensionStr = new StringBuilder();
|
StringBuilder dimensionStr = new StringBuilder();
|
||||||
dimensionStr.append("<");
|
dimensionStr.append("<");
|
||||||
dimensionStr.append(dimension.getName());
|
dimensionStr.append(dimension.getName());
|
||||||
if (!CollectionUtils.isEmpty(dimension.getAlias())) {
|
if (!CollectionUtils.isEmpty(dimension.getAlias())) {
|
||||||
StringBuilder alias = new StringBuilder();
|
StringBuilder alias = new StringBuilder();
|
||||||
dimension.getAlias().stream().forEach(a -> alias.append(a + ";"));
|
dimension.getAlias().forEach(a -> alias.append(a).append(";"));
|
||||||
dimensionStr.append(" ALIAS '" + alias + "'");
|
dimensionStr.append(" ALIAS '").append(alias).append("'");
|
||||||
}
|
}
|
||||||
if (StringUtils.isNotEmpty(dimension.getTimeFormat())) {
|
if (StringUtils.isNotEmpty(dimension.getTimeFormat())) {
|
||||||
dimensionStr.append(" FORMAT '" + dimension.getTimeFormat() + "'");
|
dimensionStr.append(" FORMAT '").append(dimension.getTimeFormat()).append("'");
|
||||||
}
|
}
|
||||||
if (StringUtils.isNotEmpty(dimension.getDescription())) {
|
if (StringUtils.isNotEmpty(dimension.getDescription())) {
|
||||||
dimensionStr.append(" COMMENT '" + dimension.getDescription() + "'");
|
dimensionStr.append(" COMMENT '").append(dimension.getDescription()).append("'");
|
||||||
}
|
}
|
||||||
dimensionStr.append(">");
|
dimensionStr.append(">");
|
||||||
dimensions.add(dimensionStr.toString());
|
dimensions.add(dimensionStr.toString());
|
||||||
});
|
});
|
||||||
|
|
||||||
List<String> values = Lists.newArrayList();
|
List<String> values = Lists.newArrayList();
|
||||||
llmReq.getSchema().getValues().stream().forEach(value -> {
|
List<LLMReq.ElementValue> elementValueList = llmReq.getSchema().getValues();
|
||||||
StringBuilder valueStr = new StringBuilder();
|
if (elementValueList != null) {
|
||||||
String fieldName = value.getFieldName();
|
elementValueList.forEach(value -> {
|
||||||
String fieldValue = value.getFieldValue();
|
StringBuilder valueStr = new StringBuilder();
|
||||||
valueStr.append(String.format("<%s='%s'>", fieldName, fieldValue));
|
String fieldName = value.getFieldName();
|
||||||
values.add(valueStr.toString());
|
String fieldValue = value.getFieldValue();
|
||||||
});
|
valueStr.append(String.format("<%s='%s'>", fieldName, fieldValue));
|
||||||
|
values.add(valueStr.toString());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
String partitionTimeStr = "";
|
String partitionTimeStr = "";
|
||||||
if (llmReq.getSchema().getPartitionTime() != null) {
|
if (llmReq.getSchema().getPartitionTime() != null) {
|
||||||
@@ -172,14 +173,14 @@ public class PromptHelper {
|
|||||||
private String buildTermStr(LLMReq llmReq) {
|
private String buildTermStr(LLMReq llmReq) {
|
||||||
List<LLMReq.Term> terms = llmReq.getTerms();
|
List<LLMReq.Term> terms = llmReq.getTerms();
|
||||||
List<String> termStr = Lists.newArrayList();
|
List<String> termStr = Lists.newArrayList();
|
||||||
terms.stream().forEach(term -> {
|
terms.forEach(term -> {
|
||||||
StringBuilder termsDesc = new StringBuilder();
|
StringBuilder termsDesc = new StringBuilder();
|
||||||
String description = term.getDescription();
|
String description = term.getDescription();
|
||||||
termsDesc.append(String.format("<%s COMMENT '%s'>", term.getName(), description));
|
termsDesc.append(String.format("<%s COMMENT '%s'>", term.getName(), description));
|
||||||
termStr.add(termsDesc.toString());
|
termStr.add(termsDesc.toString());
|
||||||
});
|
});
|
||||||
String ret = "";
|
String ret = "";
|
||||||
if (termStr.size() > 0) {
|
if (!termStr.isEmpty()) {
|
||||||
ret = String.join(",", termStr);
|
ret = String.join(",", termStr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -633,6 +633,9 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
|
|||||||
|
|
||||||
private DataItem getDataItem(MetricDO metricDO) {
|
private DataItem getDataItem(MetricDO metricDO) {
|
||||||
ModelResp modelResp = modelService.getModel(metricDO.getModelId());
|
ModelResp modelResp = modelService.getModel(metricDO.getModelId());
|
||||||
|
if (modelResp == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
MetricResp metricResp = MetricConverter.convert2MetricResp(metricDO,
|
MetricResp metricResp = MetricConverter.convert2MetricResp(metricDO,
|
||||||
ImmutableMap.of(modelResp.getId(), modelResp), Lists.newArrayList());
|
ImmutableMap.of(modelResp.getId(), modelResp), Lists.newArrayList());
|
||||||
fillDefaultAgg(metricResp);
|
fillDefaultAgg(metricResp);
|
||||||
|
|||||||
@@ -60,7 +60,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-jar-plugin</artifactId>
|
<artifactId>maven-jar-plugin</artifactId>
|
||||||
<version>2.4</version>
|
<version>2.6</version>
|
||||||
<configuration>
|
<configuration>
|
||||||
<excludes>
|
<excludes>
|
||||||
<exclude>*.*</exclude>
|
<exclude>*.*</exclude>
|
||||||
@@ -70,7 +70,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-assembly-plugin</artifactId>
|
<artifactId>maven-assembly-plugin</artifactId>
|
||||||
<version>2.4</version>
|
<version>2.6</version>
|
||||||
<configuration>
|
<configuration>
|
||||||
<tarLongFileMode>gnu</tarLongFileMode>
|
<tarLongFileMode>gnu</tarLongFileMode>
|
||||||
<skipAssembly>false</skipAssembly>
|
<skipAssembly>false</skipAssembly>
|
||||||
|
|||||||
@@ -71,7 +71,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-jar-plugin</artifactId>
|
<artifactId>maven-jar-plugin</artifactId>
|
||||||
<version>2.4</version>
|
<version>2.6</version>
|
||||||
<configuration>
|
<configuration>
|
||||||
<excludes>
|
<excludes>
|
||||||
<exclude>*.*</exclude>
|
<exclude>*.*</exclude>
|
||||||
@@ -81,7 +81,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-assembly-plugin</artifactId>
|
<artifactId>maven-assembly-plugin</artifactId>
|
||||||
<version>2.4</version>
|
<version>2.6</version>
|
||||||
<configuration>
|
<configuration>
|
||||||
<tarLongFileMode>gnu</tarLongFileMode>
|
<tarLongFileMode>gnu</tarLongFileMode>
|
||||||
<skipAssembly>false</skipAssembly>
|
<skipAssembly>false</skipAssembly>
|
||||||
|
|||||||
@@ -149,7 +149,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-jar-plugin</artifactId>
|
<artifactId>maven-jar-plugin</artifactId>
|
||||||
<version>2.4</version>
|
<version>2.6</version>
|
||||||
<configuration>
|
<configuration>
|
||||||
<excludes>
|
<excludes>
|
||||||
<exclude>*.*</exclude>
|
<exclude>*.*</exclude>
|
||||||
@@ -159,7 +159,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-assembly-plugin</artifactId>
|
<artifactId>maven-assembly-plugin</artifactId>
|
||||||
<version>2.4</version>
|
<version>2.6</version>
|
||||||
<configuration>
|
<configuration>
|
||||||
<tarLongFileMode>gnu</tarLongFileMode>
|
<tarLongFileMode>gnu</tarLongFileMode>
|
||||||
<skipAssembly>false</skipAssembly>
|
<skipAssembly>false</skipAssembly>
|
||||||
|
|||||||
2
pom.xml
2
pom.xml
@@ -323,7 +323,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.sonarsource.scanner.maven</groupId>
|
<groupId>org.sonarsource.scanner.maven</groupId>
|
||||||
<artifactId>sonar-maven-plugin</artifactId>
|
<artifactId>sonar-maven-plugin</artifactId>
|
||||||
<version>3.6.0.1398</version>
|
<version>3.6.1.1688</version>
|
||||||
</plugin>
|
</plugin>
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
|||||||
Reference in New Issue
Block a user