[fix][chat]Fix logic in s2sql parsing.

This commit is contained in:
jerryjzhang
2024-12-27 14:12:10 +08:00
parent a23d1071a3
commit 0612833618
6 changed files with 33 additions and 18 deletions

View File

@@ -101,14 +101,9 @@ public class NL2SQLParser implements ChatQueryParser {
doParse(queryNLReq, parseResp); doParse(queryNLReq, parseResp);
} }
if (parseResp.getSelectedParses().isEmpty()) { if (parseResp.getSelectedParses().isEmpty() && candidateParses.isEmpty()) {
for (MapModeEnum mode : Lists.newArrayList(MapModeEnum.LOOSE)) { queryNLReq.setMapModeEnum(MapModeEnum.LOOSE);
queryNLReq.setMapModeEnum(mode);
doParse(queryNLReq, parseResp); doParse(queryNLReq, parseResp);
if (!parseResp.getSelectedParses().isEmpty()) {
break;
}
}
} }
if (parseResp.getSelectedParses().isEmpty()) { if (parseResp.getSelectedParses().isEmpty()) {

View File

@@ -276,8 +276,10 @@ public class SqlSelectHelper {
Set<String> aliases = new HashSet<>(); Set<String> aliases = new HashSet<>();
for (PlainSelect plainSelect : plainSelects) { for (PlainSelect plainSelect : plainSelects) {
List<String> fields = getFieldsByPlainSelect(plainSelect); List<String> fields = getFieldsByPlainSelect(plainSelect);
Set<String> subaliases = getAliasFields(plainSelect);
subaliases.removeAll(fields);
results.addAll(fields); results.addAll(fields);
aliases.addAll(getAliasFields(plainSelect)); aliases.addAll(subaliases);
} }
// do not account in aliases // do not account in aliases
results.removeAll(aliases); results.removeAll(aliases);

View File

@@ -51,7 +51,7 @@ public class MapperConfig extends ParameterConfig {
"每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"); "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_THRESHOLD = public static final Parameter EMBEDDING_MAPPER_THRESHOLD =
new Parameter("s2.mapper.embedding.threshold", "0.8", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", new Parameter("s2.mapper.embedding.threshold", "0.9", "向量召回相似度阈值", "相似度小于该阈值的则舍弃",
"number", "Mapper相关配置"); "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER = public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =

View File

@@ -108,6 +108,7 @@ public class Text2SQLEval extends BaseTest {
} }
@Test @Test
@SetSystemProperty(key = "s2.test", value = "true")
public void test_filter_and_top() throws Exception { public void test_filter_and_top() throws Exception {
long start = System.currentTimeMillis(); long start = System.currentTimeMillis();
QueryResult result = submitNewChat("近半个月来marketing部门访问量最高的用户是谁", agent.getId()); QueryResult result = submitNewChat("近半个月来marketing部门访问量最高的用户是谁", agent.getId());

View File

@@ -10,17 +10,21 @@ import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import com.tencent.supersonic.headless.server.persistence.dataobject.DomainDO; import com.tencent.supersonic.headless.server.persistence.dataobject.DomainDO;
import com.tencent.supersonic.headless.server.persistence.repository.DomainRepository; import com.tencent.supersonic.headless.server.persistence.repository.DomainRepository;
import com.tencent.supersonic.headless.server.service.DatabaseService;
import com.tencent.supersonic.headless.server.service.SchemaService; import com.tencent.supersonic.headless.server.service.SchemaService;
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.DataUtils;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import java.util.ArrayList; import java.util.ArrayList;
@@ -40,9 +44,12 @@ public class BaseTest extends BaseApplication {
protected SchemaService schemaService; protected SchemaService schemaService;
@Autowired @Autowired
private AgentService agentService; private AgentService agentService;
@Autowired
protected DatabaseService databaseService;
protected Agent agent; protected Agent agent;
protected SemanticSchema schema; protected SemanticSchema schema;
protected DatabaseResp databaseResp;
protected Agent getAgentByName(String agentName) { protected Agent getAgentByName(String agentName) {
Optional<Agent> agent = agentService.getAgents().stream() Optional<Agent> agent = agentService.getAgents().stream()
@@ -59,6 +66,16 @@ public class BaseTest extends BaseApplication {
return semanticLayerService.queryByReq(buildQuerySqlReq(sql), user); return semanticLayerService.queryByReq(buildQuerySqlReq(sql), user);
} }
protected void executeSql(String sql) {
if (databaseResp == null) {
databaseResp = databaseService.getDatabase(1L);
}
SemanticQueryResp queryResp = databaseService.executeSql(sql, databaseResp);
assert StringUtils.isBlank(queryResp.getErrorMsg());
System.out.println(
String.format("Execute result: %s", JsonUtil.toString(queryResp.getResultList())));
}
protected SemanticQueryReq buildQuerySqlReq(String sql) { protected SemanticQueryReq buildQuerySqlReq(String sql) {
QuerySqlReq querySqlCmd = new QuerySqlReq(); QuerySqlReq querySqlCmd = new QuerySqlReq();
querySqlCmd.setSql(sql); querySqlCmd.setSql(sql);