mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
[improvement][chat]Introduce AllFieldMapper to increase parsing robustness when normal pipeline fails.
[improvement][chat]Introduce `AllFieldMapper` to increase parsing robustness when normal pipeline fails.
This commit is contained in:
@@ -100,10 +100,12 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
queryNLReq.setMapModeEnum(mode);
|
queryNLReq.setMapModeEnum(mode);
|
||||||
doParse(queryNLReq, parseResp);
|
doParse(queryNLReq, parseResp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parseResp.getSelectedParses().isEmpty()) {
|
if (parseResp.getSelectedParses().isEmpty()) {
|
||||||
queryNLReq.setMapModeEnum(MapModeEnum.LOOSE);
|
queryNLReq.setMapModeEnum(MapModeEnum.LOOSE);
|
||||||
doParse(queryNLReq, parseResp);
|
doParse(queryNLReq, parseResp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parseResp.getSelectedParses().isEmpty()) {
|
if (parseResp.getSelectedParses().isEmpty()) {
|
||||||
errMsg.append(parseResp.getErrorMsg());
|
errMsg.append(parseResp.getErrorMsg());
|
||||||
continue;
|
continue;
|
||||||
@@ -137,11 +139,18 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
SemanticParseInfo userSelectParse = parseContext.getRequest().getSelectedParse();
|
SemanticParseInfo userSelectParse = parseContext.getRequest().getSelectedParse();
|
||||||
queryNLReq.setSelectedParseInfo(Objects.nonNull(userSelectParse) ? userSelectParse
|
queryNLReq.setSelectedParseInfo(Objects.nonNull(userSelectParse) ? userSelectParse
|
||||||
: parseContext.getResponse().getSelectedParses().get(0));
|
: parseContext.getResponse().getSelectedParses().get(0));
|
||||||
|
|
||||||
parseContext.setResponse(new ChatParseResp(parseContext.getResponse().getQueryId()));
|
parseContext.setResponse(new ChatParseResp(parseContext.getResponse().getQueryId()));
|
||||||
|
|
||||||
rewriteMultiTurn(parseContext, queryNLReq);
|
rewriteMultiTurn(parseContext, queryNLReq);
|
||||||
addDynamicExemplars(parseContext, queryNLReq);
|
addDynamicExemplars(parseContext, queryNLReq);
|
||||||
doParse(queryNLReq, parseContext.getResponse());
|
doParse(queryNLReq, parseContext.getResponse());
|
||||||
|
|
||||||
|
// try again with all semantic fields passed to LLM
|
||||||
|
if (parseContext.getResponse().getState().equals(ParseResp.ParseState.FAILED)) {
|
||||||
|
queryNLReq.setSelectedParseInfo(null);
|
||||||
|
queryNLReq.setMapModeEnum(MapModeEnum.ALL);
|
||||||
|
doParse(queryNLReq, parseContext.getResponse());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,24 @@
|
|||||||
|
package com.tencent.supersonic.common.jsqlparser;
|
||||||
|
|
||||||
|
import net.sf.jsqlparser.expression.Alias;
|
||||||
|
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
||||||
|
import net.sf.jsqlparser.statement.select.SelectItem;
|
||||||
|
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
public class AliasAcquireVisitor extends ExpressionVisitorAdapter {
|
||||||
|
|
||||||
|
private Set<String> aliases;
|
||||||
|
|
||||||
|
public AliasAcquireVisitor(Set<String> aliases) {
|
||||||
|
this.aliases = aliases;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void visit(SelectItem selectItem) {
|
||||||
|
Alias alias = selectItem.getAlias();
|
||||||
|
if (alias != null) {
|
||||||
|
aliases.add(alias.getName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.tencent.supersonic.common.jsqlparser;
|
package com.tencent.supersonic.common.jsqlparser;
|
||||||
|
|
||||||
|
import com.google.common.collect.Sets;
|
||||||
import net.sf.jsqlparser.expression.Alias;
|
import net.sf.jsqlparser.expression.Alias;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
||||||
@@ -11,6 +12,7 @@ import java.util.Set;
|
|||||||
public class FieldAcquireVisitor extends ExpressionVisitorAdapter {
|
public class FieldAcquireVisitor extends ExpressionVisitorAdapter {
|
||||||
|
|
||||||
private Set<String> fields;
|
private Set<String> fields;
|
||||||
|
private Set<String> aliases = Sets.newHashSet();
|
||||||
|
|
||||||
public FieldAcquireVisitor(Set<String> fields) {
|
public FieldAcquireVisitor(Set<String> fields) {
|
||||||
this.fields = fields;
|
this.fields = fields;
|
||||||
@@ -26,8 +28,9 @@ public class FieldAcquireVisitor extends ExpressionVisitorAdapter {
|
|||||||
public void visit(SelectItem selectItem) {
|
public void visit(SelectItem selectItem) {
|
||||||
Alias alias = selectItem.getAlias();
|
Alias alias = selectItem.getAlias();
|
||||||
if (alias != null) {
|
if (alias != null) {
|
||||||
fields.add(alias.getName());
|
aliases.add(alias.getName());
|
||||||
}
|
}
|
||||||
|
|
||||||
Expression expression = selectItem.getExpression();
|
Expression expression = selectItem.getExpression();
|
||||||
if (expression != null) {
|
if (expression != null) {
|
||||||
expression.accept(this);
|
expression.accept(this);
|
||||||
|
|||||||
@@ -133,6 +133,15 @@ public class SqlSelectHelper {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static Set<String> getAliasFields(PlainSelect plainSelect) {
|
||||||
|
Set<String> result = new HashSet<>();
|
||||||
|
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
|
||||||
|
for (SelectItem selectItem : selectItems) {
|
||||||
|
selectItem.accept(new AliasAcquireVisitor(result));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
public static List<PlainSelect> getPlainSelect(Select selectStatement) {
|
public static List<PlainSelect> getPlainSelect(Select selectStatement) {
|
||||||
if (selectStatement == null) {
|
if (selectStatement == null) {
|
||||||
return null;
|
return null;
|
||||||
@@ -264,10 +273,14 @@ public class SqlSelectHelper {
|
|||||||
public static List<String> getAllSelectFields(String sql) {
|
public static List<String> getAllSelectFields(String sql) {
|
||||||
List<PlainSelect> plainSelects = getPlainSelects(getPlainSelect(sql));
|
List<PlainSelect> plainSelects = getPlainSelects(getPlainSelect(sql));
|
||||||
Set<String> results = new HashSet<>();
|
Set<String> results = new HashSet<>();
|
||||||
|
Set<String> aliases = new HashSet<>();
|
||||||
for (PlainSelect plainSelect : plainSelects) {
|
for (PlainSelect plainSelect : plainSelects) {
|
||||||
List<String> fields = getFieldsByPlainSelect(plainSelect);
|
List<String> fields = getFieldsByPlainSelect(plainSelect);
|
||||||
results.addAll(fields);
|
results.addAll(fields);
|
||||||
|
aliases.addAll(getAliasFields(plainSelect));
|
||||||
}
|
}
|
||||||
|
// do not account in aliases
|
||||||
|
results.removeAll(aliases);
|
||||||
return new ArrayList<>(results);
|
return new ArrayList<>(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.enums;
|
package com.tencent.supersonic.headless.api.pojo.enums;
|
||||||
|
|
||||||
public enum MapModeEnum {
|
public enum MapModeEnum {
|
||||||
STRICT(0), MODERATE(2), LOOSE(4);
|
STRICT(0), MODERATE(2), LOOSE(4), ALL(6);
|
||||||
|
|
||||||
public int threshold;
|
public int threshold;
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
package com.tencent.supersonic.headless.chat.mapper;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||||
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class AllFieldMapper extends BaseMapper {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean accept(ChatQueryContext chatQueryContext) {
|
||||||
|
return MapModeEnum.ALL.equals(chatQueryContext.getRequest().getMapModeEnum());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doMap(ChatQueryContext chatQueryContext) {
|
||||||
|
Map<Long, DataSetSchema> schemaMap =
|
||||||
|
chatQueryContext.getSemanticSchema().getDataSetSchemaMap();
|
||||||
|
for (Map.Entry<Long, DataSetSchema> entry : schemaMap.entrySet()) {
|
||||||
|
List<SchemaElement> schemaElements = Lists.newArrayList();
|
||||||
|
schemaElements.addAll(entry.getValue().getDimensions());
|
||||||
|
schemaElements.addAll(entry.getValue().getMetrics());
|
||||||
|
|
||||||
|
for (SchemaElement schemaElement : schemaElements) {
|
||||||
|
chatQueryContext.getMapInfo().getMatchedElements(entry.getKey())
|
||||||
|
.add(SchemaElementMatch.builder().word(schemaElement.getName())
|
||||||
|
.element(schemaElement).detectWord(schemaElement.getName())
|
||||||
|
.similarity(1.0).build());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -27,6 +27,10 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void map(ChatQueryContext chatQueryContext) {
|
public void map(ChatQueryContext chatQueryContext) {
|
||||||
|
if (!accept(chatQueryContext)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
String simpleName = this.getClass().getSimpleName();
|
String simpleName = this.getClass().getSimpleName();
|
||||||
long startTime = System.currentTimeMillis();
|
long startTime = System.currentTimeMillis();
|
||||||
log.debug("before {},mapInfo:{}", simpleName,
|
log.debug("before {},mapInfo:{}", simpleName,
|
||||||
@@ -46,6 +50,10 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
|
|
||||||
public abstract void doMap(ChatQueryContext chatQueryContext);
|
public abstract void doMap(ChatQueryContext chatQueryContext);
|
||||||
|
|
||||||
|
protected boolean accept(ChatQueryContext chatQueryContext) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
public void addToSchemaMap(SchemaMapInfo schemaMap, Long dataSetId,
|
public void addToSchemaMap(SchemaMapInfo schemaMap, Long dataSetId,
|
||||||
SchemaElementMatch newElementMatch) {
|
SchemaElementMatch newElementMatch) {
|
||||||
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
|
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
|
||||||
|
|||||||
@@ -20,12 +20,13 @@ import java.util.Objects;
|
|||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EmbeddingMapper extends BaseMapper {
|
public class EmbeddingMapper extends BaseMapper {
|
||||||
public void doMap(ChatQueryContext chatQueryContext) {
|
|
||||||
// Check if the map mode is LOOSE
|
|
||||||
if (!MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum())) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean accept(ChatQueryContext chatQueryContext) {
|
||||||
|
return MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void doMap(ChatQueryContext chatQueryContext) {
|
||||||
// 1. Query from embedding by queryText
|
// 1. Query from embedding by queryText
|
||||||
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||||
List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy);
|
List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy);
|
||||||
@@ -62,4 +63,5 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,11 +35,11 @@ public class MapperConfig extends ParameterConfig {
|
|||||||
"维度值相似度阈值在动态调整中的最低值", "number", "Mapper相关配置");
|
"维度值相似度阈值在动态调整中的最低值", "number", "Mapper相关配置");
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_MAPPER_TEXT_SIZE =
|
public static final Parameter EMBEDDING_MAPPER_TEXT_SIZE =
|
||||||
new Parameter("s2.mapper.embedding.word.size", "4", "用于向量召回文本长度",
|
new Parameter("s2.mapper.embedding.word.size", "3", "用于向量召回文本长度",
|
||||||
"为提高向量召回效率, 按指定长度进行向量语义召回", "number", "Mapper相关配置");
|
"为提高向量召回效率, 按指定长度进行向量语义召回", "number", "Mapper相关配置");
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_MAPPER_TEXT_STEP =
|
public static final Parameter EMBEDDING_MAPPER_TEXT_STEP =
|
||||||
new Parameter("s2.mapper.embedding.word.step", "3", "向量召回文本每步长度",
|
new Parameter("s2.mapper.embedding.word.step", "2", "向量召回文本每步长度",
|
||||||
"为提高向量召回效率, 按指定每步长度进行召回", "number", "Mapper相关配置");
|
"为提高向量召回效率, 按指定每步长度进行召回", "number", "Mapper相关配置");
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_MAPPER_BATCH =
|
public static final Parameter EMBEDDING_MAPPER_BATCH =
|
||||||
@@ -51,7 +51,7 @@ public class MapperConfig extends ParameterConfig {
|
|||||||
"每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置");
|
"每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置");
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_MAPPER_THRESHOLD =
|
public static final Parameter EMBEDDING_MAPPER_THRESHOLD =
|
||||||
new Parameter("s2.mapper.embedding.threshold", "0.98", "向量召回相似度阈值", "相似度小于该阈值的则舍弃",
|
new Parameter("s2.mapper.embedding.threshold", "0.8", "向量召回相似度阈值", "相似度小于该阈值的则舍弃",
|
||||||
"number", "Mapper相关配置");
|
"number", "Mapper相关配置");
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =
|
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =
|
||||||
|
|||||||
@@ -9,22 +9,23 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TimeFieldMapper extends BaseMapper {
|
public class PartitionTimeMapper extends BaseMapper {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean accept(ChatQueryContext chatQueryContext) {
|
||||||
|
return !(chatQueryContext.getRequest().getText2SQLType().equals(Text2SQLType.ONLY_RULE)
|
||||||
|
|| chatQueryContext.getMapInfo().isEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doMap(ChatQueryContext chatQueryContext) {
|
public void doMap(ChatQueryContext chatQueryContext) {
|
||||||
if (chatQueryContext.getRequest().getText2SQLType().equals(Text2SQLType.ONLY_RULE)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
Map<Long, DataSetSchema> schemaMap =
|
Map<Long, DataSetSchema> schemaMap =
|
||||||
chatQueryContext.getSemanticSchema().getDataSetSchemaMap();
|
chatQueryContext.getSemanticSchema().getDataSetSchemaMap();
|
||||||
for (Map.Entry<Long, DataSetSchema> entry : schemaMap.entrySet()) {
|
for (Map.Entry<Long, DataSetSchema> entry : schemaMap.entrySet()) {
|
||||||
List<SchemaElement> timeDims = entry.getValue().getDimensions().stream()
|
List<SchemaElement> timeDims = entry.getValue().getDimensions().stream()
|
||||||
.filter(dim -> dim.getTimeFormat() != null).collect(Collectors.toList());
|
.filter(SchemaElement::isPartitionTime).toList();
|
||||||
for (SchemaElement schemaElement : timeDims) {
|
for (SchemaElement schemaElement : timeDims) {
|
||||||
chatQueryContext.getMapInfo().getMatchedElements(entry.getKey())
|
chatQueryContext.getMapInfo().getMatchedElements(entry.getKey())
|
||||||
.add(SchemaElementMatch.builder().word(schemaElement.getName())
|
.add(SchemaElementMatch.builder().word(schemaElement.getName())
|
||||||
@@ -21,14 +21,16 @@ import java.util.stream.Collectors;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class QueryFilterMapper extends BaseMapper {
|
public class QueryFilterMapper extends BaseMapper {
|
||||||
|
|
||||||
private double similarity = 1.0;
|
private final double similarity = 1.0;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean accept(ChatQueryContext chatQueryContext) {
|
||||||
|
return !chatQueryContext.getRequest().getDataSetIds().isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doMap(ChatQueryContext chatQueryContext) {
|
public void doMap(ChatQueryContext chatQueryContext) {
|
||||||
Set<Long> dataSetIds = chatQueryContext.getRequest().getDataSetIds();
|
Set<Long> dataSetIds = chatQueryContext.getRequest().getDataSetIds();
|
||||||
if (CollectionUtils.isEmpty(dataSetIds)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
SchemaMapInfo schemaMapInfo = chatQueryContext.getMapInfo();
|
SchemaMapInfo schemaMapInfo = chatQueryContext.getMapInfo();
|
||||||
clearOtherSchemaElementMatch(dataSetIds, schemaMapInfo);
|
clearOtherSchemaElementMatch(dataSetIds, schemaMapInfo);
|
||||||
for (Long dataSetId : dataSetIds) {
|
for (Long dataSetId : dataSetIds) {
|
||||||
|
|||||||
@@ -16,14 +16,15 @@ import java.util.List;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class TermDescMapper extends BaseMapper {
|
public class TermDescMapper extends BaseMapper {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean accept(ChatQueryContext chatQueryContext) {
|
||||||
|
return !(CollectionUtils.isEmpty(chatQueryContext.getMapInfo().getTermDescriptionToMap())
|
||||||
|
|| chatQueryContext.getRequest().isDescriptionMapped());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doMap(ChatQueryContext chatQueryContext) {
|
public void doMap(ChatQueryContext chatQueryContext) {
|
||||||
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
|
List<SchemaElement> termElements = chatQueryContext.getMapInfo().getTermDescriptionToMap();
|
||||||
List<SchemaElement> termElements = mapInfo.getTermDescriptionToMap();
|
|
||||||
if (CollectionUtils.isEmpty(termElements)
|
|
||||||
|| chatQueryContext.getRequest().isDescriptionMapped()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
for (SchemaElement schemaElement : termElements) {
|
for (SchemaElement schemaElement : termElements) {
|
||||||
ChatQueryContext queryCtx =
|
ChatQueryContext queryCtx =
|
||||||
buildQueryContext(chatQueryContext, schemaElement.getDescription());
|
buildQueryContext(chatQueryContext, schemaElement.getDescription());
|
||||||
|
|||||||
@@ -30,6 +30,9 @@ public class DefaultSemanticTranslator implements SemanticTranslator {
|
|||||||
if (parser.accept(queryStatement)) {
|
if (parser.accept(queryStatement)) {
|
||||||
log.debug("QueryConverter accept [{}]", parser.getClass().getName());
|
log.debug("QueryConverter accept [{}]", parser.getClass().getName());
|
||||||
parser.parse(queryStatement);
|
parser.parse(queryStatement);
|
||||||
|
if (queryStatement.getStatus() != 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!queryStatement.isOk()) {
|
if (!queryStatement.isOk()) {
|
||||||
|
|||||||
@@ -63,6 +63,14 @@ public class SqlQueryParser implements QueryParser {
|
|||||||
List<String> metrics =
|
List<String> metrics =
|
||||||
metricSchemas.stream().map(SchemaItem::getBizName).collect(Collectors.toList());
|
metricSchemas.stream().map(SchemaItem::getBizName).collect(Collectors.toList());
|
||||||
Set<String> dimensions = getDimensions(semanticSchemaResp, allFields);
|
Set<String> dimensions = getDimensions(semanticSchemaResp, allFields);
|
||||||
|
// check if there are fields not matched with any metric or dimension
|
||||||
|
if (allFields.size() > metricSchemas.size() + dimensions.size()) {
|
||||||
|
queryStatement
|
||||||
|
.setErrMsg("There are fields in the SQL not matched with any semantic column.");
|
||||||
|
queryStatement.setStatus(1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
OntologyQuery ontologyQuery = new OntologyQuery();
|
OntologyQuery ontologyQuery = new OntologyQuery();
|
||||||
ontologyQuery.getMetrics().addAll(metrics);
|
ontologyQuery.getMetrics().addAll(metrics);
|
||||||
ontologyQuery.getDimensions().addAll(dimensions);
|
ontologyQuery.getDimensions().addAll(dimensions);
|
||||||
|
|||||||
@@ -76,7 +76,6 @@ public class ChatWorkflowEngine {
|
|||||||
long start = System.currentTimeMillis();
|
long start = System.currentTimeMillis();
|
||||||
performTranslating(queryCtx, parseResult);
|
performTranslating(queryCtx, parseResult);
|
||||||
parseResult.getParseTimeCost().setSqlTime(System.currentTimeMillis() - start);
|
parseResult.getParseTimeCost().setSqlTime(System.currentTimeMillis() - start);
|
||||||
parseResult.setState(ParseResp.ParseState.COMPLETED);
|
|
||||||
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
|
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
@@ -137,7 +136,12 @@ public class ChatWorkflowEngine {
|
|||||||
ContextUtils.getBean(SemanticLayerService.class);
|
ContextUtils.getBean(SemanticLayerService.class);
|
||||||
SemanticTranslateResp explain =
|
SemanticTranslateResp explain =
|
||||||
queryService.translate(semanticQueryReq, queryCtx.getRequest().getUser());
|
queryService.translate(semanticQueryReq, queryCtx.getRequest().getUser());
|
||||||
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
|
if (explain.isOk()) {
|
||||||
|
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
|
||||||
|
parseResult.setState(ParseResp.ParseState.COMPLETED);
|
||||||
|
} else {
|
||||||
|
parseResult.setState(ParseResp.ParseState.FAILED);
|
||||||
|
}
|
||||||
if (StringUtils.isNotBlank(explain.getErrMsg())) {
|
if (StringUtils.isNotBlank(explain.getErrMsg())) {
|
||||||
errorMsg.add(explain.getErrMsg());
|
errorMsg.add(explain.getErrMsg());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ com.tencent.supersonic.headless.chat.mapper.SchemaMapper=\
|
|||||||
com.tencent.supersonic.headless.chat.mapper.EmbeddingMapper, \
|
com.tencent.supersonic.headless.chat.mapper.EmbeddingMapper, \
|
||||||
com.tencent.supersonic.headless.chat.mapper.KeywordMapper, \
|
com.tencent.supersonic.headless.chat.mapper.KeywordMapper, \
|
||||||
com.tencent.supersonic.headless.chat.mapper.QueryFilterMapper, \
|
com.tencent.supersonic.headless.chat.mapper.QueryFilterMapper, \
|
||||||
com.tencent.supersonic.headless.chat.mapper.TimeFieldMapper,\
|
com.tencent.supersonic.headless.chat.mapper.PartitionTimeMapper,\
|
||||||
com.tencent.supersonic.headless.chat.mapper.TermDescMapper
|
com.tencent.supersonic.headless.chat.mapper.TermDescMapper
|
||||||
|
|
||||||
com.tencent.supersonic.headless.chat.parser.SemanticParser=\
|
com.tencent.supersonic.headless.chat.parser.SemanticParser=\
|
||||||
|
|||||||
@@ -4,8 +4,9 @@ com.tencent.supersonic.headless.chat.mapper.SchemaMapper=\
|
|||||||
com.tencent.supersonic.headless.chat.mapper.EmbeddingMapper, \
|
com.tencent.supersonic.headless.chat.mapper.EmbeddingMapper, \
|
||||||
com.tencent.supersonic.headless.chat.mapper.KeywordMapper, \
|
com.tencent.supersonic.headless.chat.mapper.KeywordMapper, \
|
||||||
com.tencent.supersonic.headless.chat.mapper.QueryFilterMapper, \
|
com.tencent.supersonic.headless.chat.mapper.QueryFilterMapper, \
|
||||||
com.tencent.supersonic.headless.chat.mapper.TimeFieldMapper,\
|
com.tencent.supersonic.headless.chat.mapper.PartitionTimeMapper,\
|
||||||
com.tencent.supersonic.headless.chat.mapper.TermDescMapper
|
com.tencent.supersonic.headless.chat.mapper.TermDescMapper,\
|
||||||
|
com.tencent.supersonic.headless.chat.mapper.AllFieldMapper
|
||||||
|
|
||||||
com.tencent.supersonic.headless.chat.parser.SemanticParser=\
|
com.tencent.supersonic.headless.chat.parser.SemanticParser=\
|
||||||
com.tencent.supersonic.headless.chat.parser.llm.LLMSqlParser,\
|
com.tencent.supersonic.headless.chat.parser.llm.LLMSqlParser,\
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import org.junit.jupiter.api.BeforeAll;
|
|||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.TestInstance;
|
import org.junit.jupiter.api.TestInstance;
|
||||||
|
import org.junitpioneer.jupiter.SetSystemProperty;
|
||||||
import org.springframework.test.context.TestPropertySource;
|
import org.springframework.test.context.TestPropertySource;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -95,6 +96,7 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@SetSystemProperty(key = "s2.test", value = "true")
|
||||||
public void test_drilldown_and_topN() throws Exception {
|
public void test_drilldown_and_topN() throws Exception {
|
||||||
long start = System.currentTimeMillis();
|
long start = System.currentTimeMillis();
|
||||||
QueryResult result = submitNewChat("过去30天访问次数最高的部门top3", agent.getId());
|
QueryResult result = submitNewChat("过去30天访问次数最高的部门top3", agent.getId());
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
|||||||
import com.tencent.supersonic.util.DataUtils;
|
import com.tencent.supersonic.util.DataUtils;
|
||||||
import org.junit.jupiter.api.Assertions;
|
import org.junit.jupiter.api.Assertions;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junitpioneer.jupiter.SetSystemProperty;
|
||||||
|
|
||||||
import static java.time.LocalDate.now;
|
import static java.time.LocalDate.now;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
@@ -29,6 +30,7 @@ public class QueryBySqlTest extends BaseTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@SetSystemProperty(key = "s2.test", value = "true")
|
||||||
public void testSumQuery() throws Exception {
|
public void testSumQuery() throws Exception {
|
||||||
SemanticQueryResp semanticQueryResp =
|
SemanticQueryResp semanticQueryResp =
|
||||||
queryBySql("SELECT SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 ");
|
queryBySql("SELECT SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 ");
|
||||||
|
|||||||
Reference in New Issue
Block a user