diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index f3dee66f3..7022a54cf 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -101,14 +101,9 @@ public class NL2SQLParser implements ChatQueryParser { doParse(queryNLReq, parseResp); } - if (parseResp.getSelectedParses().isEmpty()) { - for (MapModeEnum mode : Lists.newArrayList(MapModeEnum.LOOSE)) { - queryNLReq.setMapModeEnum(mode); - doParse(queryNLReq, parseResp); - if (!parseResp.getSelectedParses().isEmpty()) { - break; - } - } + if (parseResp.getSelectedParses().isEmpty() && candidateParses.isEmpty()) { + queryNLReq.setMapModeEnum(MapModeEnum.LOOSE); + doParse(queryNLReq, parseResp); } if (parseResp.getSelectedParses().isEmpty()) { diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java index 0811350af..133971fcd 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java @@ -276,8 +276,10 @@ public class SqlSelectHelper { Set aliases = new HashSet<>(); for (PlainSelect plainSelect : plainSelects) { List fields = getFieldsByPlainSelect(plainSelect); + Set subaliases = getAliasFields(plainSelect); + subaliases.removeAll(fields); results.addAll(fields); - aliases.addAll(getAliasFields(plainSelect)); + aliases.addAll(subaliases); } // do not account in aliases results.removeAll(aliases); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperConfig.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperConfig.java index 38178b7cd..583071c84 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperConfig.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperConfig.java @@ -51,7 +51,7 @@ public class MapperConfig extends ParameterConfig { "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"); public static final Parameter EMBEDDING_MAPPER_THRESHOLD = - new Parameter("s2.mapper.embedding.threshold", "0.8", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", + new Parameter("s2.mapper.embedding.threshold", "0.9", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", "number", "Mapper相关配置"); public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER = diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java index 8a0888b0e..6a6736971 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java @@ -97,10 +97,10 @@ public class ModelServiceImpl implements ModelService { private final ThreadPoolExecutor executor; public ModelServiceImpl(ModelRepository modelRepository, DatabaseService databaseService, - @Lazy DimensionService dimensionService, @Lazy MetricService metricService, - DomainService domainService, UserService userService, DataSetService dataSetService, - DateInfoRepository dateInfoRepository, ModelRelaService modelRelaService, - @Qualifier("commonExecutor") ThreadPoolExecutor executor) { + @Lazy DimensionService dimensionService, @Lazy MetricService metricService, + DomainService domainService, UserService userService, DataSetService dataSetService, + DateInfoRepository dateInfoRepository, ModelRelaService modelRelaService, + @Qualifier("commonExecutor") ThreadPoolExecutor executor) { this.modelRepository = modelRepository; this.databaseService = databaseService; this.dimensionService = dimensionService; @@ -233,7 +233,7 @@ public class ModelServiceImpl implements ModelService { } private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List dbSchemas, - Map modelSchemaMap) { + Map modelSchemaMap) { ModelSchema modelSchema = new ModelSchema(); List semanticModellers = CoreComponentFactory.getSemanticModellers(); for (SemanticModeller semanticModeller : semanticModellers) { @@ -251,7 +251,7 @@ public class ModelServiceImpl implements ModelService { } private List convert(Map> dbColumnMap, - ModelBuildReq modelBuildReq) { + ModelBuildReq modelBuildReq) { return dbColumnMap.keySet().stream() .map(key -> convert(modelBuildReq, key, dbColumnMap.get(key))) .collect(Collectors.toList()); @@ -406,7 +406,7 @@ public class ModelServiceImpl implements ModelService { } public List getModelRespAuthInheritDomain(User user, Long domainId, - AuthType authType) { + AuthType authType) { List domainIds = domainService.getDomainAuthSet(user, authType).stream().filter(domainResp -> { if (domainId == null) { @@ -581,7 +581,7 @@ public class ModelServiceImpl implements ModelService { } public static boolean checkDataSetPermission(Set orgIds, User user, - ModelResp modelResp) { + ModelResp modelResp) { if (checkAdminPermission(orgIds, user, modelResp)) { return true; } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java index bc90f22c0..028f269cf 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java @@ -108,6 +108,7 @@ public class Text2SQLEval extends BaseTest { } @Test + @SetSystemProperty(key = "s2.test", value = "true") public void test_filter_and_top() throws Exception { long start = System.currentTimeMillis(); QueryResult result = submitNewChat("近半个月来marketing部门访问量最高的用户是谁", agent.getId()); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java index b3703e194..263a0225e 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java @@ -10,17 +10,21 @@ import com.tencent.supersonic.common.pojo.Order; import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.QueryType; +import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; +import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; import com.tencent.supersonic.headless.server.persistence.dataobject.DomainDO; import com.tencent.supersonic.headless.server.persistence.repository.DomainRepository; +import com.tencent.supersonic.headless.server.service.DatabaseService; import com.tencent.supersonic.headless.server.service.SchemaService; import com.tencent.supersonic.util.DataUtils; import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; import java.util.ArrayList; @@ -40,9 +44,12 @@ public class BaseTest extends BaseApplication { protected SchemaService schemaService; @Autowired private AgentService agentService; + @Autowired + protected DatabaseService databaseService; protected Agent agent; protected SemanticSchema schema; + protected DatabaseResp databaseResp; protected Agent getAgentByName(String agentName) { Optional agent = agentService.getAgents().stream() @@ -59,6 +66,16 @@ public class BaseTest extends BaseApplication { return semanticLayerService.queryByReq(buildQuerySqlReq(sql), user); } + protected void executeSql(String sql) { + if (databaseResp == null) { + databaseResp = databaseService.getDatabase(1L); + } + SemanticQueryResp queryResp = databaseService.executeSql(sql, databaseResp); + assert StringUtils.isBlank(queryResp.getErrorMsg()); + System.out.println( + String.format("Execute result: %s", JsonUtil.toString(queryResp.getResultList()))); + } + protected SemanticQueryReq buildQuerySqlReq(String sql) { QuerySqlReq querySqlCmd = new QuerySqlReq(); querySqlCmd.setSql(sql);