(improvement)(project) support explain in semantic and show explain sql in web and fix chat start error (#103)

This commit is contained in:
lexluo09
2023-09-19 16:38:24 +08:00
committed by GitHub
parent 13dcf0edb9
commit a94a44826b
17 changed files with 367 additions and 110 deletions

View File

@@ -15,7 +15,10 @@ import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import java.util.ArrayList;
import java.util.List;
@@ -42,12 +45,10 @@ public class DslQuery extends PluginSemanticQuery {
@Override
public QueryResult execute(User user) {
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class);
LLMResp llmResp = dslParseResult.getLlmResp();
LLMResp llmResp = getLlmResp();
long startTime = System.currentTimeMillis();
QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(llmResp.getCorrectorSql(), parseInfo.getModelId());
QueryDslReq queryDslReq = getQueryDslReq(llmResp);
QueryResultWithSchemaResp queryResp = semanticLayer.queryByDsl(queryDslReq, user);
log.info("queryByDsl cost:{},querySql:{}", System.currentTimeMillis() - startTime, llmResp.getSqlOutput());
@@ -71,4 +72,30 @@ public class DslQuery extends PluginSemanticQuery {
parseInfo.setProperties(null);
return queryResult;
}
private LLMResp getLlmResp() {
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class);
return dslParseResult.getLlmResp();
}
private QueryDslReq getQueryDslReq(LLMResp llmResp) {
QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(llmResp.getCorrectorSql(), parseInfo.getModelId());
return queryDslReq;
}
@Override
public ExplainResp explain(User user) {
ExplainSqlReq explainSqlReq = null;
try {
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryTypeEnum.SQL)
.queryReq(getQueryDslReq(getLlmResp()))
.build();
return semanticLayer.explain(explainSqlReq, user);
} catch (Exception e) {
log.error("explain error explainSqlReq:{}", explainSqlReq, e);
}
return null;
}
}

View File

@@ -1,7 +1,9 @@
package com.tencent.supersonic.chat.query.plugin;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import lombok.extern.slf4j.Slf4j;
@Slf4j
@@ -17,5 +19,8 @@ public abstract class PluginSemanticQuery implements SemanticQuery {
return parseInfo;
}
@Override
public ExplainResp explain(User user) {
return null;
}
}

View File

@@ -21,8 +21,11 @@ import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.io.Serializable;
@@ -215,6 +218,22 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
return queryResult;
}
@Override
public ExplainResp explain(User user) {
ExplainSqlReq explainSqlReq = null;
try {
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryTypeEnum.STRUCT)
.queryReq(convertQueryStruct())
.build();
return semanticLayer.explain(explainSqlReq, user);
} catch (Exception e) {
log.error("explain error explainSqlReq:{}", explainSqlReq, e);
}
return null;
}
public QueryResult multiStructExecute(User user) {
String queryMode = parseInfo.getQueryMode();

View File

@@ -27,6 +27,7 @@ import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.service.StatisticsService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import java.util.List;
import java.util.ArrayList;
import java.util.Set;
@@ -35,9 +36,7 @@ import java.util.Comparator;
import java.util.Objects;
import java.util.stream.Collectors;
//import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.DateConf;
//import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
@@ -64,8 +63,6 @@ public class QueryServiceImpl implements QueryService {
@Autowired
private StatisticsService statisticsService;
private final String entity = "ENTITY";
@Value("${time.threshold: 100}")
private Integer timeThreshold;
@@ -109,12 +106,16 @@ public class QueryServiceImpl implements QueryService {
.map(SemanticQuery::getParseInfo)
.sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
.collect(Collectors.toList());
selectedParses.forEach(parseInfo -> {
if (parseInfo.getQueryMode().contains(entity)) {
String queryMode = parseInfo.getQueryMode();
if (QueryManager.isEntityQuery(queryMode)) {
EntityInfo entityInfo = ContextUtils.getBean(SemanticService.class)
.getEntityInfo(parseInfo, queryReq.getUser());
parseInfo.setEntityInfo(entityInfo);
}
addExplainSql(queryReq, parseInfo);
});
List<SemanticParseInfo> candidateParses = queryCtx.getCandidateQueries().stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
@@ -138,6 +139,19 @@ public class QueryServiceImpl implements QueryService {
return parseResult;
}
private void addExplainSql(QueryReq queryReq, SemanticParseInfo parseInfo) {
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
if (Objects.isNull(semanticQuery)) {
return;
}
semanticQuery.setParseInfo(parseInfo);
ExplainResp explain = semanticQuery.explain(queryReq.getUser());
if (Objects.isNull(explain)) {
return;
}
parseInfo.setSql(explain.getSql());
}
@Override
public QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception {
ChatParseDO chatParseDO = chatService.getParseInfo(queryReq.getQueryId(),
@@ -155,9 +169,9 @@ public class QueryServiceImpl implements QueryService {
chatCtx.setAgentId(queryReq.getAgentId());
Long startTime = System.currentTimeMillis();
QueryResult queryResult = semanticQuery.execute(queryReq.getUser());
Long endTime = System.currentTimeMillis();
if (queryResult != null) {
timeCostDOList.add(StatisticsDO.builder().cost((int) (endTime - startTime))
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(semanticQuery.getClass().getSimpleName()).type(CostType.QUERY.getType()).build());
saveInfo(timeCostDOList, queryReq.getQueryText(), queryReq.getQueryId(),
queryReq.getUser().getName(), queryReq.getChatId().longValue());
@@ -169,7 +183,6 @@ public class QueryServiceImpl implements QueryService {
}
chatCtx.setQueryText(queryReq.getQueryText());
chatCtx.setUser(queryReq.getUser().getName());
//chatService.addQuery(queryResult, chatCtx);
chatService.updateQuery(queryReq.getQueryId(), queryResult, chatCtx);
} else {
chatService.deleteChatQuery(queryReq.getQueryId());
@@ -179,8 +192,8 @@ public class QueryServiceImpl implements QueryService {
}
public void saveInfo(List<StatisticsDO> timeCostDOList,
String queryText, Long queryId,
String userName, Long chatId) {
String queryText, Long queryId,
String userName, Long chatId) {
List<StatisticsDO> list = timeCostDOList.stream()
.filter(o -> o.getCost() > timeThreshold).collect(Collectors.toList());
list.forEach(o -> {
@@ -264,13 +277,6 @@ public class QueryServiceImpl implements QueryService {
dateConf.setPeriod("DAY");
queryStructReq.setDateInfo(dateConf);
queryStructReq.setLimit(20L);
// List<Aggregator> aggregators = new ArrayList<>();
// Aggregator aggregator = new Aggregator(dimensionValueReq.getQueryFilter().getBizName(),
// AggOperatorEnum.DISTINCT);
// aggregators.add(aggregator);
// queryStructReq.setAggregators(aggregators);
queryStructReq.setModelId(dimensionValueReq.getModelId());
queryStructReq.setNativeQuery(true);
List<String> groups = new ArrayList<>();