(improvement)(headless)Remove unnecessary performExecution method from ChatQueryService.

This commit is contained in:
jerryjzhang
2024-07-09 17:37:36 +08:00
parent f0b4eb46cf
commit ea4aa3eacf
10 changed files with 126 additions and 141 deletions

View File

@@ -28,6 +28,7 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@@ -48,6 +49,9 @@ public class ChatServiceImpl implements ChatService {
private RetrieveService retrieveService;
@Autowired
private AgentService agentService;
@Autowired
private SemanticLayerService semanticLayerService;
private List<ChatParser> chatParsers = ComponentFactory.getChatParsers();
private List<ChatExecutor> chatExecutors = ComponentFactory.getChatExecutors();
private List<ParseResultProcessor> parseResultProcessors = ComponentFactory.getParseProcessors();

View File

@@ -1,10 +1,14 @@
package com.tencent.supersonic.headless.server.facade.rest;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping;
@@ -25,6 +29,9 @@ public class ChatQueryApiController {
@Autowired
private RetrieveService retrieveService;
@Autowired
private SemanticLayerService semanticLayerService;
@PostMapping("/chat/search")
public Object search(@RequestBody QueryNLReq queryNLReq,
HttpServletRequest request,
@@ -34,9 +41,9 @@ public class ChatQueryApiController {
}
@PostMapping("/chat/map")
public MapResp map(@RequestBody QueryNLReq queryNLReq,
HttpServletRequest request,
HttpServletResponse response) {
public Object map(@RequestBody QueryNLReq queryNLReq,
HttpServletRequest request,
HttpServletResponse response) {
queryNLReq.setUser(UserHolder.findUser(request, response));
return chatQueryService.performMapping(queryNLReq);
}
@@ -49,4 +56,21 @@ public class ChatQueryApiController {
return chatQueryService.performParsing(queryNLReq);
}
@PostMapping("/chat")
public Object queryByNL(@RequestBody QueryNLReq queryNLReq,
HttpServletRequest request,
HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
ParseResp parseResp = chatQueryService.performParsing(queryNLReq);
if (parseResp.getState().equals(ParseResp.ParseState.COMPLETED)) {
SemanticParseInfo parseInfo = parseResp.getSelectedParses().get(0);
QuerySqlReq sqlReq = new QuerySqlReq();
sqlReq.setSql(parseInfo.getSqlInfo().getCorrectedS2SQL());
sqlReq.setSqlInfo(parseInfo.getSqlInfo());
return semanticLayerService.queryByReq(sqlReq, user);
}
throw new RuntimeException("Failed to parse natural language query: " + queryNLReq.getQueryText());
}
}

View File

@@ -24,7 +24,7 @@ public class DataSetQueryApiController {
@Autowired
private DataSetService dataSetService;
@Autowired
private SemanticLayerService queryService;
private SemanticLayerService semanticLayerService;
@PostMapping("/dataSet")
public Object queryByDataSet(@RequestBody QueryDataSetReq queryDataSetReq,
@@ -32,7 +32,7 @@ public class DataSetQueryApiController {
HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
SemanticQueryReq queryReq = dataSetService.convert(queryDataSetReq);
return queryService.queryByReq(queryReq, user);
return semanticLayerService.queryByReq(queryReq, user);
}
}

View File

@@ -25,7 +25,7 @@ import javax.servlet.http.HttpServletResponse;
public class MetricQueryApiController {
@Autowired
private SemanticLayerService queryService;
private SemanticLayerService semanticLayerService;
@Autowired
private MetricService metricService;
@@ -39,7 +39,7 @@ public class MetricQueryApiController {
HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
QueryStructReq queryStructReq = metricService.convert(queryMetricReq);
return queryService.queryByReq(queryStructReq.convert(true), user);
return semanticLayerService.queryByReq(queryStructReq.convert(true), user);
}
@PostMapping("/download/metric")

View File

@@ -30,7 +30,7 @@ import java.util.stream.Collectors;
public class SqlQueryApiController {
@Autowired
private SemanticLayerService queryService;
private SemanticLayerService semanticLayerService;
@Autowired
private ChatQueryService chatQueryService;
@@ -43,7 +43,7 @@ public class SqlQueryApiController {
String sql = querySqlReq.getSql();
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
chatQueryService.correct(querySqlReq, user);
return queryService.queryByReq(querySqlReq, user);
return semanticLayerService.queryByReq(querySqlReq, user);
}
@PostMapping("/sqls")
@@ -63,7 +63,7 @@ public class SqlQueryApiController {
List<CompletableFuture<SemanticQueryResp>> futures = semanticQueryReqs.stream()
.map(querySqlReq -> CompletableFuture.supplyAsync(() -> {
try {
return queryService.queryByReq(querySqlReq, user);
return semanticLayerService.queryByReq(querySqlReq, user);
} catch (Exception e) {
log.error("querySqlReq:{},queryByReq error:", querySqlReq, e);
return new SemanticQueryResp();
@@ -88,7 +88,7 @@ public class SqlQueryApiController {
List<SemanticQueryResp> semanticQueryRespList = new ArrayList<>();
try {
for (SemanticQueryReq semanticQueryReq : semanticQueryReqs) {
SemanticQueryResp semanticQueryResp = queryService.queryByReq(semanticQueryReq, user);
SemanticQueryResp semanticQueryResp = semanticLayerService.queryByReq(semanticQueryReq, user);
semanticQueryRespList.add(semanticQueryResp);
}
} catch (Exception e) {

View File

@@ -20,14 +20,14 @@ import org.springframework.web.bind.annotation.RestController;
public class TagQueryApiController {
@Autowired
private SemanticLayerService queryService;
private SemanticLayerService semanticLayerService;
@PostMapping("/tag")
public Object queryByTag(@RequestBody QueryStructReq queryStructReq,
HttpServletRequest request,
HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
return queryService.queryByReq(queryStructReq.convert(), user);
return semanticLayerService.queryByReq(queryStructReq.convert(), user);
}
}

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
@@ -21,19 +20,16 @@ public interface ChatQueryService {
MapResp performMapping(QueryNLReq queryNLReq);
MapInfoResp map(QueryMapReq queryMapReq);
ParseResp performParsing(QueryNLReq queryNLReq);
@Deprecated
QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception;
SemanticParseInfo queryContext(Integer chatId);
QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws Exception;
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
MapInfoResp map(QueryMapReq queryMapReq);
void correct(QuerySqlReq querySqlReq, User user);
SqlEvaluation validate(QuerySqlReq querySqlReq, User user);

View File

@@ -7,6 +7,6 @@ import java.util.List;
public interface RetrieveService {
List<SearchResult> retrieve(QueryNLReq queryCtx);
List<SearchResult> retrieve(QueryNLReq queryNLReq);
}

View File

@@ -25,9 +25,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.enums.CostType;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
@@ -62,7 +60,6 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine;
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
import com.tencent.supersonic.headless.server.web.service.ChatContextService;
@@ -176,36 +173,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return queryCtx;
}
@Override
@Deprecated
public QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception {
List<StatisticsDO> timeCostDOList = new ArrayList<>();
SemanticParseInfo parseInfo = queryReq.getParseInfo();
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
if (semanticQuery == null) {
return null;
}
semanticQuery.setParseInfo(parseInfo);
// in order to support multi-turn conversation, chat context is needed
ChatContext chatCtx = chatContextService.getOrCreateContext(queryReq.getChatId());
long startTime = System.currentTimeMillis();
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
QueryResult queryResult = doExecution(semanticQueryReq, parseInfo, queryReq.getUser());
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(semanticQuery.getClass().getSimpleName()).type(CostType.QUERY.getType()).build());
queryResult.setQueryTimeCost(timeCostDOList.get(0).getCost().longValue());
queryResult.setChatContext(parseInfo);
// update chat context after a successful semantic query
if (QueryState.SUCCESS.equals(queryResult.getQueryState()) && queryReq.isSaveAnswer()) {
chatCtx.setParseInfo(parseInfo);
chatContextService.updateContext(chatCtx);
}
chatCtx.setQueryText(queryReq.getQueryText());
chatCtx.setUser(queryReq.getUser().getName());
return queryResult;
}
private QueryResult doExecution(SemanticQueryReq semanticQueryReq,
SemanticParseInfo parseInfo, User user) throws Exception {
SemanticQueryResp queryResp = semanticLayerService.queryByReq(semanticQueryReq, user);

View File

@@ -76,8 +76,9 @@ public class S2SemanticLayerService implements SemanticLayerService {
private final DataSetService dataSetService;
private final SchemaService schemaService;
private final SemanticTranslator semanticTranslator;
private final MetricDrillDownChecker metricDrillDownChecker;
private QueryCache queryCache = ComponentFactory.getQueryCache();
private List<QueryExecutor> queryExecutors = ComponentFactory.getQueryExecutors();
public S2SemanticLayerService(
StatUtils statUtils,
@@ -102,6 +103,18 @@ public class S2SemanticLayerService implements SemanticLayerService {
return schemaService.getDataSetSchema(id);
}
@S2DataPermission
@Override
public SemanticTranslateResp translate(SemanticQueryReq queryReq, User user) throws Exception {
QueryStatement queryStatement = buildQueryStatement(queryReq, user);
semanticTranslator.translate(queryStatement);
return SemanticTranslateResp.builder()
.querySQL(queryStatement.getSql())
.isOk(queryStatement.isOk())
.errMsg(queryStatement.getErrMsg())
.build();
}
@Override
@S2DataPermission
@SneakyThrows
@@ -111,8 +124,9 @@ public class S2SemanticLayerService implements SemanticLayerService {
try {
//1.initStatInfo
statUtils.initStatInfo(queryReq, user);
//2.query from cache
QueryCache queryCache = ComponentFactory.getQueryCache();
String cacheKey = queryCache.getCacheKey(queryReq);
log.debug("cacheKey:{}", cacheKey);
Object query = queryCache.query(queryReq, cacheKey);
@@ -122,19 +136,35 @@ public class S2SemanticLayerService implements SemanticLayerService {
return queryResp;
}
StatUtils.get().setUseResultCache(false);
//3 query
QueryStatement queryStatement = buildQueryStatement(queryReq, user);
SemanticQueryResp result = doQuery(queryStatement);
SemanticQueryResp queryResp = null;
// skip translation if already done.
if (!queryStatement.isTranslated()) {
semanticTranslator.translate(queryStatement);
}
queryPreCheck(queryStatement);
for (QueryExecutor queryExecutor : queryExecutors) {
if (queryExecutor.accept(queryStatement)) {
queryResp = queryExecutor.execute(queryStatement);
queryUtils.fillItemNameInfo(queryResp, queryStatement.getSemanticSchemaResp());
}
}
//4 reset cache and set stateInfo
Boolean setCacheSuccess = queryCache.put(cacheKey, result);
Boolean setCacheSuccess = queryCache.put(cacheKey, queryResp);
if (setCacheSuccess) {
// if result is not null, update cache data
statUtils.updateResultCacheKey(cacheKey);
}
if (Objects.isNull(result)) {
if (Objects.isNull(queryResp)) {
state = TaskStatusEnum.ERROR;
}
return result;
return queryResp;
} catch (Exception e) {
log.error("exception in queryByStruct, e: ", e);
state = TaskStatusEnum.ERROR;
@@ -144,6 +174,49 @@ public class S2SemanticLayerService implements SemanticLayerService {
}
}
@Override
@SneakyThrows
public SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user) {
QuerySqlReq querySqlReq = buildQuerySqlReq(queryDimValueReq);
return queryByReq(querySqlReq, user);
}
public EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user) {
if (parseInfo != null && parseInfo.getDataSetId() != null && parseInfo.getDataSetId() > 0) {
EntityInfo entityInfo = getEntityBasicInfo(dataSetSchema);
if (parseInfo.getDimensionFilters().size() <= 0 || entityInfo.getDataSetInfo() == null) {
entityInfo.setMetrics(null);
entityInfo.setDimensions(null);
return entityInfo;
}
String primaryKey = entityInfo.getDataSetInfo().getPrimaryKey();
if (StringUtils.isNotBlank(primaryKey)) {
String entityId = "";
for (QueryFilter chatFilter : parseInfo.getDimensionFilters()) {
if (chatFilter != null && chatFilter.getBizName() != null && chatFilter.getBizName()
.equals(primaryKey)) {
if (chatFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) {
entityId = chatFilter.getValue().toString();
}
}
}
entityInfo.setEntityId(entityId);
try {
fillEntityInfoValue(entityInfo, dataSetSchema, user);
return entityInfo;
} catch (Exception e) {
log.error("setMainModel error", e);
}
}
}
return null;
}
@Override
public List<ItemResp> getDomainDataSetTree() {
return schemaService.getDomainDataSetTree();
}
private QueryStatement buildSqlQueryStatement(QuerySqlReq querySqlReq, User user) throws Exception {
//If dataSetId or DataSetName is empty, parse dataSetId from the SQL
if (querySqlReq.needGetDataSetId()) {
@@ -171,8 +244,8 @@ public class S2SemanticLayerService implements SemanticLayerService {
if (semanticQueryReq instanceof QueryMultiStructReq) {
queryStatement = buildMultiStructQueryStatement((QueryMultiStructReq) semanticQueryReq, user);
}
if (Objects.nonNull(queryStatement) && Objects.nonNull(semanticQueryReq.getSqlInfo()) && StringUtils.isNotBlank(
semanticQueryReq.getSqlInfo().getQuerySQL())) {
if (Objects.nonNull(queryStatement) && Objects.nonNull(semanticQueryReq.getSqlInfo())
&& StringUtils.isNotBlank(semanticQueryReq.getSqlInfo().getQuerySQL())) {
queryStatement.setSql(semanticQueryReq.getSqlInfo().getQuerySQL());
queryStatement.setDataSetId(semanticQueryReq.getDataSetId());
queryStatement.setIsTranslated(true);
@@ -218,29 +291,6 @@ public class S2SemanticLayerService implements SemanticLayerService {
return schemaFilterReq;
}
@Override
@SneakyThrows
public SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user) {
QuerySqlReq querySqlReq = buildQuerySqlReq(queryDimValueReq);
return queryByReq(querySqlReq, user);
}
@S2DataPermission
@Override
public SemanticTranslateResp translate(SemanticQueryReq queryReq, User user) throws Exception {
QueryStatement queryStatement = buildQueryStatement(queryReq, user);
semanticTranslator.translate(queryStatement);
return SemanticTranslateResp.builder()
.querySQL(queryStatement.getSql())
.isOk(queryStatement.isOk())
.errMsg(queryStatement.getErrMsg())
.build();
}
public List<ItemResp> getDomainDataSetTree() {
return schemaService.getDomainDataSetTree();
}
private QuerySqlReq buildQuerySqlReq(QueryDimValueReq queryDimValueReq) {
QuerySqlReq querySqlReq = new QuerySqlReq();
List<ModelResp> modelResps = schemaService.getModelList(Lists.newArrayList(queryDimValueReq.getModelId()));
@@ -263,68 +313,12 @@ public class S2SemanticLayerService implements SemanticLayerService {
return querySqlReq;
}
private SemanticQueryResp doQuery(QueryStatement queryStatement) {
SemanticQueryResp semanticQueryResp = null;
try {
//1 translate
if (!queryStatement.isTranslated()) {
semanticTranslator.translate(queryStatement);
}
//2. query pre-check
queryPreCheck(queryStatement);
//3 execute
for (QueryExecutor queryExecutor : ComponentFactory.getQueryExecutors()) {
if (queryExecutor.accept(queryStatement)) {
semanticQueryResp = queryExecutor.execute(queryStatement);
queryUtils.fillItemNameInfo(semanticQueryResp, queryStatement.getSemanticSchemaResp());
}
}
return semanticQueryResp;
} catch (Exception e) {
log.error("exception in query, e: ", e);
throw e;
}
}
private void queryPreCheck(QueryStatement queryStatement) {
//Check whether the dimensions of the metric drill-down are correct temporarily,
//add the abstraction of a validator later.
metricDrillDownChecker.checkQuery(queryStatement);
}
public EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user) {
if (parseInfo != null && parseInfo.getDataSetId() != null && parseInfo.getDataSetId() > 0) {
EntityInfo entityInfo = getEntityBasicInfo(dataSetSchema);
if (parseInfo.getDimensionFilters().size() <= 0 || entityInfo.getDataSetInfo() == null) {
entityInfo.setMetrics(null);
entityInfo.setDimensions(null);
return entityInfo;
}
String primaryKey = entityInfo.getDataSetInfo().getPrimaryKey();
if (StringUtils.isNotBlank(primaryKey)) {
String entityId = "";
for (QueryFilter chatFilter : parseInfo.getDimensionFilters()) {
if (chatFilter != null && chatFilter.getBizName() != null && chatFilter.getBizName()
.equals(primaryKey)) {
if (chatFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) {
entityId = chatFilter.getValue().toString();
}
}
}
entityInfo.setEntityId(entityId);
try {
fillEntityInfoValue(entityInfo, dataSetSchema, user);
return entityInfo;
} catch (Exception e) {
log.error("setMainModel error", e);
}
}
}
return null;
}
private EntityInfo getEntityBasicInfo(DataSetSchema dataSetSchema) {
EntityInfo entityInfo = new EntityInfo();