mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 04:27:39 +00:00
[fix][chat]Fix logic in s2sql parsing.
This commit is contained in:
@@ -101,14 +101,9 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
doParse(queryNLReq, parseResp);
|
doParse(queryNLReq, parseResp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parseResp.getSelectedParses().isEmpty()) {
|
if (parseResp.getSelectedParses().isEmpty() && candidateParses.isEmpty()) {
|
||||||
for (MapModeEnum mode : Lists.newArrayList(MapModeEnum.LOOSE)) {
|
queryNLReq.setMapModeEnum(MapModeEnum.LOOSE);
|
||||||
queryNLReq.setMapModeEnum(mode);
|
|
||||||
doParse(queryNLReq, parseResp);
|
doParse(queryNLReq, parseResp);
|
||||||
if (!parseResp.getSelectedParses().isEmpty()) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parseResp.getSelectedParses().isEmpty()) {
|
if (parseResp.getSelectedParses().isEmpty()) {
|
||||||
|
|||||||
@@ -276,8 +276,10 @@ public class SqlSelectHelper {
|
|||||||
Set<String> aliases = new HashSet<>();
|
Set<String> aliases = new HashSet<>();
|
||||||
for (PlainSelect plainSelect : plainSelects) {
|
for (PlainSelect plainSelect : plainSelects) {
|
||||||
List<String> fields = getFieldsByPlainSelect(plainSelect);
|
List<String> fields = getFieldsByPlainSelect(plainSelect);
|
||||||
|
Set<String> subaliases = getAliasFields(plainSelect);
|
||||||
|
subaliases.removeAll(fields);
|
||||||
results.addAll(fields);
|
results.addAll(fields);
|
||||||
aliases.addAll(getAliasFields(plainSelect));
|
aliases.addAll(subaliases);
|
||||||
}
|
}
|
||||||
// do not account in aliases
|
// do not account in aliases
|
||||||
results.removeAll(aliases);
|
results.removeAll(aliases);
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ public class MapperConfig extends ParameterConfig {
|
|||||||
"每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置");
|
"每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置");
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_MAPPER_THRESHOLD =
|
public static final Parameter EMBEDDING_MAPPER_THRESHOLD =
|
||||||
new Parameter("s2.mapper.embedding.threshold", "0.8", "向量召回相似度阈值", "相似度小于该阈值的则舍弃",
|
new Parameter("s2.mapper.embedding.threshold", "0.9", "向量召回相似度阈值", "相似度小于该阈值的则舍弃",
|
||||||
"number", "Mapper相关配置");
|
"number", "Mapper相关配置");
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =
|
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =
|
||||||
|
|||||||
@@ -108,6 +108,7 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@SetSystemProperty(key = "s2.test", value = "true")
|
||||||
public void test_filter_and_top() throws Exception {
|
public void test_filter_and_top() throws Exception {
|
||||||
long start = System.currentTimeMillis();
|
long start = System.currentTimeMillis();
|
||||||
QueryResult result = submitNewChat("近半个月来marketing部门访问量最高的用户是谁", agent.getId());
|
QueryResult result = submitNewChat("近半个月来marketing部门访问量最高的用户是谁", agent.getId());
|
||||||
|
|||||||
@@ -10,17 +10,21 @@ import com.tencent.supersonic.common.pojo.Order;
|
|||||||
import com.tencent.supersonic.common.pojo.User;
|
import com.tencent.supersonic.common.pojo.User;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||||
import com.tencent.supersonic.headless.server.persistence.dataobject.DomainDO;
|
import com.tencent.supersonic.headless.server.persistence.dataobject.DomainDO;
|
||||||
import com.tencent.supersonic.headless.server.persistence.repository.DomainRepository;
|
import com.tencent.supersonic.headless.server.persistence.repository.DomainRepository;
|
||||||
|
import com.tencent.supersonic.headless.server.service.DatabaseService;
|
||||||
import com.tencent.supersonic.headless.server.service.SchemaService;
|
import com.tencent.supersonic.headless.server.service.SchemaService;
|
||||||
import com.tencent.supersonic.util.DataUtils;
|
import com.tencent.supersonic.util.DataUtils;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -40,9 +44,12 @@ public class BaseTest extends BaseApplication {
|
|||||||
protected SchemaService schemaService;
|
protected SchemaService schemaService;
|
||||||
@Autowired
|
@Autowired
|
||||||
private AgentService agentService;
|
private AgentService agentService;
|
||||||
|
@Autowired
|
||||||
|
protected DatabaseService databaseService;
|
||||||
|
|
||||||
protected Agent agent;
|
protected Agent agent;
|
||||||
protected SemanticSchema schema;
|
protected SemanticSchema schema;
|
||||||
|
protected DatabaseResp databaseResp;
|
||||||
|
|
||||||
protected Agent getAgentByName(String agentName) {
|
protected Agent getAgentByName(String agentName) {
|
||||||
Optional<Agent> agent = agentService.getAgents().stream()
|
Optional<Agent> agent = agentService.getAgents().stream()
|
||||||
@@ -59,6 +66,16 @@ public class BaseTest extends BaseApplication {
|
|||||||
return semanticLayerService.queryByReq(buildQuerySqlReq(sql), user);
|
return semanticLayerService.queryByReq(buildQuerySqlReq(sql), user);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected void executeSql(String sql) {
|
||||||
|
if (databaseResp == null) {
|
||||||
|
databaseResp = databaseService.getDatabase(1L);
|
||||||
|
}
|
||||||
|
SemanticQueryResp queryResp = databaseService.executeSql(sql, databaseResp);
|
||||||
|
assert StringUtils.isBlank(queryResp.getErrorMsg());
|
||||||
|
System.out.println(
|
||||||
|
String.format("Execute result: %s", JsonUtil.toString(queryResp.getResultList())));
|
||||||
|
}
|
||||||
|
|
||||||
protected SemanticQueryReq buildQuerySqlReq(String sql) {
|
protected SemanticQueryReq buildQuerySqlReq(String sql) {
|
||||||
QuerySqlReq querySqlCmd = new QuerySqlReq();
|
QuerySqlReq querySqlCmd = new QuerySqlReq();
|
||||||
querySqlCmd.setSql(sql);
|
querySqlCmd.setSql(sql);
|
||||||
|
|||||||
Reference in New Issue
Block a user