(improvement)(chat) dsl supports revision (#200)

This commit is contained in:
mainmain
2023-10-12 21:45:40 +08:00
committed by GitHub
parent 88b8130d37
commit 26beff1080
13 changed files with 162 additions and 54 deletions

View File

@@ -118,9 +118,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
} catch (Exception e) { } catch (Exception e) {
log.info("database insert has an exception:{}", e.toString()); log.info("database insert has an exception:{}", e.toString());
} }
Long queryId = chatQueryDO.getQuestionId();
ChatQueryDO lastChatQuery = getLastChatQuery(chatCtx.getChatId());
Long queryId = lastChatQuery.getQuestionId();
parseResult.setQueryId(queryId); parseResult.setQueryId(queryId);
return queryId; return queryId;
} }

View File

@@ -77,6 +77,10 @@ public class QueryManager {
return ruleQueryMap.get(queryMode) instanceof EntitySemanticQuery; return ruleQueryMap.get(queryMode) instanceof EntitySemanticQuery;
} }
public static boolean isPluginQuery(String queryMode) {
return queryMode != null && pluginQueryMap.containsKey(queryMode);
}
public static RuleSemanticQuery getRuleQuery(String queryMode) { public static RuleSemanticQuery getRuleQuery(String queryMode) {
if (queryMode == null) { if (queryMode == null) {
return null; return null;
@@ -92,4 +96,4 @@ public class QueryManager {
return new ArrayList<>(pluginQueryMap.keySet()); return new ArrayList<>(pluginQueryMap.keySet());
} }
} }

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List; import java.util.List;
@@ -21,6 +22,9 @@ public class EntityInfoExecuteResponder implements ExecuteResponder {
if (semanticParseInfo == null || semanticParseInfo.getModelId() <= 0L) { if (semanticParseInfo == null || semanticParseInfo.getModelId() <= 0L) {
return; return;
} }
if (QueryManager.isPluginQuery(semanticParseInfo.getQueryMode())) {
return;
}
SemanticService semanticService = ContextUtils.getBean(SemanticService.class); SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
User user = queryReq.getUser(); User user = queryReq.getUser();
EntityInfo entityInfo = semanticService.getEntityInfo(semanticParseInfo, user); EntityInfo entityInfo = semanticService.getEntityInfo(semanticParseInfo, user);
@@ -50,4 +54,4 @@ public class EntityInfoExecuteResponder implements ExecuteResponder {
} }
} }
} }

View File

@@ -6,12 +6,15 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List; import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@Slf4j
public class EntityInfoParseResponder implements ParseResponder { public class EntityInfoParseResponder implements ParseResponder {
@Override @Override
@@ -22,15 +25,20 @@ public class EntityInfoParseResponder implements ParseResponder {
} }
QueryReq queryReq = queryContext.getRequest(); QueryReq queryReq = queryContext.getRequest();
selectedParses.forEach(parseInfo -> { selectedParses.forEach(parseInfo -> {
if (QueryManager.isPluginQuery(parseInfo.getQueryMode())
&& !parseInfo.getQueryMode().equals(DslQuery.QUERY_MODE)) {
return;
}
//1. set entity info //1. set entity info
SemanticService semanticService = ContextUtils.getBean(SemanticService.class); SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, queryReq.getUser()); EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, queryReq.getUser());
if (QueryManager.isEntityQuery(parseInfo.getQueryMode()) if (QueryManager.isEntityQuery(parseInfo.getQueryMode())
|| QueryManager.isMetricQuery(parseInfo.getQueryMode())) { || QueryManager.isMetricQuery(parseInfo.getQueryMode())) {
parseInfo.setEntityInfo(entityInfo); parseInfo.setEntityInfo(entityInfo);
} }
//2. set native value //2. set native value
entityInfo = semanticService.getEntityInfo(parseInfo.getModelId());
log.info("entityInfo:{}", entityInfo);
String primaryEntityBizName = semanticService.getPrimaryEntityBizName(entityInfo); String primaryEntityBizName = semanticService.getPrimaryEntityBizName(entityInfo);
if (StringUtils.isNotEmpty(primaryEntityBizName)) { if (StringUtils.isNotEmpty(primaryEntityBizName)) {
//if exist primaryEntityBizName in parseInfo's dimensions, set nativeQuery to true //if exist primaryEntityBizName in parseInfo's dimensions, set nativeQuery to true
@@ -40,4 +48,4 @@ public class EntityInfoParseResponder implements ParseResponder {
} }
}); });
} }
} }

View File

@@ -15,12 +15,16 @@ public class ExplainSqlParseResponder implements ParseResponder {
@Override @Override
public void fillResponse(ParseResp parseResp, QueryContext queryContext) { public void fillResponse(ParseResp parseResp, QueryContext queryContext) {
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses(); QueryReq queryReq = queryContext.getRequest();
if (CollectionUtils.isEmpty(selectedParses)) { addExplainSql(queryReq, parseResp.getSelectedParses());
addExplainSql(queryReq, parseResp.getCandidateParses());
}
private void addExplainSql(QueryReq queryReq, List<SemanticParseInfo> semanticParseInfos) {
if (CollectionUtils.isEmpty(semanticParseInfos)) {
return; return;
} }
QueryReq queryReq = queryContext.getRequest(); semanticParseInfos.forEach(parseInfo -> {
selectedParses.forEach(parseInfo -> {
addExplainSql(queryReq, parseInfo); addExplainSql(queryReq, parseInfo);
}); });
} }
@@ -38,4 +42,4 @@ public class ExplainSqlParseResponder implements ParseResponder {
parseInfo.getSqlInfo().setQuerySql(explain.getSql()); parseInfo.getSqlInfo().setQuerySql(explain.getSql());
} }
} }

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.rest; package com.tencent.supersonic.chat.rest;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.chat.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
@@ -34,7 +35,7 @@ public class ChatQueryController {
@PostMapping("search") @PostMapping("search")
public Object search(@RequestBody QueryReq queryCtx, HttpServletRequest request, public Object search(@RequestBody QueryReq queryCtx, HttpServletRequest request,
HttpServletResponse response) { HttpServletResponse response) {
queryCtx.setUser(UserHolder.findUser(request, response)); queryCtx.setUser(UserHolder.findUser(request, response));
return searchService.search(queryCtx); return searchService.search(queryCtx);
} }
@@ -55,7 +56,7 @@ public class ChatQueryController {
@PostMapping("execute") @PostMapping("execute")
public Object execute(@RequestBody ExecuteQueryReq queryReq, public Object execute(@RequestBody ExecuteQueryReq queryReq,
HttpServletRequest request, HttpServletResponse response) HttpServletRequest request, HttpServletResponse response)
throws Exception { throws Exception {
queryReq.setUser(UserHolder.findUser(request, response)); queryReq.setUser(UserHolder.findUser(request, response));
return queryService.performExecution(queryReq); return queryService.performExecution(queryReq);
@@ -63,14 +64,14 @@ public class ChatQueryController {
@PostMapping("queryContext") @PostMapping("queryContext")
public Object queryContext(@RequestBody QueryReq queryCtx, HttpServletRequest request, public Object queryContext(@RequestBody QueryReq queryCtx, HttpServletRequest request,
HttpServletResponse response) throws Exception { HttpServletResponse response) throws Exception {
queryCtx.setUser(UserHolder.findUser(request, response)); queryCtx.setUser(UserHolder.findUser(request, response));
return queryService.queryContext(queryCtx); return queryService.queryContext(queryCtx);
} }
@PostMapping("queryData") @PostMapping("queryData")
public Object queryData(@RequestBody QueryDataReq queryData, public Object queryData(@RequestBody QueryDataReq queryData,
HttpServletRequest request, HttpServletResponse response) HttpServletRequest request, HttpServletResponse response)
throws Exception { throws Exception {
queryData.setUser(UserHolder.findUser(request, response)); queryData.setUser(UserHolder.findUser(request, response));
return queryService.executeDirectQuery(queryData, UserHolder.findUser(request, response)); return queryService.executeDirectQuery(queryData, UserHolder.findUser(request, response));
@@ -83,4 +84,12 @@ public class ChatQueryController {
return queryService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response)); return queryService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
} }
@RequestMapping("/getEntityInfo")
public Object getEntityInfo(Long queryId, Integer parseId,
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return queryService.getEntityInfo(queryId, parseId, user);
}
} }

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.chat.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq; import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
@@ -25,5 +26,8 @@ public interface QueryService {
QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws SqlParseException; QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws SqlParseException;
EntityInfo getEntityInfo(Long queryId, Integer parseId, User user);
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception; Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
} }

View File

@@ -15,6 +15,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq; import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState; import com.tencent.supersonic.chat.api.pojo.response.QueryState;
@@ -31,24 +32,24 @@ import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
import com.tencent.supersonic.chat.responder.parse.ParseResponder; import com.tencent.supersonic.chat.responder.parse.ParseResponder;
import com.tencent.supersonic.chat.service.ChatService; import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.QueryService; import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.service.StatisticsService; import com.tencent.supersonic.chat.service.StatisticsService;
import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.SolvedQueryManager; import com.tencent.supersonic.chat.utils.SolvedQueryManager;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression; import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.dictionary.MapResult; import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.service.SearchService; import com.tencent.supersonic.knowledge.service.SearchService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import com.tencent.supersonic.semantic.query.utils.QueryStructUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
@@ -58,6 +59,8 @@ import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
@@ -96,24 +99,20 @@ public class QueryServiceImpl implements QueryService {
// in order to support multi-turn conversation, chat context is needed // in order to support multi-turn conversation, chat context is needed
ChatContext chatCtx = chatService.getOrCreateContext(queryReq.getChatId()); ChatContext chatCtx = chatService.getOrCreateContext(queryReq.getChatId());
List<StatisticsDO> timeCostDOList = new ArrayList<>(); List<StatisticsDO> timeCostDOList = new ArrayList<>();
for (SchemaMapper mapper : schemaMappers) { schemaMappers.stream().forEach(mapper -> {
Long startTime = System.currentTimeMillis(); Long startTime = System.currentTimeMillis();
mapper.map(queryCtx); mapper.map(queryCtx);
Long endTime = System.currentTimeMillis(); timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
String className = mapper.getClass().getSimpleName(); .interfaceName(mapper.getClass().getSimpleName()).type(CostType.MAPPER.getType()).build());
timeCostDOList.add(StatisticsDO.builder().cost((int) (endTime - startTime)) log.info("{} result:{}", mapper.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
.interfaceName(className).type(CostType.MAPPER.getType()).build()); });
log.info("{} result:{}", className, JsonUtil.toString(queryCtx)); semanticParsers.stream().forEach(parser -> {
}
for (SemanticParser parser : semanticParsers) {
Long startTime = System.currentTimeMillis(); Long startTime = System.currentTimeMillis();
parser.parse(queryCtx, chatCtx); parser.parse(queryCtx, chatCtx);
Long endTime = System.currentTimeMillis(); timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
String className = parser.getClass().getSimpleName(); .interfaceName(parser.getClass().getSimpleName()).type(CostType.PARSER.getType()).build());
timeCostDOList.add(StatisticsDO.builder().cost((int) (endTime - startTime)) log.info("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
.interfaceName(className).type(CostType.PARSER.getType()).build()); });
log.info("{} result:{}", className, JsonUtil.toString(queryCtx));
}
ParseResp parseResult; ParseResp parseResult;
if (queryCtx.getCandidateQueries().size() > 0) { if (queryCtx.getCandidateQueries().size() > 0) {
log.debug("pick before [{}]", queryCtx.getCandidateQueries().stream().collect( log.debug("pick before [{}]", queryCtx.getCandidateQueries().stream().collect(
@@ -124,6 +123,7 @@ public class QueryServiceImpl implements QueryService {
List<SemanticParseInfo> selectedParses = convertParseInfo(selectedQueries); List<SemanticParseInfo> selectedParses = convertParseInfo(selectedQueries);
List<SemanticParseInfo> candidateParses = convertParseInfo(queryCtx.getCandidateQueries()); List<SemanticParseInfo> candidateParses = convertParseInfo(queryCtx.getCandidateQueries());
candidateParses = getTop5CandidateParseInfo(selectedParses, candidateParses);
parseResult = ParseResp.builder() parseResult = ParseResp.builder()
.chatId(queryReq.getChatId()) .chatId(queryReq.getChatId())
.queryText(queryReq.getQueryText()) .queryText(queryReq.getQueryText())
@@ -154,6 +154,24 @@ public class QueryServiceImpl implements QueryService {
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
private List<SemanticParseInfo> getTop5CandidateParseInfo(List<SemanticParseInfo> selectedParses,
List<SemanticParseInfo> candidateParses) {
if (CollectionUtils.isEmpty(selectedParses) || CollectionUtils.isEmpty(candidateParses)) {
return candidateParses;
}
int selectParseSize = selectedParses.size();
int candidateParseSize = 5 - selectParseSize;
SemanticParseInfo semanticParseInfo = selectedParses.get(0);
Long modelId = semanticParseInfo.getModelId();
if (modelId == null || modelId <= 0) {
return candidateParses;
}
return candidateParses.stream()
.sorted(Comparator.comparing(parse -> !parse.getModelId().equals(modelId)))
.limit(candidateParseSize)
.collect(Collectors.toList());
}
@Override @Override
public QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception { public QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception {
ChatParseDO chatParseDO = chatService.getParseInfo(queryReq.getQueryId(), ChatParseDO chatParseDO = chatService.getParseInfo(queryReq.getQueryId(),
@@ -277,6 +295,7 @@ public class QueryServiceImpl implements QueryService {
if (DslQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) { if (DslQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>(); Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT)); String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class); DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class);
LLMResp llmResp = dslParseResult.getLlmResp(); LLMResp llmResp = dslParseResult.getLlmResp();
@@ -288,41 +307,55 @@ public class QueryServiceImpl implements QueryService {
updateFilters(filedNameToValueMap, filterExpressionList, queryData.getDimensionFilters(), updateFilters(filedNameToValueMap, filterExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters()); parseInfo.getDimensionFilters());
updateFilters(filedNameToValueMap, filterExpressionList, queryData.getMetricFilters(), updateFilters(havingFiledNameToValueMap, filterExpressionList, queryData.getDimensionFilters(),
parseInfo.getMetricFilters()); parseInfo.getDimensionFilters());
updateDateInfo(queryData, parseInfo, filedNameToValueMap, filterExpressionList); updateDateInfo(queryData, parseInfo, filedNameToValueMap, filterExpressionList);
log.info("filedNameToValueMap:{}", filedNameToValueMap); log.info("filedNameToValueMap:{}", filedNameToValueMap);
correctorSql = SqlParserUpdateHelper.replaceValue(correctorSql, filedNameToValueMap); correctorSql = SqlParserUpdateHelper.replaceValue(correctorSql, filedNameToValueMap);
log.info("havingFiledNameToValueMap:{}", havingFiledNameToValueMap);
correctorSql = SqlParserUpdateHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap);
log.info("correctorSql after replacing:{}", correctorSql); log.info("correctorSql after replacing:{}", correctorSql);
llmResp.setCorrectorSql(correctorSql); llmResp.setCorrectorSql(correctorSql);
dslParseResult.setLlmResp(llmResp);
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, dslParseResult);
parseInfo.setProperties(properties);
parseInfo.getSqlInfo().setLogicSql(correctorSql); parseInfo.getSqlInfo().setLogicSql(correctorSql);
semanticQuery.setParseInfo(parseInfo);
ExplainResp explain = semanticQuery.explain(user); ExplainResp explain = semanticQuery.explain(user);
if (!Objects.isNull(explain)) { if (!Objects.isNull(explain)) {
parseInfo.getSqlInfo().setQuerySql(explain.getSql()); parseInfo.getSqlInfo().setQuerySql(explain.getSql());
} }
} }
log.info("parseInfo:{}", JsonUtil.toString(semanticQuery.getParseInfo().getProperties()));
semanticQuery.setParseInfo(parseInfo); semanticQuery.setParseInfo(parseInfo);
QueryResult queryResult = semanticQuery.execute(user); QueryResult queryResult = semanticQuery.execute(user);
queryResult.setChatContext(semanticQuery.getParseInfo()); queryResult.setChatContext(semanticQuery.getParseInfo());
return queryResult; return queryResult;
} }
@Override
public EntityInfo getEntityInfo(Long queryId, Integer parseId, User user) {
ChatParseDO chatParseDO = chatService.getParseInfo(queryId, user.getName(), parseId);
SemanticParseInfo parseInfo = JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
return semanticService.getEntityInfo(parseInfo, user);
}
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo, private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
Map<String, Map<String, String>> filedNameToValueMap, List<FilterExpression> filterExpressionList) { Map<String, Map<String, String>> filedNameToValueMap, List<FilterExpression> filterExpressionList) {
if (Objects.isNull(queryData.getDateInfo())) { if (Objects.isNull(queryData.getDateInfo())) {
return; return;
} }
Map<String, String> map = new HashMap<>(); Map<String, String> map = new HashMap<>();
List<String> dateFields = new ArrayList<>(QueryStructUtils.internalTimeCols); //List<String> dateFields = new ArrayList<>(QueryStructUtils.internalTimeCols);
String dateField = TimeDimensionEnum.DAY.getName(); String dateField = "数据日期";
if (queryData.getDateInfo().getStartDate().equals(queryData.getDateInfo().getEndDate())) { if (queryData.getDateInfo().getStartDate().equals(queryData.getDateInfo().getEndDate())) {
for (FilterExpression filterExpression : filterExpressionList) { for (FilterExpression filterExpression : filterExpressionList) {
if (filterExpression.getFieldName() != null if (filterExpression.getFieldName() != null
&& dateFields.contains(filterExpression.getFieldName())) { && filterExpression.getFieldName().equals("数据日期")) {
dateField = filterExpression.getFieldName(); dateField = filterExpression.getFieldName();
map.put(filterExpression.getFieldValue().toString(), map.put(filterExpression.getFieldValue().toString(),
queryData.getDateInfo().getStartDate()); queryData.getDateInfo().getStartDate());
@@ -331,7 +364,7 @@ public class QueryServiceImpl implements QueryService {
} }
} else { } else {
for (FilterExpression filterExpression : filterExpressionList) { for (FilterExpression filterExpression : filterExpressionList) {
if (dateFields.contains(filterExpression.getFieldName())) { if (filterExpression.getFieldName().equals("数据日期")) {
dateField = filterExpression.getFieldName(); dateField = filterExpression.getFieldName();
if (FilterOperatorEnum.GREATER_THAN_EQUALS.getValue().equals(filterExpression.getOperator()) if (FilterOperatorEnum.GREATER_THAN_EQUALS.getValue().equals(filterExpression.getOperator())
|| FilterOperatorEnum.GREATER_THAN.getValue().equals(filterExpression.getOperator())) { || FilterOperatorEnum.GREATER_THAN.getValue().equals(filterExpression.getOperator())) {
@@ -360,7 +393,7 @@ public class QueryServiceImpl implements QueryService {
Map<String, String> map = new HashMap<>(); Map<String, String> map = new HashMap<>();
for (FilterExpression filterExpression : filterExpressionList) { for (FilterExpression filterExpression : filterExpressionList) {
if (filterExpression.getFieldName() != null if (filterExpression.getFieldName() != null
&& filterExpression.getFieldName().equals(dslQueryFilter.getName()) && filterExpression.getFieldName().contains(dslQueryFilter.getName())
&& dslQueryFilter.getOperator().getValue().equals(filterExpression.getOperator())) { && dslQueryFilter.getOperator().getValue().equals(filterExpression.getOperator())) {
map.put(filterExpression.getFieldValue().toString(), dslQueryFilter.getValue().toString()); map.put(filterExpression.getFieldValue().toString(), dslQueryFilter.getValue().toString());
contextMetricFilters.stream().forEach(o -> { contextMetricFilters.stream().forEach(o -> {
@@ -407,7 +440,7 @@ public class QueryServiceImpl implements QueryService {
queryStructReq.setDateInfo(dateConf); queryStructReq.setDateInfo(dateConf);
queryStructReq.setLimit(20L); queryStructReq.setLimit(20L);
queryStructReq.setModelId(dimensionValueReq.getModelId()); queryStructReq.setModelId(dimensionValueReq.getModelId());
queryStructReq.setNativeQuery(true); queryStructReq.setNativeQuery(false);
List<String> groups = new ArrayList<>(); List<String> groups = new ArrayList<>();
groups.add(dimensionValueReq.getBizName()); groups.add(dimensionValueReq.getBizName());
queryStructReq.setGroups(groups); queryStructReq.setGroups(groups);

View File

@@ -29,7 +29,7 @@
select * select *
from s2_chat_parse from s2_chat_parse
where question_id = #{questionId} and user_name = #{userName} where question_id = #{questionId} and user_name = #{userName}
and parse_id = #{parseId} and is_candidate = 0 limit 1 and parse_id = #{parseId} limit 1
</select> </select>
</mapper> </mapper>

View File

@@ -72,12 +72,12 @@
delete from s2_chat_query delete from s2_chat_query
where question_id = #{questionId,jdbcType=BIGINT} where question_id = #{questionId,jdbcType=BIGINT}
</delete> </delete>
<insert id="insert" parameterType="com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO"> <insert id="insert" parameterType="com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO" useGeneratedKeys="true" keyProperty="questionId">
insert into s2_chat_query (question_id, agent_id, create_time, user_name, insert into s2_chat_query (agent_id, create_time, user_name,
query_state, chat_id, score, query_state, chat_id, score,
feedback, query_text, query_result feedback, query_text, query_result
) )
values (#{questionId,jdbcType=BIGINT}, #{agentId,jdbcType=INTEGER}, #{createTime,jdbcType=TIMESTAMP}, #{userName,jdbcType=VARCHAR}, values (#{agentId,jdbcType=INTEGER}, #{createTime,jdbcType=TIMESTAMP}, #{userName,jdbcType=VARCHAR},
#{queryState,jdbcType=INTEGER}, #{chatId,jdbcType=BIGINT}, #{score,jdbcType=INTEGER}, #{queryState,jdbcType=INTEGER}, #{chatId,jdbcType=BIGINT}, #{score,jdbcType=INTEGER},
#{feedback,jdbcType=VARCHAR}, #{queryText,jdbcType=LONGVARCHAR}, #{queryResult,jdbcType=LONGVARCHAR} #{feedback,jdbcType=VARCHAR}, #{queryText,jdbcType=LONGVARCHAR}, #{queryResult,jdbcType=LONGVARCHAR}
) )

View File

@@ -1,13 +1,16 @@
package com.tencent.supersonic.common.util.jsqlparser; package com.tencent.supersonic.common.util.jsqlparser;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.DoubleValue; import net.sf.jsqlparser.expression.DoubleValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan; import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThan; import net.sf.jsqlparser.expression.operators.relational.MinorThan;
@@ -19,6 +22,7 @@ import net.sf.jsqlparser.schema.Column;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@Slf4j
public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter { public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter {
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper(); ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
@@ -55,7 +59,7 @@ public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter {
public <T extends Expression> void replaceComparisonExpression(T expression) { public <T extends Expression> void replaceComparisonExpression(T expression) {
Expression leftExpression = ((ComparisonOperator) expression).getLeftExpression(); Expression leftExpression = ((ComparisonOperator) expression).getLeftExpression();
Expression rightExpression = ((ComparisonOperator) expression).getRightExpression(); Expression rightExpression = ((ComparisonOperator) expression).getRightExpression();
if (!(leftExpression instanceof Column)) { if (!(leftExpression instanceof Column || leftExpression instanceof Function)) {
return; return;
} }
if (CollectionUtils.isEmpty(filedNameToValueMap)) { if (CollectionUtils.isEmpty(filedNameToValueMap)) {
@@ -64,18 +68,30 @@ public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter {
if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) { if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) {
return; return;
} }
Column leftColumnName = (Column) leftExpression; String columnName = "";
if (leftExpression instanceof Column) {
String columnName = leftColumnName.getColumnName(); Column leftColumnName = (Column) leftExpression;
columnName = leftColumnName.getColumnName();
}
if (leftExpression instanceof Function) {
Function function = (Function) leftExpression;
columnName = ((Column) function.getParameters().getExpressions().get(0)).getColumnName();
}
if (StringUtils.isEmpty(columnName)) { if (StringUtils.isEmpty(columnName)) {
return; return;
} }
Map<String, String> valueMap = filedNameToValueMap.get(columnName); Map<String, String> valueMap = new HashMap<>();
for (String key : filedNameToValueMap.keySet()) {
if (columnName.contains(key)) {
valueMap = filedNameToValueMap.get(key);
break;
}
}
//filedNameToValueMap.get(columnName);
if (Objects.isNull(valueMap) || valueMap.isEmpty()) { if (Objects.isNull(valueMap) || valueMap.isEmpty()) {
return; return;
} }
if (rightExpression instanceof LongValue) { if (rightExpression instanceof LongValue) {
LongValue rightStringValue = (LongValue) rightExpression; LongValue rightStringValue = (LongValue) rightExpression;
String replaceValue = getReplaceValue(valueMap, String.valueOf(rightStringValue.getValue())); String replaceValue = getReplaceValue(valueMap, String.valueOf(rightStringValue.getValue()));

View File

@@ -154,6 +154,19 @@ public class SqlParserSelectHelper {
return null; return null;
} }
public static List<FilterExpression> getHavingExpressions(String sql) {
PlainSelect plainSelect = getPlainSelect(sql);
if (Objects.isNull(plainSelect)) {
return new ArrayList<>();
}
Set<FilterExpression> result = new HashSet<>();
Expression having = plainSelect.getHaving();
if (Objects.nonNull(having)) {
having.accept(new FieldAndValueAcquireVisitor(result));
}
return new ArrayList<>(result);
}
public static List<String> getOrderByFields(String sql) { public static List<String> getOrderByFields(String sql) {
PlainSelect plainSelect = getPlainSelect(sql); PlainSelect plainSelect = getPlainSelect(sql);
if (Objects.isNull(plainSelect)) { if (Objects.isNull(plainSelect)) {

View File

@@ -59,6 +59,21 @@ public class SqlParserUpdateHelper {
return selectStatement.toString(); return selectStatement.toString();
} }
public static String replaceHavingValue(String sql, Map<String, Map<String, String>> filedNameToValueMap) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
return sql;
}
PlainSelect plainSelect = (PlainSelect) selectBody;
Expression having = plainSelect.getHaving();
FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(false, filedNameToValueMap);
if (Objects.nonNull(having)) {
having.accept(visitor);
}
return selectStatement.toString();
}
public static String replaceFieldNameByValue(String sql, Map<String, Set<String>> fieldValueToFieldNames) { public static String replaceFieldNameByValue(String sql, Map<String, Set<String>> fieldValueToFieldNames) {
Select selectStatement = SqlParserSelectHelper.getSelect(sql); Select selectStatement = SqlParserSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody(); SelectBody selectBody = selectStatement.getSelectBody();