From 26beff10802530bb0bbe3f7d1bfb6ba13fb066b0 Mon Sep 17 00:00:00 2001 From: mainmain <57514971+mainmainer@users.noreply.github.com> Date: Thu, 12 Oct 2023 21:45:40 +0800 Subject: [PATCH] (improvement)(chat) dsl supports revision (#200) --- .../impl/ChatQueryRepositoryImpl.java | 4 +- .../supersonic/chat/query/QueryManager.java | 6 +- .../execute/EntityInfoExecuteResponder.java | 6 +- .../parse/EntityInfoParseResponder.java | 14 ++- .../parse/ExplainSqlParseResponder.java | 14 +-- .../chat/rest/ChatQueryController.java | 17 +++- .../supersonic/chat/service/QueryService.java | 4 + .../chat/service/impl/QueryServiceImpl.java | 85 +++++++++++++------ .../main/resources/mapper/ChatParseMapper.xml | 2 +- .../resources/mapper/ChatQueryDOMapper.xml | 6 +- .../jsqlparser/FieldlValueReplaceVisitor.java | 30 +++++-- .../jsqlparser/SqlParserSelectHelper.java | 13 +++ .../jsqlparser/SqlParserUpdateHelper.java | 15 ++++ 13 files changed, 162 insertions(+), 54 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/ChatQueryRepositoryImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/ChatQueryRepositoryImpl.java index 0d99937c2..33b83c28d 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/ChatQueryRepositoryImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/ChatQueryRepositoryImpl.java @@ -118,9 +118,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { } catch (Exception e) { log.info("database insert has an exception:{}", e.toString()); } - - ChatQueryDO lastChatQuery = getLastChatQuery(chatCtx.getChatId()); - Long queryId = lastChatQuery.getQuestionId(); + Long queryId = chatQueryDO.getQuestionId(); parseResult.setQueryId(queryId); return queryId; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/QueryManager.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/QueryManager.java index 9b808f7f1..cd1a70d6d 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/QueryManager.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/QueryManager.java @@ -77,6 +77,10 @@ public class QueryManager { return ruleQueryMap.get(queryMode) instanceof EntitySemanticQuery; } + public static boolean isPluginQuery(String queryMode) { + return queryMode != null && pluginQueryMap.containsKey(queryMode); + } + public static RuleSemanticQuery getRuleQuery(String queryMode) { if (queryMode == null) { return null; @@ -92,4 +96,4 @@ public class QueryManager { return new ArrayList<>(pluginQueryMap.keySet()); } -} \ No newline at end of file +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java index 6ec83bec6..c89137af5 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java @@ -5,6 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; +import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.common.util.ContextUtils; import java.util.List; @@ -21,6 +22,9 @@ public class EntityInfoExecuteResponder implements ExecuteResponder { if (semanticParseInfo == null || semanticParseInfo.getModelId() <= 0L) { return; } + if (QueryManager.isPluginQuery(semanticParseInfo.getQueryMode())) { + return; + } SemanticService semanticService = ContextUtils.getBean(SemanticService.class); User user = queryReq.getUser(); EntityInfo entityInfo = semanticService.getEntityInfo(semanticParseInfo, user); @@ -50,4 +54,4 @@ public class EntityInfoExecuteResponder implements ExecuteResponder { } } -} \ No newline at end of file +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java index fc8646a9f..c04ba12f4 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java @@ -6,12 +6,15 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.query.QueryManager; +import com.tencent.supersonic.chat.query.llm.dsl.DslQuery; import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.common.util.ContextUtils; import java.util.List; + +import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; - +@Slf4j public class EntityInfoParseResponder implements ParseResponder { @Override @@ -22,15 +25,20 @@ public class EntityInfoParseResponder implements ParseResponder { } QueryReq queryReq = queryContext.getRequest(); selectedParses.forEach(parseInfo -> { + if (QueryManager.isPluginQuery(parseInfo.getQueryMode()) + && !parseInfo.getQueryMode().equals(DslQuery.QUERY_MODE)) { + return; + } //1. set entity info SemanticService semanticService = ContextUtils.getBean(SemanticService.class); EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, queryReq.getUser()); - if (QueryManager.isEntityQuery(parseInfo.getQueryMode()) || QueryManager.isMetricQuery(parseInfo.getQueryMode())) { parseInfo.setEntityInfo(entityInfo); } //2. set native value + entityInfo = semanticService.getEntityInfo(parseInfo.getModelId()); + log.info("entityInfo:{}", entityInfo); String primaryEntityBizName = semanticService.getPrimaryEntityBizName(entityInfo); if (StringUtils.isNotEmpty(primaryEntityBizName)) { //if exist primaryEntityBizName in parseInfo's dimensions, set nativeQuery to true @@ -40,4 +48,4 @@ public class EntityInfoParseResponder implements ParseResponder { } }); } -} \ No newline at end of file +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/ExplainSqlParseResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/ExplainSqlParseResponder.java index 22a82e734..ab63a6cb7 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/ExplainSqlParseResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/ExplainSqlParseResponder.java @@ -15,12 +15,16 @@ public class ExplainSqlParseResponder implements ParseResponder { @Override public void fillResponse(ParseResp parseResp, QueryContext queryContext) { - List selectedParses = parseResp.getSelectedParses(); - if (CollectionUtils.isEmpty(selectedParses)) { + QueryReq queryReq = queryContext.getRequest(); + addExplainSql(queryReq, parseResp.getSelectedParses()); + addExplainSql(queryReq, parseResp.getCandidateParses()); + } + + private void addExplainSql(QueryReq queryReq, List semanticParseInfos) { + if (CollectionUtils.isEmpty(semanticParseInfos)) { return; } - QueryReq queryReq = queryContext.getRequest(); - selectedParses.forEach(parseInfo -> { + semanticParseInfos.forEach(parseInfo -> { addExplainSql(queryReq, parseInfo); }); } @@ -38,4 +42,4 @@ public class ExplainSqlParseResponder implements ParseResponder { parseInfo.getSqlInfo().setQuerySql(explain.getSql()); } -} \ No newline at end of file +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatQueryController.java b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatQueryController.java index f2485d50e..a448b07b4 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatQueryController.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatQueryController.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.chat.rest; +import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.chat.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; @@ -34,7 +35,7 @@ public class ChatQueryController { @PostMapping("search") public Object search(@RequestBody QueryReq queryCtx, HttpServletRequest request, - HttpServletResponse response) { + HttpServletResponse response) { queryCtx.setUser(UserHolder.findUser(request, response)); return searchService.search(queryCtx); } @@ -55,7 +56,7 @@ public class ChatQueryController { @PostMapping("execute") public Object execute(@RequestBody ExecuteQueryReq queryReq, - HttpServletRequest request, HttpServletResponse response) + HttpServletRequest request, HttpServletResponse response) throws Exception { queryReq.setUser(UserHolder.findUser(request, response)); return queryService.performExecution(queryReq); @@ -63,14 +64,14 @@ public class ChatQueryController { @PostMapping("queryContext") public Object queryContext(@RequestBody QueryReq queryCtx, HttpServletRequest request, - HttpServletResponse response) throws Exception { + HttpServletResponse response) throws Exception { queryCtx.setUser(UserHolder.findUser(request, response)); return queryService.queryContext(queryCtx); } @PostMapping("queryData") public Object queryData(@RequestBody QueryDataReq queryData, - HttpServletRequest request, HttpServletResponse response) + HttpServletRequest request, HttpServletResponse response) throws Exception { queryData.setUser(UserHolder.findUser(request, response)); return queryService.executeDirectQuery(queryData, UserHolder.findUser(request, response)); @@ -83,4 +84,12 @@ public class ChatQueryController { return queryService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response)); } + @RequestMapping("/getEntityInfo") + public Object getEntityInfo(Long queryId, Integer parseId, + HttpServletRequest request, + HttpServletResponse response) { + User user = UserHolder.findUser(request, response); + return queryService.getEntityInfo(queryId, parseId, user); + } } + diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/QueryService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/QueryService.java index 482f4787f..81296edf9 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/QueryService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/QueryService.java @@ -5,6 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq; @@ -25,5 +26,8 @@ public interface QueryService { QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws SqlParseException; + EntityInfo getEntityInfo(Long queryId, Integer parseId, User user); + Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception; } + diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 2c28aa7e3..d2e201e5b 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -15,6 +15,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq; +import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; @@ -31,24 +32,24 @@ import com.tencent.supersonic.chat.responder.execute.ExecuteResponder; import com.tencent.supersonic.chat.responder.parse.ParseResponder; import com.tencent.supersonic.chat.service.ChatService; import com.tencent.supersonic.chat.service.QueryService; +import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.service.StatisticsService; import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.SolvedQueryManager; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.QueryColumn; +import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.jsqlparser.FilterExpression; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.knowledge.dictionary.MapResult; import com.tencent.supersonic.knowledge.service.SearchService; -import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; -import com.tencent.supersonic.semantic.query.utils.QueryStructUtils; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; @@ -58,6 +59,8 @@ import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; + + import lombok.extern.slf4j.Slf4j; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.commons.collections.CollectionUtils; @@ -96,24 +99,20 @@ public class QueryServiceImpl implements QueryService { // in order to support multi-turn conversation, chat context is needed ChatContext chatCtx = chatService.getOrCreateContext(queryReq.getChatId()); List timeCostDOList = new ArrayList<>(); - for (SchemaMapper mapper : schemaMappers) { + schemaMappers.stream().forEach(mapper -> { Long startTime = System.currentTimeMillis(); mapper.map(queryCtx); - Long endTime = System.currentTimeMillis(); - String className = mapper.getClass().getSimpleName(); - timeCostDOList.add(StatisticsDO.builder().cost((int) (endTime - startTime)) - .interfaceName(className).type(CostType.MAPPER.getType()).build()); - log.info("{} result:{}", className, JsonUtil.toString(queryCtx)); - } - for (SemanticParser parser : semanticParsers) { + timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime)) + .interfaceName(mapper.getClass().getSimpleName()).type(CostType.MAPPER.getType()).build()); + log.info("{} result:{}", mapper.getClass().getSimpleName(), JsonUtil.toString(queryCtx)); + }); + semanticParsers.stream().forEach(parser -> { Long startTime = System.currentTimeMillis(); parser.parse(queryCtx, chatCtx); - Long endTime = System.currentTimeMillis(); - String className = parser.getClass().getSimpleName(); - timeCostDOList.add(StatisticsDO.builder().cost((int) (endTime - startTime)) - .interfaceName(className).type(CostType.PARSER.getType()).build()); - log.info("{} result:{}", className, JsonUtil.toString(queryCtx)); - } + timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime)) + .interfaceName(parser.getClass().getSimpleName()).type(CostType.PARSER.getType()).build()); + log.info("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx)); + }); ParseResp parseResult; if (queryCtx.getCandidateQueries().size() > 0) { log.debug("pick before [{}]", queryCtx.getCandidateQueries().stream().collect( @@ -124,6 +123,7 @@ public class QueryServiceImpl implements QueryService { List selectedParses = convertParseInfo(selectedQueries); List candidateParses = convertParseInfo(queryCtx.getCandidateQueries()); + candidateParses = getTop5CandidateParseInfo(selectedParses, candidateParses); parseResult = ParseResp.builder() .chatId(queryReq.getChatId()) .queryText(queryReq.getQueryText()) @@ -154,6 +154,24 @@ public class QueryServiceImpl implements QueryService { .collect(Collectors.toList()); } + private List getTop5CandidateParseInfo(List selectedParses, + List candidateParses) { + if (CollectionUtils.isEmpty(selectedParses) || CollectionUtils.isEmpty(candidateParses)) { + return candidateParses; + } + int selectParseSize = selectedParses.size(); + int candidateParseSize = 5 - selectParseSize; + SemanticParseInfo semanticParseInfo = selectedParses.get(0); + Long modelId = semanticParseInfo.getModelId(); + if (modelId == null || modelId <= 0) { + return candidateParses; + } + return candidateParses.stream() + .sorted(Comparator.comparing(parse -> !parse.getModelId().equals(modelId))) + .limit(candidateParseSize) + .collect(Collectors.toList()); + } + @Override public QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception { ChatParseDO chatParseDO = chatService.getParseInfo(queryReq.getQueryId(), @@ -277,6 +295,7 @@ public class QueryServiceImpl implements QueryService { if (DslQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) { Map> filedNameToValueMap = new HashMap<>(); + Map> havingFiledNameToValueMap = new HashMap<>(); String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT)); DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class); LLMResp llmResp = dslParseResult.getLlmResp(); @@ -288,41 +307,55 @@ public class QueryServiceImpl implements QueryService { updateFilters(filedNameToValueMap, filterExpressionList, queryData.getDimensionFilters(), parseInfo.getDimensionFilters()); - updateFilters(filedNameToValueMap, filterExpressionList, queryData.getMetricFilters(), - parseInfo.getMetricFilters()); + updateFilters(havingFiledNameToValueMap, filterExpressionList, queryData.getDimensionFilters(), + parseInfo.getDimensionFilters()); updateDateInfo(queryData, parseInfo, filedNameToValueMap, filterExpressionList); log.info("filedNameToValueMap:{}", filedNameToValueMap); correctorSql = SqlParserUpdateHelper.replaceValue(correctorSql, filedNameToValueMap); + log.info("havingFiledNameToValueMap:{}", havingFiledNameToValueMap); + correctorSql = SqlParserUpdateHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap); log.info("correctorSql after replacing:{}", correctorSql); llmResp.setCorrectorSql(correctorSql); - + dslParseResult.setLlmResp(llmResp); + Map properties = new HashMap<>(); + properties.put(Constants.CONTEXT, dslParseResult); + parseInfo.setProperties(properties); parseInfo.getSqlInfo().setLogicSql(correctorSql); - + semanticQuery.setParseInfo(parseInfo); ExplainResp explain = semanticQuery.explain(user); if (!Objects.isNull(explain)) { parseInfo.getSqlInfo().setQuerySql(explain.getSql()); } } + log.info("parseInfo:{}", JsonUtil.toString(semanticQuery.getParseInfo().getProperties())); semanticQuery.setParseInfo(parseInfo); QueryResult queryResult = semanticQuery.execute(user); queryResult.setChatContext(semanticQuery.getParseInfo()); return queryResult; } + @Override + public EntityInfo getEntityInfo(Long queryId, Integer parseId, User user) { + ChatParseDO chatParseDO = chatService.getParseInfo(queryId, user.getName(), parseId); + SemanticParseInfo parseInfo = JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class); + SemanticService semanticService = ContextUtils.getBean(SemanticService.class); + return semanticService.getEntityInfo(parseInfo, user); + } + private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo, Map> filedNameToValueMap, List filterExpressionList) { if (Objects.isNull(queryData.getDateInfo())) { return; } Map map = new HashMap<>(); - List dateFields = new ArrayList<>(QueryStructUtils.internalTimeCols); - String dateField = TimeDimensionEnum.DAY.getName(); + //List dateFields = new ArrayList<>(QueryStructUtils.internalTimeCols); + String dateField = "数据日期"; if (queryData.getDateInfo().getStartDate().equals(queryData.getDateInfo().getEndDate())) { for (FilterExpression filterExpression : filterExpressionList) { if (filterExpression.getFieldName() != null - && dateFields.contains(filterExpression.getFieldName())) { + && filterExpression.getFieldName().equals("数据日期")) { dateField = filterExpression.getFieldName(); map.put(filterExpression.getFieldValue().toString(), queryData.getDateInfo().getStartDate()); @@ -331,7 +364,7 @@ public class QueryServiceImpl implements QueryService { } } else { for (FilterExpression filterExpression : filterExpressionList) { - if (dateFields.contains(filterExpression.getFieldName())) { + if (filterExpression.getFieldName().equals("数据日期")) { dateField = filterExpression.getFieldName(); if (FilterOperatorEnum.GREATER_THAN_EQUALS.getValue().equals(filterExpression.getOperator()) || FilterOperatorEnum.GREATER_THAN.getValue().equals(filterExpression.getOperator())) { @@ -360,7 +393,7 @@ public class QueryServiceImpl implements QueryService { Map map = new HashMap<>(); for (FilterExpression filterExpression : filterExpressionList) { if (filterExpression.getFieldName() != null - && filterExpression.getFieldName().equals(dslQueryFilter.getName()) + && filterExpression.getFieldName().contains(dslQueryFilter.getName()) && dslQueryFilter.getOperator().getValue().equals(filterExpression.getOperator())) { map.put(filterExpression.getFieldValue().toString(), dslQueryFilter.getValue().toString()); contextMetricFilters.stream().forEach(o -> { @@ -407,7 +440,7 @@ public class QueryServiceImpl implements QueryService { queryStructReq.setDateInfo(dateConf); queryStructReq.setLimit(20L); queryStructReq.setModelId(dimensionValueReq.getModelId()); - queryStructReq.setNativeQuery(true); + queryStructReq.setNativeQuery(false); List groups = new ArrayList<>(); groups.add(dimensionValueReq.getBizName()); queryStructReq.setGroups(groups); diff --git a/chat/core/src/main/resources/mapper/ChatParseMapper.xml b/chat/core/src/main/resources/mapper/ChatParseMapper.xml index cede70423..1d6b7f344 100644 --- a/chat/core/src/main/resources/mapper/ChatParseMapper.xml +++ b/chat/core/src/main/resources/mapper/ChatParseMapper.xml @@ -29,7 +29,7 @@ select * from s2_chat_parse where question_id = #{questionId} and user_name = #{userName} - and parse_id = #{parseId} and is_candidate = 0 limit 1 + and parse_id = #{parseId} limit 1 diff --git a/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml b/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml index 1cba2a902..2c4554d26 100644 --- a/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml +++ b/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml @@ -72,12 +72,12 @@ delete from s2_chat_query where question_id = #{questionId,jdbcType=BIGINT} - - insert into s2_chat_query (question_id, agent_id, create_time, user_name, + + insert into s2_chat_query (agent_id, create_time, user_name, query_state, chat_id, score, feedback, query_text, query_result ) - values (#{questionId,jdbcType=BIGINT}, #{agentId,jdbcType=INTEGER}, #{createTime,jdbcType=TIMESTAMP}, #{userName,jdbcType=VARCHAR}, + values (#{agentId,jdbcType=INTEGER}, #{createTime,jdbcType=TIMESTAMP}, #{userName,jdbcType=VARCHAR}, #{queryState,jdbcType=INTEGER}, #{chatId,jdbcType=BIGINT}, #{score,jdbcType=INTEGER}, #{feedback,jdbcType=VARCHAR}, #{queryText,jdbcType=LONGVARCHAR}, #{queryResult,jdbcType=LONGVARCHAR} ) diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java index 7396d3ce1..74c3c5454 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java @@ -1,13 +1,16 @@ package com.tencent.supersonic.common.util.jsqlparser; +import java.util.HashMap; import java.util.Map; import java.util.Objects; +import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; -import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.DoubleValue; +import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; import net.sf.jsqlparser.expression.operators.relational.GreaterThan; import net.sf.jsqlparser.expression.operators.relational.MinorThan; @@ -19,6 +22,7 @@ import net.sf.jsqlparser.schema.Column; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; +@Slf4j public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter { ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper(); @@ -55,7 +59,7 @@ public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter { public void replaceComparisonExpression(T expression) { Expression leftExpression = ((ComparisonOperator) expression).getLeftExpression(); Expression rightExpression = ((ComparisonOperator) expression).getRightExpression(); - if (!(leftExpression instanceof Column)) { + if (!(leftExpression instanceof Column || leftExpression instanceof Function)) { return; } if (CollectionUtils.isEmpty(filedNameToValueMap)) { @@ -64,18 +68,30 @@ public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter { if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) { return; } - Column leftColumnName = (Column) leftExpression; - - String columnName = leftColumnName.getColumnName(); + String columnName = ""; + if (leftExpression instanceof Column) { + Column leftColumnName = (Column) leftExpression; + columnName = leftColumnName.getColumnName(); + } + if (leftExpression instanceof Function) { + Function function = (Function) leftExpression; + columnName = ((Column) function.getParameters().getExpressions().get(0)).getColumnName(); + } if (StringUtils.isEmpty(columnName)) { return; } - Map valueMap = filedNameToValueMap.get(columnName); + Map valueMap = new HashMap<>(); + for (String key : filedNameToValueMap.keySet()) { + if (columnName.contains(key)) { + valueMap = filedNameToValueMap.get(key); + break; + } + } + //filedNameToValueMap.get(columnName); if (Objects.isNull(valueMap) || valueMap.isEmpty()) { return; } - if (rightExpression instanceof LongValue) { LongValue rightStringValue = (LongValue) rightExpression; String replaceValue = getReplaceValue(valueMap, String.valueOf(rightStringValue.getValue())); diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java index cce8972bb..42d33bc97 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java @@ -154,6 +154,19 @@ public class SqlParserSelectHelper { return null; } + public static List getHavingExpressions(String sql) { + PlainSelect plainSelect = getPlainSelect(sql); + if (Objects.isNull(plainSelect)) { + return new ArrayList<>(); + } + Set result = new HashSet<>(); + Expression having = plainSelect.getHaving(); + if (Objects.nonNull(having)) { + having.accept(new FieldAndValueAcquireVisitor(result)); + } + return new ArrayList<>(result); + } + public static List getOrderByFields(String sql) { PlainSelect plainSelect = getPlainSelect(sql); if (Objects.isNull(plainSelect)) { diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java index a1f9ffa98..28cf55336 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java @@ -59,6 +59,21 @@ public class SqlParserUpdateHelper { return selectStatement.toString(); } + public static String replaceHavingValue(String sql, Map> filedNameToValueMap) { + Select selectStatement = SqlParserSelectHelper.getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + PlainSelect plainSelect = (PlainSelect) selectBody; + Expression having = plainSelect.getHaving(); + FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(false, filedNameToValueMap); + if (Objects.nonNull(having)) { + having.accept(visitor); + } + return selectStatement.toString(); + } + public static String replaceFieldNameByValue(String sql, Map> fieldValueToFieldNames) { Select selectStatement = SqlParserSelectHelper.getSelect(sql); SelectBody selectBody = selectStatement.getSelectBody();