(improvement)(headless&chat)Move ChatContext from Headless module to Chat module.

This commit is contained in:
jerryjzhang
2024-07-12 16:56:25 +08:00
parent e365a36749
commit baff30550e
78 changed files with 399 additions and 459 deletions

View File

@@ -11,7 +11,6 @@ import lombok.Data;
public class ExecuteQueryReq {
private User user;
private Long queryId;
private Integer chatId;
private String queryText;
private SemanticParseInfo parseInfo;
private boolean saveAnswer;

View File

@@ -9,6 +9,7 @@ import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import lombok.Data;
@@ -18,7 +19,6 @@ import java.util.Set;
@Data
public class QueryNLReq {
private String queryText;
private Integer chatId;
private Set<Long> dataSetIds = Sets.newHashSet();
private User user;
private QueryFilters queryFilters;
@@ -30,4 +30,5 @@ public class QueryNLReq {
private ChatModelConfig modelConfig;
private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars = Lists.newArrayList();
private SemanticParseInfo contextParseInfo;
}

View File

@@ -10,7 +10,6 @@ import java.util.stream.Collectors;
@Data
public class ParseResp {
private Integer chatId;
private String queryText;
private Long queryId;
private ParseState state = ParseState.PENDING;
@@ -24,8 +23,7 @@ public class ParseResp {
FAILED
}
public ParseResp(Integer chatId, String queryText) {
this.chatId = chatId;
public ParseResp(String queryText) {
this.queryText = queryText;
parseTimeCost.setParseStartTime(System.currentTimeMillis());
}

View File

@@ -1,14 +0,0 @@
package com.tencent.supersonic.headless.chat;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import lombok.Data;
@Data
public class ChatContext {
private Integer chatId;
private String queryText;
private SemanticParseInfo parseInfo = new SemanticParseInfo();
private String user;
}

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
@@ -35,7 +36,6 @@ import java.util.stream.Collectors;
public class ChatQueryContext {
private String queryText;
private Integer chatId;
private Set<Long> dataSetIds;
private Map<Long, List<Long>> modelIdToDataSetIds;
private User user;
@@ -54,6 +54,7 @@ public class ChatQueryContext {
private ChatModelConfig modelConfig;
private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars;
private SemanticParseInfo contextParseInfo;
public List<SemanticQuery> getCandidateQueries() {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);

View File

@@ -11,7 +11,6 @@ import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@@ -29,7 +28,7 @@ import java.util.stream.Collectors;
public class QueryTypeParser implements SemanticParser {
@Override
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
public void parse(ChatQueryContext chatQueryContext) {
List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries();
User user = chatQueryContext.getUser();

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.headless.chat.parser;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
/**
@@ -10,5 +9,5 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
*/
public interface SemanticParser {
void parse(ChatQueryContext chatQueryContext, ChatContext chatContext);
void parse(ChatQueryContext chatQueryContext);
}

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
@@ -23,7 +22,7 @@ import org.apache.commons.collections.MapUtils;
public class LLMSqlParser implements SemanticParser {
@Override
public void parse(ChatQueryContext queryCtx, ChatContext chatCtx) {
public void parse(ChatQueryContext queryCtx) {
try {
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
//1.determine whether to skip this parser.

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.chat.parser.rule;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
@@ -41,7 +40,7 @@ public class AggregateTypeParser implements SemanticParser {
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k2));
@Override
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
public void parse(ChatQueryContext chatQueryContext) {
String queryText = chatQueryContext.getQueryText();
AggregateConf aggregateConf = resolveAggregateConf(queryText);

View File

@@ -7,7 +7,6 @@ import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricIdQuery;
@@ -43,11 +42,11 @@ public class ContextInheritParser implements SemanticParser {
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
@Override
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
public void parse(ChatQueryContext chatQueryContext) {
if (!shouldInherit(chatQueryContext)) {
return;
}
Long dataSetId = getMatchedDataSet(chatQueryContext, chatContext);
Long dataSetId = getMatchedDataSet(chatQueryContext);
if (dataSetId == null) {
return;
}
@@ -55,10 +54,11 @@ public class ContextInheritParser implements SemanticParser {
List<SchemaElementMatch> elementMatches = chatQueryContext.getMapInfo().getMatchedElements(dataSetId);
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
for (SchemaElementMatch match : chatQueryContext.getContextParseInfo().getElementMatches()) {
SchemaElementType matchType = match.getElement().getType();
// mutual exclusive element types should not be inherited
RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery(chatContext.getParseInfo().getQueryMode());
RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery(
chatQueryContext.getContextParseInfo().getQueryMode());
if (!containsTypes(elementMatches, matchType, ruleQuery)) {
match.setInherited(true);
matchesToInherit.add(match);
@@ -68,7 +68,7 @@ public class ContextInheritParser implements SemanticParser {
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(chatQueryContext, chatContext);
query.fillParseInfo(chatQueryContext);
if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(), chatQueryContext)) {
continue;
}
@@ -108,8 +108,8 @@ public class ContextInheritParser implements SemanticParser {
return metricModelQueries.size() == chatQueryContext.getCandidateQueries().size();
}
protected Long getMatchedDataSet(ChatQueryContext chatQueryContext, ChatContext chatContext) {
Long dataSetId = chatContext.getParseInfo().getDataSetId();
protected Long getMatchedDataSet(ChatQueryContext chatQueryContext) {
Long dataSetId = chatQueryContext.getContextParseInfo().getDataSetId();
if (dataSetId == null) {
return null;
}

View File

@@ -5,7 +5,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.chat.ChatContext;
import lombok.extern.slf4j.Slf4j;
import java.util.Arrays;
import java.util.List;
@@ -24,7 +23,7 @@ public class RuleSqlParser implements SemanticParser {
);
@Override
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
public void parse(ChatQueryContext chatQueryContext) {
if (!chatQueryContext.getText2SQLType().enableRule()) {
return;
}
@@ -34,11 +33,11 @@ public class RuleSqlParser implements SemanticParser {
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(chatQueryContext, chatContext);
query.fillParseInfo(chatQueryContext);
chatQueryContext.getCandidateQueries().add(query);
}
}
auxiliaryParsers.stream().forEach(p -> p.parse(chatQueryContext, chatContext));
auxiliaryParsers.stream().forEach(p -> p.parse(chatQueryContext));
}
}

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat.parser.rule;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.chat.query.QueryManager;
@@ -42,7 +41,7 @@ public class TimeRangeParser implements SemanticParser {
private static final DateFormat DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd");
@Override
public void parse(ChatQueryContext queryContext, ChatContext chatContext) {
public void parse(ChatQueryContext queryContext) {
String queryText = queryContext.getQueryText();
DateConf dateConf = parseRecent(queryText);
if (dateConf == null) {
@@ -59,14 +58,14 @@ public class TimeRangeParser implements SemanticParser {
query.getParseInfo().setScore(query.getParseInfo().getScore()
+ dateConf.getDetectWord().length());
}
} else if (QueryManager.containsRuleQuery(chatContext.getParseInfo().getQueryMode())) {
} else if (QueryManager.containsRuleQuery(queryContext.getContextParseInfo().getQueryMode())) {
RuleSemanticQuery semanticQuery = QueryManager.createRuleQuery(
chatContext.getParseInfo().getQueryMode());
queryContext.getContextParseInfo().getQueryMode());
// inherit parse info from context
chatContext.getParseInfo().setDateInfo(dateConf);
chatContext.getParseInfo().setScore(chatContext.getParseInfo().getScore()
queryContext.getContextParseInfo().setDateInfo(dateConf);
queryContext.getContextParseInfo().setScore(queryContext.getContextParseInfo().getScore()
+ dateConf.getDetectWord().length());
semanticQuery.setParseInfo(chatContext.getParseInfo());
semanticQuery.setParseInfo(queryContext.getContextParseInfo());
queryContext.getCandidateQueries().add(semanticQuery);
}
}

View File

@@ -16,7 +16,6 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.ChatContext;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
@@ -49,13 +48,13 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
initS2SqlByStruct(semanticSchema);
}
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
public void fillParseInfo(ChatQueryContext chatQueryContext) {
parseInfo.setQueryMode(getQueryMode());
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
fillSchemaElement(parseInfo, semanticSchema);
fillScore(parseInfo);
fillDateConf(parseInfo, chatContext.getParseInfo());
fillDateConf(parseInfo, chatQueryContext.getContextParseInfo());
}
private void fillDateConf(SemanticParseInfo queryParseInfo, SemanticParseInfo chatParseInfo) {

View File

@@ -7,7 +7,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import org.apache.commons.collections.CollectionUtils;
@@ -19,8 +18,8 @@ import java.util.stream.Collectors;
public abstract class DetailListQuery extends DetailSemanticQuery {
@Override
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
super.fillParseInfo(chatQueryContext, chatContext);
public void fillParseInfo(ChatQueryContext chatQueryContext) {
super.fillParseInfo(chatQueryContext);
this.addEntityDetailAndOrderByMetric(chatQueryContext, parseInfo);
}

View File

@@ -10,7 +10,6 @@ import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.chat.ChatContext;
import lombok.extern.slf4j.Slf4j;
import java.time.LocalDate;
@@ -35,8 +34,8 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
}
@Override
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
super.fillParseInfo(chatQueryContext, chatContext);
public void fillParseInfo(ChatQueryContext chatQueryContext) {
super.fillParseInfo(chatQueryContext);
parseInfo.setQueryType(QueryType.DETAIL);
parseInfo.setLimit(DETAIL_MAX_RESULTS);

View File

@@ -8,7 +8,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.chat.ChatContext;
import lombok.extern.slf4j.Slf4j;
import java.time.LocalDate;
@@ -36,8 +35,8 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
}
@Override
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
super.fillParseInfo(chatQueryContext, chatContext);
public void fillParseInfo(ChatQueryContext chatQueryContext) {
super.fillParseInfo(chatQueryContext);
parseInfo.setLimit(METRIC_MAX_RESULTS);
if (parseInfo.getDateInfo() == null) {
DataSetSchema dataSetSchema =

View File

@@ -6,7 +6,6 @@ import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.ChatContext;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
@@ -50,8 +49,8 @@ public class MetricTopNQuery extends MetricSemanticQuery {
}
@Override
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
super.fillParseInfo(chatQueryContext, chatContext);
public void fillParseInfo(ChatQueryContext chatQueryContext) {
super.fillParseInfo(chatQueryContext);
parseInfo.setLimit(ORDERBY_MAX_RESULTS);
parseInfo.setScore(parseInfo.getScore() + 2.0);

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import lombok.extern.slf4j.Slf4j;
@@ -24,7 +24,7 @@ import javax.servlet.http.HttpServletResponse;
public class ChatQueryApiController {
@Autowired
private ChatQueryService chatQueryService;
private ChatLayerService chatLayerService;
@Autowired
private RetrieveService retrieveService;
@@ -45,7 +45,7 @@ public class ChatQueryApiController {
HttpServletRequest request,
HttpServletResponse response) {
queryNLReq.setUser(UserHolder.findUser(request, response));
return chatQueryService.performMapping(queryNLReq);
return chatLayerService.performMapping(queryNLReq);
}
@PostMapping("/chat/parse")
@@ -53,7 +53,7 @@ public class ChatQueryApiController {
HttpServletRequest request,
HttpServletResponse response) throws Exception {
queryNLReq.setUser(UserHolder.findUser(request, response));
return chatQueryService.performParsing(queryNLReq);
return chatLayerService.performParsing(queryNLReq);
}
@PostMapping("/chat")
@@ -61,7 +61,7 @@ public class ChatQueryApiController {
HttpServletRequest request,
HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
ParseResp parseResp = chatQueryService.performParsing(queryNLReq);
ParseResp parseResp = chatLayerService.performParsing(queryNLReq);
if (parseResp.getState().equals(ParseResp.ParseState.COMPLETED)) {
SemanticParseInfo parseInfo = parseResp.getSelectedParses().get(0);
QuerySqlReq sqlReq = new QuerySqlReq();

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.server.facade.rest;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping;
@@ -20,14 +20,14 @@ import javax.servlet.http.HttpServletResponse;
public class MetaDiscoveryApiController {
@Autowired
private ChatQueryService chatQueryService;
private ChatLayerService chatLayerService;
@PostMapping("map")
public Object map(@RequestBody QueryMapReq queryMapReq,
HttpServletRequest request, HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
queryMapReq.setUser(user);
return chatQueryService.map(queryMapReq);
return chatLayerService.map(queryMapReq);
}
}

View File

@@ -7,7 +7,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlsReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
@@ -33,7 +33,7 @@ public class SqlQueryApiController {
private SemanticLayerService semanticLayerService;
@Autowired
private ChatQueryService chatQueryService;
private ChatLayerService chatLayerService;
@PostMapping("/sql")
public Object queryBySql(@RequestBody QuerySqlReq querySqlReq,
@@ -42,7 +42,7 @@ public class SqlQueryApiController {
User user = UserHolder.findUser(request, response);
String sql = querySqlReq.getSql();
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
chatQueryService.correct(querySqlReq, user);
chatLayerService.correct(querySqlReq, user);
return semanticLayerService.queryByReq(querySqlReq, user);
}
@@ -56,7 +56,7 @@ public class SqlQueryApiController {
QuerySqlReq querySqlReq = new QuerySqlReq();
BeanUtils.copyProperties(querySqlsReq, querySqlReq);
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
chatQueryService.correct(querySqlReq, user);
chatLayerService.correct(querySqlReq, user);
return querySqlReq;
}).collect(Collectors.toList());
@@ -82,7 +82,7 @@ public class SqlQueryApiController {
QuerySqlReq querySqlReq = new QuerySqlReq();
BeanUtils.copyProperties(querySqlsReq, querySqlReq);
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
chatQueryService.correct(querySqlReq, user);
chatLayerService.correct(querySqlReq, user);
return querySqlReq;
}).collect(Collectors.toList());
List<SemanticQueryResp> semanticQueryRespList = new ArrayList<>();
@@ -104,7 +104,7 @@ public class SqlQueryApiController {
User user = UserHolder.findUser(request, response);
String sql = querySqlReq.getSql();
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
return chatQueryService.validate(querySqlReq, user);
return chatLayerService.validate(querySqlReq, user);
}
}

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.server.facade.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
@@ -16,14 +15,12 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
/***dd
* SemanticLayerService for query and search
*/
public interface ChatQueryService {
public interface ChatLayerService {
MapResp performMapping(QueryNLReq queryNLReq);
ParseResp performParsing(QueryNLReq queryNLReq);
SemanticParseInfo queryContext(Integer chatId);
QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws Exception;
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;

View File

@@ -22,7 +22,6 @@ import com.tencent.supersonic.headless.chat.mapper.MatchText;
import com.tencent.supersonic.headless.chat.mapper.ModelWithSemanticType;
import com.tencent.supersonic.headless.chat.mapper.SearchMatchStrategy;
import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
import com.tencent.supersonic.headless.server.web.service.ChatContextService;
import com.tencent.supersonic.headless.server.web.service.DataSetService;
import com.tencent.supersonic.headless.server.web.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
@@ -52,9 +51,6 @@ public class RetrieveServiceImpl implements RetrieveService {
@Autowired
private DataSetService dataSetService;
@Autowired
private ChatContextService chatContextService;
@Autowired
private SchemaService schemaService;
@@ -135,7 +131,7 @@ public class RetrieveServiceImpl implements RetrieveService {
List<Long> possibleDataSets = NatureHelper.selectPossibleDataSets(originals);
Long contextDataset = chatContextService.getContextDataset(queryCtx.getChatId());
Long contextDataset = queryCtx.getContextParseInfo().getDataSetId();
log.debug("possibleDataSets:{},dataSetInfoStat:{},contextDataset:{}",
possibleDataSets, dataSetInfoStat, contextDataset);

View File

@@ -44,7 +44,6 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.corrector.GrammarCorrector;
import com.tencent.supersonic.headless.chat.corrector.SchemaCorrector;
@@ -57,12 +56,11 @@ import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine;
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
import com.tencent.supersonic.headless.server.web.service.ChatContextService;
import com.tencent.supersonic.headless.server.web.service.DataSetService;
import com.tencent.supersonic.headless.server.web.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
@@ -98,14 +96,12 @@ import java.util.stream.Collectors;
@Service
@Slf4j
public class ChatQueryServiceImpl implements ChatQueryService {
public class S2ChatLayerService implements ChatLayerService {
@Autowired
private SemanticLayerService semanticLayerService;
@Autowired
private SchemaService schemaService;
@Autowired
private ChatContextService chatContextService;
@Autowired
private KnowledgeBaseService knowledgeBaseService;
@Autowired
private DataSetService dataSetService;
@@ -141,14 +137,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
@Override
public ParseResp performParsing(QueryNLReq queryNLReq) {
ParseResp parseResult = new ParseResp(queryNLReq.getChatId(), queryNLReq.getQueryText());
// build queryContext and chatContext
ParseResp parseResult = new ParseResp(queryNLReq.getQueryText());
// build queryContext
ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq);
// in order to support multi-turn conversation, chat context is needed
ChatContext chatCtx = chatContextService.getOrCreateContext(queryNLReq.getChatId());
chatWorkflowEngine.execute(queryCtx, chatCtx, parseResult);
chatWorkflowEngine.execute(queryCtx, parseResult);
List<SemanticParseInfo> parseInfos = queryCtx.getCandidateQueries().stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
@@ -173,12 +166,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return queryCtx;
}
@Override
public SemanticParseInfo queryContext(Integer chatId) {
ChatContext context = chatContextService.getOrCreateContext(chatId);
return context.getParseInfo();
}
//mainly used for executing after revising filters,for example:"fans_cnt>=100000"->"fans_cnt>500000",
//"style='流行'"->"style in ['流行','爱国']"
@Override

View File

@@ -1,16 +0,0 @@
package com.tencent.supersonic.headless.server.persistence.dataobject;
import lombok.Data;
import java.io.Serializable;
import java.time.Instant;
@Data
public class ChatContextDO implements Serializable {
private Integer chatId;
private Instant modifiedAt;
private String user;
private String queryText;
private String semanticParse;
}

View File

@@ -1,54 +0,0 @@
package com.tencent.supersonic.headless.server.persistence.dataobject;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import java.util.Date;
@Data
@Builder
@NoArgsConstructor
@Getter
@AllArgsConstructor
public class StatisticsDO {
/**
* questionId
*/
private Long questionId;
/**
* chatId
*/
private Long chatId;
/**
* createTime
*/
private Date createTime;
/**
* queryText
*/
private String queryText;
/**
* userName
*/
private String userName;
/**
* interface
*/
private String interfaceName;
/**
* cost
*/
private Integer cost;
private Integer type;
}

View File

@@ -1,14 +0,0 @@
package com.tencent.supersonic.headless.server.persistence.mapper;
import com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface ChatContextMapper {
ChatContextDO getContextByChatId(Integer chatId);
int updateContext(ChatContextDO contextDO);
int addContext(ChatContextDO contextDO);
}

View File

@@ -1,12 +0,0 @@
package com.tencent.supersonic.headless.server.persistence.mapper;
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
@Mapper
public interface StatisticsMapper {
boolean batchSaveStatistics(@Param("list") List<StatisticsDO> list);
}

View File

@@ -1,11 +0,0 @@
package com.tencent.supersonic.headless.server.persistence.repository;
import com.tencent.supersonic.headless.chat.ChatContext;
public interface ChatContextRepository {
ChatContext getOrCreateContext(Integer chatId);
void updateContext(ChatContext chatCtx);
}

View File

@@ -1,71 +0,0 @@
package com.tencent.supersonic.headless.server.persistence.repository.impl;
import com.google.gson.Gson;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO;
import com.tencent.supersonic.headless.server.persistence.mapper.ChatContextMapper;
import com.tencent.supersonic.headless.server.persistence.repository.ChatContextRepository;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;
@Repository
@Primary
@Slf4j
public class ChatContextRepositoryImpl implements ChatContextRepository {
private final ChatContextMapper chatContextMapper;
public ChatContextRepositoryImpl(ChatContextMapper chatContextMapper) {
this.chatContextMapper = chatContextMapper;
}
@Override
public ChatContext getOrCreateContext(Integer chatId) {
ChatContextDO context = chatContextMapper.getContextByChatId(chatId);
if (context == null) {
ChatContext chatContext = new ChatContext();
chatContext.setChatId(chatId);
return chatContext;
}
return cast(context);
}
@Override
public void updateContext(ChatContext chatCtx) {
ChatContextDO context = cast(chatCtx);
if (chatContextMapper.getContextByChatId(chatCtx.getChatId()) == null) {
chatContextMapper.addContext(context);
} else {
chatContextMapper.updateContext(context);
}
}
private ChatContext cast(ChatContextDO contextDO) {
ChatContext chatContext = new ChatContext();
chatContext.setChatId(contextDO.getChatId());
chatContext.setUser(contextDO.getUser());
chatContext.setQueryText(contextDO.getQueryText());
if (contextDO.getSemanticParse() != null && !contextDO.getSemanticParse().isEmpty()) {
SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(),
SemanticParseInfo.class);
chatContext.setParseInfo(semanticParseInfo);
}
return chatContext;
}
private ChatContextDO cast(ChatContext chatContext) {
ChatContextDO chatContextDO = new ChatContextDO();
chatContextDO.setChatId(chatContext.getChatId());
chatContextDO.setQueryText(chatContext.getQueryText());
chatContextDO.setUser(chatContext.getUser());
if (chatContext.getParseInfo() != null) {
Gson g = new Gson();
chatContextDO.setSemanticParse(g.toJson(chatContext.getParseInfo()));
}
return chatContextDO;
}
}

View File

@@ -14,7 +14,6 @@ import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.server.web.service.SchemaService;
@@ -40,7 +39,7 @@ import java.util.stream.Collectors;
public class ParseInfoProcessor implements ResultProcessor {
@Override
public void process(ParseResp parseResp, ChatQueryContext chatQueryContext, ChatContext chatContext) {
public void process(ParseResp parseResp, ChatQueryContext chatQueryContext) {
List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries();
if (CollectionUtils.isEmpty(candidateQueries)) {
return;

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.server.processor;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
/**
@@ -9,6 +8,6 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
*/
public interface ResultProcessor {
void process(ParseResp parseResp, ChatQueryContext chatQueryContext, ChatContext chatContext);
void process(ParseResp parseResp, ChatQueryContext chatQueryContext);
}

View File

@@ -7,7 +7,6 @@ import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
import com.tencent.supersonic.headless.chat.mapper.SchemaMapper;
@@ -39,7 +38,7 @@ public class ChatWorkflowEngine {
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
private List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
public void execute(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
public void execute(ChatQueryContext queryCtx, ParseResp parseResult) {
queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING);
while (queryCtx.getChatWorkflowState() != ChatWorkflowState.FINISHED) {
switch (queryCtx.getChatWorkflowState()) {
@@ -54,7 +53,7 @@ public class ChatWorkflowEngine {
}
break;
case PARSING:
performParsing(queryCtx, chatCtx);
performParsing(queryCtx);
if (queryCtx.getCandidateQueries().size() == 0) {
parseResult.setState(ParseResp.ParseState.FAILED);
parseResult.setErrorMsg("No semantic queries can be parsed out.");
@@ -75,7 +74,7 @@ public class ChatWorkflowEngine {
break;
case PROCESSING:
default:
performProcessing(queryCtx, chatCtx, parseResult);
performProcessing(queryCtx, parseResult);
if (parseResult.getState().equals(ParseResp.ParseState.PENDING)) {
parseResult.setState(ParseResp.ParseState.COMPLETED);
}
@@ -92,9 +91,9 @@ public class ChatWorkflowEngine {
}
}
private void performParsing(ChatQueryContext queryCtx, ChatContext chatCtx) {
private void performParsing(ChatQueryContext queryCtx) {
semanticParsers.forEach(parser -> {
parser.parse(queryCtx, chatCtx);
parser.parse(queryCtx);
log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
});
}
@@ -116,9 +115,9 @@ public class ChatWorkflowEngine {
}
}
private void performProcessing(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
private void performProcessing(ChatQueryContext queryCtx, ParseResp parseResult) {
resultProcessors.forEach(processor -> {
processor.process(parseResult, queryCtx, chatCtx);
processor.process(parseResult, queryCtx);
});
}

View File

@@ -1,18 +0,0 @@
package com.tencent.supersonic.headless.server.web.service;
import com.tencent.supersonic.headless.chat.ChatContext;
public interface ChatContextService {
/***
* get the model from context
* @param chatId
* @return
*/
Long getContextDataset(Integer chatId);
ChatContext getOrCreateContext(Integer chatId);
void updateContext(ChatContext chatCtx);
}

View File

@@ -1,50 +0,0 @@
package com.tencent.supersonic.headless.server.web.service.impl;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.server.persistence.repository.ChatContextRepository;
import com.tencent.supersonic.headless.server.web.service.ChatContextService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.Objects;
@Slf4j
@Service
public class ChatContextServiceImpl implements ChatContextService {
private ChatContextRepository chatContextRepository;
public ChatContextServiceImpl(ChatContextRepository chatContextRepository) {
this.chatContextRepository = chatContextRepository;
}
@Override
public Long getContextDataset(Integer chatId) {
if (Objects.isNull(chatId)) {
return null;
}
ChatContext chatContext = getOrCreateContext(chatId);
if (Objects.isNull(chatContext)) {
return null;
}
SemanticParseInfo originalSemanticParse = chatContext.getParseInfo();
if (Objects.nonNull(originalSemanticParse) && Objects.nonNull(originalSemanticParse.getDataSetId())) {
return originalSemanticParse.getDataSetId();
}
return null;
}
@Override
public ChatContext getOrCreateContext(Integer chatId) {
return chatContextRepository.getOrCreateContext(chatId);
}
@Override
public void updateContext(ChatContext chatCtx) {
log.debug("save ChatContext {}", chatCtx);
chatContextRepository.updateContext(chatCtx);
}
}

View File

@@ -44,7 +44,7 @@ import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.TagItem;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.persistence.dataobject.CollectDO;
import com.tencent.supersonic.headless.server.persistence.dataobject.MetricDO;
import com.tencent.supersonic.headless.server.persistence.dataobject.MetricQueryDefaultConfigDO;
@@ -111,7 +111,7 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
private TagMetaService tagMetaService;
private ChatQueryService chatQueryService;
private ChatLayerService chatLayerService;
public MetricServiceImpl(MetricRepository metricRepository,
ModelService modelService,
@@ -121,7 +121,7 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
ApplicationEventPublisher eventPublisher,
DimensionService dimensionService,
TagMetaService tagMetaService,
@Lazy ChatQueryService chatQueryService) {
@Lazy ChatLayerService chatLayerService) {
this.metricRepository = metricRepository;
this.modelService = modelService;
this.aliasGenerateHelper = aliasGenerateHelper;
@@ -130,7 +130,7 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
this.dataSetService = dataSetService;
this.dimensionService = dimensionService;
this.tagMetaService = tagMetaService;
this.chatQueryService = chatQueryService;
this.chatLayerService = chatLayerService;
}
@Override
@@ -299,7 +299,7 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
queryMapReq.setQueryText(pageMetricReq.getKey());
queryMapReq.setUser(user);
queryMapReq.setMapModeEnum(MapModeEnum.LOOSE);
MapInfoResp mapMeta = chatQueryService.map(queryMapReq);
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
Map<String, DataSetMapInfo> dataSetMapInfoMap = mapMeta.getDataSetMapInfo();
if (CollectionUtils.isEmpty(dataSetMapInfoMap)) {
return metricRespPageInfo;

View File

@@ -1,30 +0,0 @@
<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.tencent.supersonic.headless.server.persistence.mapper.ChatContextMapper">
<resultMap id="ChatContextDO"
type="com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO">
<id column="chat_id" property="chatId"/>
<result column="modified_at" property="modifiedAt"/>
<result column="user" property="user"/>
<result column="query_text" property="queryText"/>
<result column="semantic_parse" property="semanticParse"/>
<!--<result column="ext_data" property="extData"/>-->
</resultMap>
<select id="getContextByChatId" resultMap="ChatContextDO">
select *
from s2_chat_context where chat_id=#{chatId} limit 1
</select>
<insert id="addContext" parameterType="com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO" >
insert into s2_chat_context (chat_id,user,query_text,semantic_parse) values (#{chatId}, #{user},#{queryText}, #{semanticParse})
</insert>
<update id="updateContext">
update s2_chat_context set query_text=#{queryText},semantic_parse=#{semanticParse} where chat_id=#{chatId}
</update>
</mapper>

View File

@@ -1,28 +0,0 @@
<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.tencent.supersonic.headless.server.persistence.mapper.StatisticsMapper">
<resultMap id="Statistics" type="com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO">
<id column="question_id" property="questionId"/>
<result column="chat_id" property="chatId"/>
<result column="user_name" property="userName"/>
<result column="query_text" property="queryText"/>
<result column="interface_name" property="interfaceName"/>
<result column="cost" property="cost"/>
<result column="type" property="type"/>
<result column="create_time" property="createTime"/>
</resultMap>
<insert id="batchSaveStatistics" parameterType="com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO">
insert into s2_chat_statistics
(question_id,chat_id, user_name, query_text, interface_name,cost,type ,create_time)
values
<foreach collection="list" item="item" index="index" separator=",">
(#{item.questionId}, #{item.chatId}, #{item.userName}, #{item.queryText}, #{item.interfaceName}, #{item.cost}, #{item.type},#{item.createTime})
</foreach>
</insert>
</mapper>

View File

@@ -16,7 +16,7 @@ import com.tencent.supersonic.headless.api.pojo.enums.MetricType;
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.persistence.dataobject.MetricDO;
import com.tencent.supersonic.headless.server.persistence.repository.MetricRepository;
import com.tencent.supersonic.headless.server.utils.AliasGenerateHelper;
@@ -77,10 +77,10 @@ public class MetricServiceImplTest {
DataSetService dataSetService = Mockito.mock(DataSetServiceImpl.class);
DimensionService dimensionService = Mockito.mock(DimensionService.class);
TagMetaService tagMetaService = Mockito.mock(TagMetaService.class);
ChatQueryService chatQueryService = Mockito.mock(ChatQueryService.class);
ChatLayerService chatLayerService = Mockito.mock(ChatLayerService.class);
return new MetricServiceImpl(metricRepository, modelService, aliasGenerateHelper,
collectService, dataSetService, eventPublisher, dimensionService,
tagMetaService, chatQueryService);
tagMetaService, chatLayerService);
}
private MetricReq buildMetricReq() {