(improvement)(Chat) Move chat-core to headless (#805)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2024-03-12 22:20:30 +08:00
committed by GitHub
parent f152deeb81
commit f93bee81cb
301 changed files with 2256 additions and 4527 deletions

View File

@@ -11,19 +11,10 @@
<artifactId>chat-server</artifactId>
<dependencies>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-context</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<dependency>
<groupId>com.tencent.supersonic</groupId>
<artifactId>common</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.tencent.supersonic</groupId>
<artifactId>auth-api</artifactId>
@@ -36,13 +27,7 @@
</dependency>
<dependency>
<groupId>com.tencent.supersonic</groupId>
<artifactId>headless-core</artifactId>
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.tencent.supersonic</groupId>
<artifactId>chat-core</artifactId>
<artifactId>headless-server</artifactId>
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
@@ -51,12 +36,6 @@
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<version>${mockito-inline.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@@ -0,0 +1,87 @@
package com.tencent.supersonic.chat.server.agent;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.common.pojo.RecordInfo;
import lombok.Data;
import org.springframework.util.CollectionUtils;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@Data
public class Agent extends RecordInfo {
private Integer id;
private Integer enableSearch;
private String name;
private String description;
/**
* 0 offline, 1 online
*/
private Integer status;
private List<String> examples;
private String agentConfig;
public List<String> getTools(AgentToolType type) {
Map map = JSONObject.parseObject(agentConfig, Map.class);
if (CollectionUtils.isEmpty(map) || map.get("tools") == null) {
return Lists.newArrayList();
}
List<Map> toolList = (List) map.get("tools");
return toolList.stream()
.filter(tool -> {
if (Objects.isNull(type)) {
return true;
}
return type.name().equals(tool.get("type"));
}
)
.map(JSONObject::toJSONString)
.collect(Collectors.toList());
}
public boolean enableSearch() {
return enableSearch != null && enableSearch == 1;
}
public static boolean containsAllModel(Set<Long> detectViewIds) {
return !CollectionUtils.isEmpty(detectViewIds) && detectViewIds.contains(-1L);
}
public List<NL2SQLTool> getParserTools(AgentToolType agentToolType) {
List<String> tools = this.getTools(agentToolType);
if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList();
}
return tools.stream().map(tool -> JSONObject.parseObject(tool, NL2SQLTool.class))
.collect(Collectors.toList());
}
public Set<Long> getDataSetIds() {
Set<Long> dataSetIds = getDataSetIds(null);
if (containsAllModel(dataSetIds)) {
return Sets.newHashSet();
}
return dataSetIds;
}
public Set<Long> getDataSetIds(AgentToolType agentToolType) {
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
if (CollectionUtils.isEmpty(commonAgentTools)) {
return new HashSet<>();
}
return commonAgentTools.stream().map(NL2SQLTool::getDataSetIds)
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
.flatMap(Collection::stream)
.collect(Collectors.toSet());
}
}

View File

@@ -0,0 +1,17 @@
package com.tencent.supersonic.chat.server.agent;
import com.google.common.collect.Lists;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class AgentConfig {
List<AgentTool> tools = Lists.newArrayList();
}

View File

@@ -0,0 +1,17 @@
package com.tencent.supersonic.chat.server.agent;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class AgentTool {
private String id;
private String name;
private AgentToolType type;
}

View File

@@ -0,0 +1,25 @@
package com.tencent.supersonic.chat.server.agent;
import java.util.HashMap;
import java.util.Map;
public enum AgentToolType {
NL2SQL_RULE("基于规则Text-to-SQL"),
NL2SQL_LLM("基于大模型Text-to-SQL"),
PLUGIN("第三方插件");
private String title;
AgentToolType(String title) {
this.title = title;
}
public static Map<AgentToolType, String> getToolTypes() {
Map<AgentToolType, String> map = new HashMap<>();
map.put(NL2SQL_RULE, NL2SQL_RULE.title);
map.put(NL2SQL_LLM, NL2SQL_LLM.title);
map.put(PLUGIN, PLUGIN.title);
return map;
}
}

View File

@@ -0,0 +1,12 @@
package com.tencent.supersonic.chat.server.agent;
import lombok.Data;
import java.util.List;
@Data
public class LLMParserTool extends NL2SQLTool {
private List<String> exampleQuestions;
}

View File

@@ -0,0 +1,17 @@
package com.tencent.supersonic.chat.server.agent;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class NL2SQLTool extends AgentTool {
protected List<Long> dataSetIds;
}

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.server.agent;
import lombok.Data;
import java.util.List;
@Data
public class PluginTool extends AgentTool {
private List<Long> plugins;
}

View File

@@ -0,0 +1,21 @@
package com.tencent.supersonic.chat.server.agent;
import lombok.Data;
import org.apache.commons.collections.CollectionUtils;
import java.util.List;
@Data
public class RuleParserTool extends NL2SQLTool {
private List<String> queryModes;
private List<String> queryTypes;
public boolean isContainsAllModel() {
return CollectionUtils.isNotEmpty(dataSetIds) && dataSetIds.contains(-1L);
}
}

View File

@@ -1,87 +0,0 @@
package com.tencent.supersonic.chat.server.listener;
import com.tencent.supersonic.headless.core.knowledge.DictWord;
import com.tencent.supersonic.chat.server.service.impl.SchemaService;
import com.tencent.supersonic.chat.server.service.impl.WordService;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import com.tencent.supersonic.headless.server.service.KnowledgeService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.CommandLineRunner;
import org.springframework.core.annotation.Order;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
@Slf4j
@Component
@Order(2)
public class ApplicationStartedListener implements CommandLineRunner {
@Autowired
private KnowledgeService knowledgeService;
@Autowired
private WordService wordService;
@Autowired
private SchemaService schemaService;
@Override
public void run(String... args) {
updateKnowledgeDimValue();
}
public Boolean updateKnowledgeDimValue() {
Boolean isOk = false;
try {
log.debug("ApplicationStartedInit start");
List<DictWord> dictWords = wordService.getAllDictWords();
wordService.setPreDictWords(dictWords);
knowledgeService.reloadAllData(dictWords);
log.debug("ApplicationStartedInit end");
isOk = true;
} catch (Exception e) {
log.error("ApplicationStartedInit error", e);
}
return isOk;
}
public Boolean updateKnowledgeDimValueAsync() {
CompletableFuture.supplyAsync(() -> {
updateKnowledgeDimValue();
return null;
});
return true;
}
/***
* reload knowledge task
*/
@Scheduled(cron = "${reload.knowledge.corn:0 0/1 * * * ?}")
public void reloadKnowledge() {
log.debug("reloadKnowledge start");
try {
List<DictWord> dictWords = wordService.getAllDictWords();
List<DictWord> preDictWords = wordService.getPreDictWords();
if (CollectionUtils.isEqualCollection(dictWords, preDictWords)) {
log.debug("dictWords has not changed, reloadKnowledge end");
return;
}
log.info("dictWords has changed");
wordService.setPreDictWords(dictWords);
knowledgeService.updateOnlineKnowledge(wordService.getAllDictWords());
schemaService.getCache().refresh(SchemaService.ALL_CACHE);
} catch (Exception e) {
log.error("reloadKnowledge error", e);
}
log.debug("reloadKnowledge end");
}
}

View File

@@ -1,51 +0,0 @@
package com.tencent.supersonic.chat.server.listener;
import com.tencent.supersonic.headless.core.knowledge.DictWord;
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.chat.server.service.impl.SchemaService;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DataEvent;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.common.pojo.enums.EventType;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationListener;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@Component
@Slf4j
public class SchemaDictUpdateListener implements ApplicationListener<DataEvent> {
@Autowired
private SchemaService schemaService;
@Async
@Override
public void onApplicationEvent(DataEvent dataEvent) {
if (CollectionUtils.isEmpty(dataEvent.getDataItems())) {
return;
}
schemaService.getCache().invalidateAll();
dataEvent.getDataItems().forEach(dataItem -> {
DictWord dictWord = new DictWord();
dictWord.setWord(dataItem.getName());
String sign = DictWordType.NATURE_SPILT;
String suffixNature = DictWordType.getSuffixNature(dataItem.getType());
String nature = sign + dataItem.getModelId() + dataItem.getId() + suffixNature;
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
dictWord.setNature(nature);
dictWord.setNatureWithFrequency(natureWithFrequency);
if (EventType.ADD.equals(dataEvent.getEventType())) {
HanlpHelper.addToCustomDictionary(dictWord);
} else if (EventType.DELETE.equals(dataEvent.getEventType())) {
HanlpHelper.removeFromCustomDictionary(dictWord);
} else if (EventType.UPDATE.equals(dataEvent.getEventType())) {
HanlpHelper.removeFromCustomDictionary(dictWord);
dictWord.setWord(dataItem.getNewName());
HanlpHelper.addToCustomDictionary(dictWord);
}
});
}
}

View File

@@ -1,15 +0,0 @@
package com.tencent.supersonic.chat.server.persistence.dataobject;
import java.io.Serializable;
import java.time.Instant;
import lombok.Data;
@Data
public class ChatContextDO implements Serializable {
private Integer chatId;
private Instant modifiedAt;
private String user;
private String queryText;
private String semanticParse;
}

View File

@@ -1,14 +1,14 @@
package com.tencent.supersonic.chat.server.persistence.dataobject;
import com.tencent.supersonic.chat.core.config.DefaultMetric;
import com.tencent.supersonic.chat.core.config.Dim4Dict;
import java.util.ArrayList;
import java.util.List;
import com.tencent.supersonic.headless.core.config.DefaultMetric;
import com.tencent.supersonic.headless.core.config.Dim4Dict;
import lombok.Data;
import lombok.ToString;
import java.util.ArrayList;
import java.util.List;
@Data
@ToString

View File

@@ -1,54 +0,0 @@
package com.tencent.supersonic.chat.server.persistence.dataobject;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import java.util.Date;
@Data
@Builder
@NoArgsConstructor
@Getter
@AllArgsConstructor
public class StatisticsDO {
/**
* questionId
*/
private Long questionId;
/**
* chatId
*/
private Long chatId;
/**
* createTime
*/
private Date createTime;
/**
* queryText
*/
private String queryText;
/**
* userName
*/
private String userName;
/**
* interface
*/
private String interfaceName;
/**
* cost
*/
private Integer cost;
private Integer type;
}

View File

@@ -1,14 +0,0 @@
package com.tencent.supersonic.chat.server.persistence.mapper;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface ChatContextMapper {
ChatContextDO getContextByChatId(int chatId);
int updateContext(ChatContextDO contextDO);
int addContext(ChatContextDO contextDO);
}

View File

@@ -1,12 +0,0 @@
package com.tencent.supersonic.chat.server.persistence.mapper;
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
@Mapper
public interface StatisticsMapper {
boolean batchSaveStatistics(@Param("list") List<StatisticsDO> list);
}

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.chat.server.persistence.repository;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
public interface ChatContextRepository {

View File

@@ -1,14 +1,14 @@
package com.tencent.supersonic.chat.server.persistence.repository;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
import java.util.List;
public interface ChatQueryRepository {
@@ -25,8 +25,8 @@ public interface ChatQueryRepository {
int updateChatQuery(ChatQueryDO chatQueryDO);
List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryContext queryContext,
ParseResp parseResult, List<SemanticParseInfo> candidateParses);
List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq, ParseResp parseResult,
List<SemanticParseInfo> candidateParses);
ChatParseDO getParseInfo(Long questionId, int parseId);

View File

@@ -1,11 +0,0 @@
package com.tencent.supersonic.chat.server.persistence.repository;
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
import java.util.List;
public interface StatisticsRepository {
void batchSaveStatistics(List<StatisticsDO> list);
}

View File

@@ -1,72 +0,0 @@
package com.tencent.supersonic.chat.server.persistence.repository.impl;
import com.google.gson.Gson;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO;
import com.tencent.supersonic.chat.server.persistence.mapper.ChatContextMapper;
import com.tencent.supersonic.chat.server.persistence.repository.ChatContextRepository;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;
@Repository
@Primary
@Slf4j
public class ChatContextRepositoryImpl implements ChatContextRepository {
@Autowired(required = false)
private final ChatContextMapper chatContextMapper;
public ChatContextRepositoryImpl(ChatContextMapper chatContextMapper) {
this.chatContextMapper = chatContextMapper;
}
@Override
public ChatContext getOrCreateContext(int chatId) {
ChatContextDO context = chatContextMapper.getContextByChatId(chatId);
if (context == null) {
ChatContext chatContext = new ChatContext();
chatContext.setChatId(chatId);
return chatContext;
}
return cast(context);
}
@Override
public void updateContext(ChatContext chatCtx) {
ChatContextDO context = cast(chatCtx);
if (chatContextMapper.getContextByChatId(chatCtx.getChatId()) == null) {
chatContextMapper.addContext(context);
} else {
chatContextMapper.updateContext(context);
}
}
private ChatContext cast(ChatContextDO contextDO) {
ChatContext chatContext = new ChatContext();
chatContext.setChatId(contextDO.getChatId());
chatContext.setUser(contextDO.getUser());
chatContext.setQueryText(contextDO.getQueryText());
if (contextDO.getSemanticParse() != null && !contextDO.getSemanticParse().isEmpty()) {
SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(),
SemanticParseInfo.class);
chatContext.setParseInfo(semanticParseInfo);
}
return chatContext;
}
private ChatContextDO cast(ChatContext chatContext) {
ChatContextDO chatContextDO = new ChatContextDO();
chatContextDO.setChatId(chatContext.getChatId());
chatContextDO.setQueryText(chatContext.getQueryText());
chatContextDO.setUser(chatContext.getUser());
if (chatContext.getParseInfo() != null) {
Gson g = new Gson();
chatContextDO.setSemanticParse(g.toJson(chatContext.getParseInfo()));
}
return chatContextDO;
}
}

View File

@@ -1,16 +1,9 @@
package com.tencent.supersonic.chat.server.persistence.repository.impl;
import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDOExample;
@@ -21,11 +14,10 @@ import com.tencent.supersonic.chat.server.persistence.mapper.custom.ShowCaseCust
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.PageUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
@@ -33,6 +25,12 @@ import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@Repository
@Primary
@Slf4j
@@ -108,21 +106,16 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
queryResult.setQueryId(chatQueryDO.getQuestionId());
queryResp.setQueryResult(queryResult);
}
if (StringUtils.isNotBlank(chatQueryDO.getSimilarQueries())) {
List<SimilarQueryRecallResp> similarQueries = JSONObject.parseArray(chatQueryDO.getSimilarQueries(),
SimilarQueryRecallResp.class);
queryResp.setSimilarQueries(similarQueries);
}
return queryResp;
}
public Long createChatQuery(ParseResp parseResult, ChatContext chatCtx, QueryContext queryContext) {
public Long createChatQuery(ParseResp parseResult, ChatParseReq chatParseReq) {
ChatQueryDO chatQueryDO = new ChatQueryDO();
chatQueryDO.setChatId(Long.valueOf(chatCtx.getChatId()));
chatQueryDO.setChatId(Long.valueOf(chatParseReq.getChatId()));
chatQueryDO.setCreateTime(new java.util.Date());
chatQueryDO.setUserName(queryContext.getUser().getName());
chatQueryDO.setQueryText(queryContext.getQueryText());
chatQueryDO.setAgentId(queryContext.getAgentId());
chatQueryDO.setUserName(chatParseReq.getUser().getName());
chatQueryDO.setQueryText(chatParseReq.getQueryText());
chatQueryDO.setAgentId(chatParseReq.getAgentId());
chatQueryDO.setQueryResult("");
try {
chatQueryDOMapper.insert(chatQueryDO);
@@ -135,24 +128,24 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
}
@Override
public List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryContext queryContext,
public List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq,
ParseResp parseResult, List<SemanticParseInfo> candidateParses) {
Long queryId = createChatQuery(parseResult, chatCtx, queryContext);
Long queryId = createChatQuery(parseResult, chatParseReq);
List<ChatParseDO> chatParseDOList = new ArrayList<>();
getChatParseDO(chatCtx, queryContext, queryId, candidateParses, chatParseDOList);
getChatParseDO(chatParseReq, queryId, candidateParses, chatParseDOList);
if (!CollectionUtils.isEmpty(candidateParses)) {
chatParseMapper.batchSaveParseInfo(chatParseDOList);
}
return chatParseDOList;
}
public void getChatParseDO(ChatContext chatCtx, QueryContext queryContext, Long queryId,
public void getChatParseDO(ChatParseReq chatParseReq, Long queryId,
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
for (int i = 0; i < parses.size(); i++) {
ChatParseDO chatParseDO = new ChatParseDO();
chatParseDO.setChatId(Long.valueOf(chatCtx.getChatId()));
chatParseDO.setChatId(Long.valueOf(chatParseReq.getChatId()));
chatParseDO.setQuestionId(queryId);
chatParseDO.setQueryText(queryContext.getQueryText());
chatParseDO.setQueryText(chatParseReq.getQueryText());
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
chatParseDO.setIsCandidate(1);
if (i == 0) {
@@ -160,7 +153,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
}
chatParseDO.setParseId(parses.get(i).getId());
chatParseDO.setCreateTime(new java.util.Date());
chatParseDO.setUserName(queryContext.getUser().getName());
chatParseDO.setUserName(chatParseReq.getUser().getName());
chatParseDOList.add(chatParseDO);
}
}

View File

@@ -1,27 +0,0 @@
package com.tencent.supersonic.chat.server.persistence.repository.impl;
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.server.persistence.mapper.StatisticsMapper;
import com.tencent.supersonic.chat.server.persistence.repository.StatisticsRepository;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;
import java.util.List;
@Repository
@Primary
@Slf4j
public class StatisticsRepositoryImpl implements StatisticsRepository {
private final StatisticsMapper statisticsMapper;
public StatisticsRepositoryImpl(StatisticsMapper statisticsMapper) {
this.statisticsMapper = statisticsMapper;
}
public void batchSaveStatistics(List<StatisticsDO> list) {
statisticsMapper.batchSaveStatistics(list);
}
}

View File

@@ -0,0 +1,8 @@
package com.tencent.supersonic.chat.server.plugin;
public enum ParseMode {
EMBEDDING_RECALL,
FUNCTION_CALL;
}

View File

@@ -0,0 +1,61 @@
package com.tencent.supersonic.chat.server.plugin;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.RecordInfo;
import lombok.Data;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
@Data
public class Plugin extends RecordInfo {
private Long id;
/***
* plugin type WEB_PAGE WEB_SERVICE
*/
private String type;
private List<Long> dataSetList = Lists.newArrayList();
/**
* description, for parsing
*/
private String pattern;
/**
* parse
*/
private ParseMode parseMode;
private String parseModeConfig;
private String name;
/**
* config for different plugin type
*/
private String config;
private String comment;
public List<String> getExampleQuestionList() {
if (StringUtils.isNotBlank(parseModeConfig)) {
PluginParseConfig pluginParseConfig = JSONObject.parseObject(parseModeConfig, PluginParseConfig.class);
return pluginParseConfig.getExamples();
}
return Lists.newArrayList();
}
public boolean isContainsAllModel() {
return CollectionUtils.isNotEmpty(dataSetList) && dataSetList.contains(-1L);
}
public Long getDefaultMode() {
return -1L;
}
}

View File

@@ -0,0 +1,285 @@
package com.tencent.supersonic.chat.server.plugin;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.PluginTool;
import com.tencent.supersonic.chat.server.plugin.build.ParamOption;
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent;
import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.event.EventListener;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
@Component
public class PluginManager {
@Autowired
private EmbeddingConfig embeddingConfig;
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
public static List<Plugin> getPluginAgentCanSupport(ChatParseReq chatParseReq) {
PluginService pluginService = ContextUtils.getBean(PluginService.class);
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
List<Plugin> plugins = pluginService.getPluginList();
if (Objects.isNull(agent)) {
return plugins;
}
List<Long> pluginIds = getPluginTools(agent).stream().map(PluginTool::getPlugins)
.flatMap(Collection::stream).collect(Collectors.toList());
if (CollectionUtils.isEmpty(pluginIds)) {
return Lists.newArrayList();
}
plugins = plugins.stream().filter(plugin -> pluginIds.contains(plugin.getId()))
.collect(Collectors.toList());
log.info("plugins witch can be supported by cur agent :{} {}", agent.getName(),
plugins.stream().map(Plugin::getName).collect(Collectors.toList()));
return plugins;
}
private static List<PluginTool> getPluginTools(Agent agent) {
if (agent == null) {
return Lists.newArrayList();
}
List<String> tools = agent.getTools(AgentToolType.PLUGIN);
if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList();
}
return tools.stream().map(tool -> JSONObject.parseObject(tool, PluginTool.class))
.collect(Collectors.toList());
}
@EventListener
public void addPlugin(PluginAddEvent pluginAddEvent) {
Plugin plugin = pluginAddEvent.getPlugin();
if (CollectionUtils.isNotEmpty(plugin.getExampleQuestionList())) {
requestEmbeddingPluginAdd(convert(Lists.newArrayList(plugin)));
}
}
@EventListener
public void updatePlugin(PluginUpdateEvent pluginUpdateEvent) {
Plugin oldPlugin = pluginUpdateEvent.getOldPlugin();
Plugin newPlugin = pluginUpdateEvent.getNewPlugin();
if (CollectionUtils.isNotEmpty(oldPlugin.getExampleQuestionList())) {
requestEmbeddingPluginDelete(getEmbeddingId(Lists.newArrayList(oldPlugin)));
}
if (CollectionUtils.isNotEmpty(newPlugin.getExampleQuestionList())) {
requestEmbeddingPluginAdd(convert(Lists.newArrayList(newPlugin)));
}
}
@EventListener
public void delPlugin(PluginDelEvent pluginDelEvent) {
Plugin plugin = pluginDelEvent.getPlugin();
if (CollectionUtils.isNotEmpty(plugin.getExampleQuestionList())) {
requestEmbeddingPluginDelete(getEmbeddingId(Lists.newArrayList(plugin)));
}
}
public void requestEmbeddingPluginDelete(Set<String> queryIds) {
if (CollectionUtils.isEmpty(queryIds)) {
return;
}
String presetCollection = embeddingConfig.getPresetCollection();
List<EmbeddingQuery> queries = new ArrayList<>();
for (String id : queryIds) {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(id);
queries.add(embeddingQuery);
}
s2EmbeddingStore.deleteQuery(presetCollection, queries);
}
public void requestEmbeddingPluginAdd(List<EmbeddingQuery> queries) {
if (CollectionUtils.isEmpty(queries)) {
return;
}
String presetCollection = embeddingConfig.getPresetCollection();
s2EmbeddingStore.addQuery(presetCollection, queries);
}
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
requestEmbeddingPluginAdd(convert(plugins));
}
public RetrieveQueryResult recognize(String embeddingText) {
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
.queryTextsList(Collections.singletonList(embeddingText))
.build();
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(embeddingConfig.getPresetCollection(),
retrieveQuery, embeddingConfig.getNResult());
if (CollectionUtils.isNotEmpty(resultList)) {
for (RetrieveQueryResult embeddingResp : resultList) {
List<Retrieval> embeddingRetrievals = embeddingResp.getRetrieval();
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
embeddingRetrieval.setId(getPluginIdFromEmbeddingId(embeddingRetrieval.getId()));
}
}
return resultList.get(0);
}
throw new RuntimeException("get embedding result failed");
}
public List<EmbeddingQuery> convert(List<Plugin> plugins) {
List<EmbeddingQuery> queries = Lists.newArrayList();
for (Plugin plugin : plugins) {
List<String> exampleQuestions = plugin.getExampleQuestionList();
int num = 0;
for (String pattern : exampleQuestions) {
EmbeddingQuery query = new EmbeddingQuery();
query.setQueryId(generateUniqueEmbeddingId(num, plugin.getId()));
query.setQuery(pattern);
queries.add(query);
num++;
}
}
return queries;
}
private Set<String> getEmbeddingId(List<Plugin> plugins) {
Set<String> embeddingIdSet = new HashSet<>();
for (EmbeddingQuery query : convert(plugins)) {
embeddingIdSet.add(query.getQueryId());
}
return embeddingIdSet;
}
//num can not bigger than 100
private String generateUniqueEmbeddingId(int num, Long pluginId) {
if (num < 10) {
return String.format("%s00%s", pluginId, num);
} else {
return String.format("%s0%s", pluginId, num);
}
}
private String getPluginIdFromEmbeddingId(String id) {
return String.valueOf(Integer.parseInt(id) / 1000);
}
public static Pair<Boolean, Set<Long>> resolve(Plugin plugin, QueryContext queryContext) {
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
Set<Long> pluginMatchedModel = getPluginMatchedModel(plugin, queryContext);
if (CollectionUtils.isEmpty(pluginMatchedModel) && !plugin.isContainsAllModel()) {
return Pair.of(false, Sets.newHashSet());
}
List<ParamOption> paramOptions = getSemanticOption(plugin);
if (CollectionUtils.isEmpty(paramOptions)) {
return Pair.of(true, pluginMatchedModel);
}
Set<Long> matchedModel = Sets.newHashSet();
Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream()
.collect(Collectors.groupingBy(ParamOption::getModelId));
for (Long modelId : paramOptionMap.keySet()) {
List<ParamOption> params = paramOptionMap.get(modelId);
if (CollectionUtils.isEmpty(params)) {
matchedModel.add(modelId);
continue;
}
boolean matched = true;
for (ParamOption paramOption : params) {
Set<Long> elementIdSet = getSchemaElementMatch(modelId, schemaMapInfo);
if (CollectionUtils.isEmpty(elementIdSet)) {
matched = false;
break;
}
if (!elementIdSet.contains(paramOption.getElementId())) {
matched = false;
break;
}
}
if (matched) {
matchedModel.add(modelId);
}
}
if (CollectionUtils.isEmpty(matchedModel)) {
return Pair.of(false, Sets.newHashSet());
}
return Pair.of(true, matchedModel);
}
private static Set<Long> getSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(modelId);
if (org.springframework.util.CollectionUtils.isEmpty(schemaElementMatches)) {
return Sets.newHashSet();
}
return schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.map(SchemaElementMatch::getElement)
.map(SchemaElement::getId)
.collect(Collectors.toSet());
}
private static List<ParamOption> getSemanticOption(Plugin plugin) {
WebBase webBase = JSONObject.parseObject(plugin.getConfig(), WebBase.class);
if (Objects.isNull(webBase)) {
return null;
}
List<ParamOption> paramOptions = webBase.getParamOptions();
if (org.springframework.util.CollectionUtils.isEmpty(paramOptions)) {
return Lists.newArrayList();
}
return paramOptions.stream()
.filter(paramOption -> ParamOption.ParamType.SEMANTIC.equals(paramOption.getParamType()))
.collect(Collectors.toList());
}
private static Set<Long> getPluginMatchedModel(Plugin plugin, QueryContext queryContext) {
Set<Long> matchedDataSets = queryContext.getMapInfo().getMatchedDataSetInfos();
if (plugin.isContainsAllModel()) {
return Sets.newHashSet(plugin.getDefaultMode());
}
List<Long> modelIds = plugin.getDataSetList();
Set<Long> pluginMatchedModel = Sets.newHashSet();
for (Long modelId : modelIds) {
if (matchedDataSets.contains(modelId)) {
pluginMatchedModel.add(modelId);
}
}
return pluginMatchedModel;
}
}

View File

@@ -0,0 +1,26 @@
package com.tencent.supersonic.chat.server.plugin;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
import java.io.Serializable;
import java.util.List;
@Data
@Builder
@AllArgsConstructor
@ToString
@NoArgsConstructor
public class PluginParseConfig implements Serializable {
public List<String> examples;
private String name;
private String description;
}

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.server.plugin;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import lombok.Data;
@Data
public class PluginParseResult {
private Plugin plugin;
private QueryFilters queryFilters;
private double distance;
private String queryText;
}

View File

@@ -0,0 +1,24 @@
package com.tencent.supersonic.chat.server.plugin;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.Set;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class PluginRecallResult {
private Plugin plugin;
private Set<Long> dataSetIds;
private double score;
private double distance;
}

View File

@@ -0,0 +1,37 @@
package com.tencent.supersonic.chat.server.plugin.build;
import lombok.Data;
@Data
public class ParamOption {
private ParamType paramType;
private OptionType optionType;
private String key;
private String name;
private String keyAlias;
private Long modelId;
private Long elementId;
private Object value;
/**
* CUSTOM: the value is specified by the user
* SEMANTIC: the value of element
* FORWARD: only forward
*/
public enum ParamType {
CUSTOM, SEMANTIC, FORWARD
}
public enum OptionType {
REQUIRED, OPTIONAL
}
}

View File

@@ -0,0 +1,94 @@
package com.tencent.supersonic.chat.server.plugin.build;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.core.chat.query.BaseSemanticQuery;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Slf4j
public abstract class PluginSemanticQuery extends BaseSemanticQuery {
@Override
public void initS2Sql(SemanticSchema semanticSchema, User user) {
}
private Map<Long, Object> getFilterMap(PluginParseResult pluginParseResult) {
Map<Long, Object> map = new HashMap<>();
QueryFilters queryFilters = pluginParseResult.getQueryFilters();
if (queryFilters == null) {
return map;
}
List<QueryFilter> queryFilterList = queryFilters.getFilters();
if (CollectionUtils.isEmpty(queryFilterList)) {
return map;
}
for (QueryFilter queryFilter : queryFilterList) {
map.put(queryFilter.getElementID(), queryFilter.getValue());
}
return map;
}
protected Map<String, Object> getElementMap(PluginParseResult pluginParseResult) {
Map<String, Object> elementValueMap = new HashMap<>();
Map<Long, Object> filterValueMap = getFilterMap(pluginParseResult);
List<SchemaElementMatch> schemaElementMatchList = parseInfo.getElementMatches();
if (!CollectionUtils.isEmpty(schemaElementMatchList)) {
schemaElementMatchList.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.filter(schemaElementMatch -> schemaElementMatch.getSimilarity() == 1.0)
.forEach(schemaElementMatch -> {
Object queryFilterValue = filterValueMap.get(schemaElementMatch.getElement().getId());
if (queryFilterValue != null) {
if (String.valueOf(queryFilterValue).equals(String.valueOf(schemaElementMatch.getWord()))) {
elementValueMap.put(
String.valueOf(schemaElementMatch.getElement().getId()),
schemaElementMatch.getWord());
}
} else {
elementValueMap.computeIfAbsent(
String.valueOf(schemaElementMatch.getElement().getId()),
k -> schemaElementMatch.getWord());
}
});
}
return elementValueMap;
}
protected WebBase fillWebBaseResult(WebBase webPage, PluginParseResult pluginParseResult) {
WebBase webBaseResult = new WebBase();
webBaseResult.setUrl(webPage.getUrl());
Map<String, Object> elementValueMap = getElementMap(pluginParseResult);
List<ParamOption> paramOptions = Lists.newArrayList();
if (!CollectionUtils.isEmpty(webPage.getParamOptions()) && !CollectionUtils.isEmpty(elementValueMap)) {
for (ParamOption paramOption : webPage.getParamOptions()) {
if (paramOption.getModelId() != null
&& !parseInfo.getDataSetId().equals(paramOption.getModelId())) {
continue;
}
paramOptions.add(paramOption);
if (!ParamOption.ParamType.SEMANTIC.equals(paramOption.getParamType())) {
continue;
}
String elementId = String.valueOf(paramOption.getElementId());
Object elementValue = elementValueMap.get(elementId);
paramOption.setValue(elementValue);
}
}
webBaseResult.setParamOptions(paramOptions);
return webBaseResult;
}
}

View File

@@ -0,0 +1,19 @@
package com.tencent.supersonic.chat.server.plugin.build;
import com.google.common.collect.Lists;
import lombok.Data;
import java.util.List;
@Data
public class WebBase {
private String url;
private List<ParamOption> paramOptions = Lists.newArrayList();
public List<ParamOption> getParams() {
return paramOptions;
}
}

View File

@@ -0,0 +1,46 @@
package com.tencent.supersonic.chat.server.plugin.build.webpage;
import com.tencent.supersonic.chat.server.plugin.Plugin;
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.core.chat.query.QueryManager;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.springframework.stereotype.Component;
@Slf4j
@Component
public class WebPageQuery extends PluginSemanticQuery {
public static String QUERY_MODE = "WEB_PAGE";
public WebPageQuery() {
QueryManager.register(this);
}
@Override
public String getQueryMode() {
return QUERY_MODE;
}
@Override
public SemanticQueryReq buildSemanticQueryReq() throws SqlParseException {
return null;
}
protected WebPageResp buildResponse(PluginParseResult pluginParseResult) {
Plugin plugin = pluginParseResult.getPlugin();
WebPageResp webPageResponse = new WebPageResp();
webPageResponse.setName(plugin.getName());
webPageResponse.setPluginId(plugin.getId());
webPageResponse.setPluginType(plugin.getType());
WebBase webPage = JsonUtil.toObject(plugin.getConfig(), WebBase.class);
WebBase webBase = fillWebBaseResult(webPage, pluginParseResult);
webPageResponse.setWebPage(webBase);
return webPageResponse;
}
}

View File

@@ -0,0 +1,25 @@
package com.tencent.supersonic.chat.server.plugin.build.webpage;
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
import lombok.Data;
import java.util.List;
@Data
public class WebPageResp {
private Long pluginId;
private String pluginType;
private String name;
private String description;
private WebBase webPage;
private List<WebBase> moreWebPage;
}

View File

@@ -0,0 +1,79 @@
package com.tencent.supersonic.chat.server.plugin.build.webservice;
import com.alibaba.fastjson.JSON;
import com.tencent.supersonic.chat.server.plugin.Plugin;
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
import com.tencent.supersonic.chat.server.plugin.build.ParamOption;
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.core.chat.query.QueryManager;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Slf4j
@Component
public class WebServiceQuery extends PluginSemanticQuery {
public static String QUERY_MODE = "WEB_SERVICE";
private RestTemplate restTemplate;
public WebServiceQuery() {
QueryManager.register(this);
}
@Override
public String getQueryMode() {
return QUERY_MODE;
}
@Override
public SemanticQueryReq buildSemanticQueryReq() throws SqlParseException {
return null;
}
protected WebServiceResp buildResponse(PluginParseResult pluginParseResult) {
WebServiceResp webServiceResponse = new WebServiceResp();
Plugin plugin = pluginParseResult.getPlugin();
WebBase webBase = fillWebBaseResult(JsonUtil.toObject(plugin.getConfig(), WebBase.class), pluginParseResult);
webServiceResponse.setWebBase(webBase);
List<ParamOption> paramOptions = webBase.getParamOptions();
Map<String, Object> params = new HashMap<>();
paramOptions.forEach(o -> params.put(o.getKey(), o.getValue()));
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> entity = new HttpEntity<>(JSON.toJSONString(params), headers);
URI requestUrl = UriComponentsBuilder.fromHttpUrl(webBase.getUrl()).build().encode().toUri();
ResponseEntity responseEntity = null;
Object objectResponse = null;
restTemplate = ContextUtils.getBean(RestTemplate.class);
try {
responseEntity = restTemplate.exchange(requestUrl, HttpMethod.POST, entity, Object.class);
objectResponse = responseEntity.getBody();
log.info("objectResponse:{}", objectResponse);
Map<String, Object> response = JsonUtil.objectToMap(objectResponse);
webServiceResponse.setResult(response);
} catch (Exception e) {
log.info("Exception:{}", e.getMessage());
}
return webServiceResponse;
}
}

View File

@@ -0,0 +1,15 @@
package com.tencent.supersonic.chat.server.plugin.build.webservice;
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
import lombok.Data;
@Data
public class WebServiceResp {
private WebBase webBase;
private Object result;
}

View File

@@ -0,0 +1,18 @@
package com.tencent.supersonic.chat.server.plugin.event;
import com.tencent.supersonic.chat.server.plugin.Plugin;
import org.springframework.context.ApplicationEvent;
public class PluginAddEvent extends ApplicationEvent {
private Plugin plugin;
public PluginAddEvent(Object source, Plugin plugin) {
super(source);
this.plugin = plugin;
}
public Plugin getPlugin() {
return plugin;
}
}

View File

@@ -0,0 +1,19 @@
package com.tencent.supersonic.chat.server.plugin.event;
import com.tencent.supersonic.chat.server.plugin.Plugin;
import org.springframework.context.ApplicationEvent;
public class PluginDelEvent extends ApplicationEvent {
private Plugin plugin;
public PluginDelEvent(Object source, Plugin plugin) {
super(source);
this.plugin = plugin;
}
public Plugin getPlugin() {
return plugin;
}
}

View File

@@ -0,0 +1,26 @@
package com.tencent.supersonic.chat.server.plugin.event;
import com.tencent.supersonic.chat.server.plugin.Plugin;
import org.springframework.context.ApplicationEvent;
public class PluginUpdateEvent extends ApplicationEvent {
private Plugin oldPlugin;
private Plugin newPlugin;
public PluginUpdateEvent(Object source, Plugin oldPlugin, Plugin newPlugin) {
super(source);
this.oldPlugin = oldPlugin;
this.newPlugin = newPlugin;
}
public Plugin getOldPlugin() {
return oldPlugin;
}
public Plugin getNewPlugin() {
return newPlugin;
}
}

View File

@@ -0,0 +1,114 @@
package com.tencent.supersonic.chat.server.plugin.recall;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.plugin.Plugin;
import com.tencent.supersonic.chat.server.plugin.PluginManager;
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import org.springframework.util.CollectionUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* PluginParser defines the basic process and common methods for recalling plugins.
*/
public abstract class PluginParser {
public void parse(ChatParseReq chatParseReq) {
if (!checkPreCondition(chatParseReq)) {
return;
}
PluginRecallResult pluginRecallResult = recallPlugin(chatParseReq);
if (pluginRecallResult == null) {
return;
}
buildQuery(chatParseReq, pluginRecallResult);
}
public abstract boolean checkPreCondition(ChatParseReq chatParseReq);
public abstract PluginRecallResult recallPlugin(ChatParseReq chatParseReq);
public void buildQuery(ChatParseReq chatParseReq, PluginRecallResult pluginRecallResult) {
Plugin plugin = pluginRecallResult.getPlugin();
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
if (plugin.isContainsAllModel()) {
dataSetIds = Sets.newHashSet(-1L);
}
for (Long dataSetId : dataSetIds) {
//todo
PluginSemanticQuery pluginQuery = null;
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
null, pluginRecallResult.getDistance());
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
semanticParseInfo.setScore(pluginRecallResult.getScore());
pluginQuery.setParseInfo(semanticParseInfo);
//chatParseReq.getCandidateQueries().add(pluginQuery);
}
}
protected List<Plugin> getPluginList(ChatParseReq chatParseReq) {
return PluginManager.getPluginAgentCanSupport(chatParseReq);
}
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, Plugin plugin,
QueryContext queryContext, double distance) {
List<SchemaElementMatch> schemaElementMatches = queryContext.getMapInfo().getMatchedElements(dataSetId);
QueryFilters queryFilters = queryContext.getQueryFilters();
if (dataSetId == null && !CollectionUtils.isEmpty(plugin.getDataSetList())) {
dataSetId = plugin.getDataSetList().get(0);
}
if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList();
}
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setElementMatches(schemaElementMatches);
semanticParseInfo.setDataSet(queryContext.getSemanticSchema().getDataSet(dataSetId));
Map<String, Object> properties = new HashMap<>();
PluginParseResult pluginParseResult = new PluginParseResult();
pluginParseResult.setPlugin(plugin);
pluginParseResult.setQueryFilters(queryFilters);
pluginParseResult.setDistance(distance);
pluginParseResult.setQueryText(queryContext.getQueryText());
properties.put(Constants.CONTEXT, pluginParseResult);
properties.put("type", "plugin");
properties.put("name", plugin.getName());
semanticParseInfo.setProperties(properties);
semanticParseInfo.setScore(distance);
fillSemanticParseInfo(semanticParseInfo);
return semanticParseInfo;
}
private void fillSemanticParseInfo(SemanticParseInfo semanticParseInfo) {
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
if (CollectionUtils.isEmpty(schemaElementMatches)) {
return;
}
schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.forEach(schemaElementMatch -> {
QueryFilter queryFilter = new QueryFilter();
queryFilter.setValue(schemaElementMatch.getWord());
queryFilter.setElementID(schemaElementMatch.getElement().getId());
queryFilter.setName(schemaElementMatch.getElement().getName());
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
semanticParseInfo.getDimensionFilters().add(queryFilter);
});
}
}

View File

@@ -0,0 +1,91 @@
package com.tencent.supersonic.chat.server.plugin.recall.embedding;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.plugin.ParseMode;
import com.tencent.supersonic.chat.server.plugin.Plugin;
import com.tencent.supersonic.chat.server.plugin.PluginManager;
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.server.plugin.recall.PluginParser;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.headless.core.chat.parser.PythonLLMProxy;
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
/**
* EmbeddingRecallParser is an implementation of a recall plugin based on Embedding
*/
@Slf4j
public class EmbeddingRecallParser extends PluginParser {
public boolean checkPreCondition(ChatParseReq chatParseReq) {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
if (StringUtils.isBlank(embeddingConfig.getUrl()) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) {
return false;
}
List<Plugin> plugins = getPluginList(chatParseReq);
return !CollectionUtils.isEmpty(plugins);
}
public PluginRecallResult recallPlugin(ChatParseReq chatParseReq) {
String text = chatParseReq.getQueryText();
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
return null;
}
List<Plugin> plugins = getPluginList(chatParseReq);
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
if (plugin == null) {
continue;
}
//todo
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, null);
log.info("embedding plugin resolve: {}", pair);
if (pair.getLeft()) {
Set<Long> dataSetList = pair.getRight();
if (CollectionUtils.isEmpty(dataSetList)) {
continue;
}
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
double distance = embeddingRetrieval.getDistance();
double score = chatParseReq.getQueryText().length() * (1 - distance);
return PluginRecallResult.builder()
.plugin(plugin).dataSetIds(dataSetList).score(score).distance(distance).build();
}
}
return null;
}
public List<Retrieval> embeddingRecall(String embeddingText) {
try {
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
RetrieveQueryResult embeddingResp = pluginManager.recognize(embeddingText);
List<Retrieval> embeddingRetrievals = embeddingResp.getRetrieval();
if (!CollectionUtils.isEmpty(embeddingRetrievals)) {
embeddingRetrievals = embeddingRetrievals.stream().sorted(Comparator.comparingDouble(o ->
Math.abs(o.getDistance()))).collect(Collectors.toList());
embeddingResp.setRetrieval(embeddingRetrievals);
}
return embeddingRetrievals;
} catch (Exception e) {
log.warn("get embedding result error ", e);
}
return Lists.newArrayList();
}
}

View File

@@ -0,0 +1,19 @@
package com.tencent.supersonic.chat.server.plugin.recall.embedding;
import lombok.Data;
@Data
public class RecallRetrieval {
private String id;
private String distance;
private String presetQuery;
private String presetId;
private String query;
}

View File

@@ -0,0 +1,15 @@
package com.tencent.supersonic.chat.server.plugin.recall.embedding;
import lombok.Data;
import java.util.List;
@Data
public class RecallRetrievalResp {
private String query;
private List<RecallRetrieval> retrieval;
}

View File

@@ -0,0 +1,15 @@
package com.tencent.supersonic.chat.server.plugin.recall.function;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Configuration
@Data
public class FunctionCallConfig {
@Value("${functionCall.url:}")
private String url;
@Value("${funtionCall.plugin.select.path:/plugin_selection}")
private String pluginSelectPath;
}

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.server.plugin.recall.function;
import lombok.Data;
@Data
public class FunctionFiled {
private String type;
private String description;
}

View File

@@ -0,0 +1,75 @@
package com.tencent.supersonic.chat.server.plugin.recall.function;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.core.chat.parser.llm.InputFormat;
import dev.langchain4j.model.chat.ChatLanguageModel;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.stream.Collectors;
@Component
@Slf4j
public class FunctionPromptGenerator {
public String generateFunctionCallPrompt(String queryText, List<PluginParseConfig> toolConfigList) {
List<String> toolExplainList = toolConfigList.stream()
.map(this::constructPluginPrompt)
.collect(Collectors.toList());
String functionList = String.join(InputFormat.SEPERATOR, toolExplainList);
return constructTaskPrompt(queryText, functionList);
}
public String constructPluginPrompt(PluginParseConfig parseConfig) {
String toolName = parseConfig.getName();
String toolDescription = parseConfig.getDescription();
List<String> toolExamples = parseConfig.getExamples();
StringBuilder prompt = new StringBuilder();
prompt.append("【工具名称】\n").append(toolName).append("\n");
prompt.append("【工具描述】\n").append(toolDescription).append("\n");
prompt.append("【工具适用问题示例】\n");
for (String example : toolExamples) {
prompt.append(example).append("\n");
}
return prompt.toString();
}
public String constructTaskPrompt(String queryText, String functionList) {
String instruction = String.format("问题为:%s\n请根据问题和工具的描述选择对应的工具完成任务。"
+ "请注意只能选择1个工具。请一步一步地分析选择工具的原因(每个工具的【工具适用问题示例】是选择的重要参考依据)"
+ "并给出最终选择输出格式为json,key为分析过程, ’选择工具‘", queryText);
return String.format("工具选择如下:\n\n%s\n\n【任务说明】\n%s", functionList, instruction);
}
public FunctionResp requestFunction(FunctionReq functionReq) {
FunctionPromptGenerator promptGenerator = ContextUtils.getBean(FunctionPromptGenerator.class);
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
String functionCallPrompt = promptGenerator.generateFunctionCallPrompt(functionReq.getQueryText(),
functionReq.getPluginConfigs());
String response = chatLanguageModel.generate(functionCallPrompt);
return functionCallParse(response);
}
public static FunctionResp functionCallParse(String llmOutput) {
try {
ObjectMapper objectMapper = new ObjectMapper();
JsonNode jsonNode = objectMapper.readTree(llmOutput);
String selectedTool = jsonNode.get("选择工具").asText();
FunctionResp resp = new FunctionResp();
resp.setToolSelection(selectedTool);
return resp;
} catch (Exception e) {
log.error("", e);
}
return null;
}
}

View File

@@ -0,0 +1,17 @@
package com.tencent.supersonic.chat.server.plugin.recall.function;
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
import lombok.Builder;
import lombok.Data;
import java.util.List;
@Data
@Builder
public class FunctionReq {
private String queryText;
private List<PluginParseConfig> pluginConfigs;
}

View File

@@ -0,0 +1,10 @@
package com.tencent.supersonic.chat.server.plugin.recall.function;
import lombok.Data;
@Data
public class FunctionResp {
private String toolSelection;
}

View File

@@ -0,0 +1,18 @@
package com.tencent.supersonic.chat.server.plugin.recall.function;
import lombok.Data;
import java.util.List;
import java.util.Map;
@Data
public class Parameters {
//default: object
private String type = "object";
private Map<String, FunctionFiled> properties;
private List<String> required;
}

View File

@@ -5,4 +5,5 @@ package com.tencent.supersonic.chat.server.processor;
*/
public interface ResultProcessor {
}

View File

@@ -3,11 +3,11 @@ package com.tencent.supersonic.chat.server.processor.execute;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.RelatedSchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.service.SemanticService;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.server.service.impl.SemanticService;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import org.springframework.util.CollectionUtils;
@@ -33,9 +33,9 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
|| CollectionUtils.isEmpty(semanticParseInfo.getMetrics())) {
return;
}
SchemaElement element = semanticParseInfo.getMetrics().iterator().next();
List<SchemaElement> dimensionRecommended = getDimensions(element.getId(), element.getDataSet());
queryResult.setRecommendedDimensions(dimensionRecommended);
//SchemaElement element = semanticParseInfo.getMetrics().iterator().next();
//List<SchemaElement> dimensionRecommended = getDimensions(element.getId(), element.getDataSet());
//queryResult.setRecommendedDimensions(dimensionRecommended);
}
private List<SchemaElement> getDimensions(Long metricId, Long dataSetId) {

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
/**

View File

@@ -1,25 +1,8 @@
package com.tencent.supersonic.chat.server.processor.execute;
import static com.tencent.supersonic.common.pojo.Constants.DAY;
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT_INT;
import static com.tencent.supersonic.common.pojo.Constants.MONTH;
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT_INT;
import static com.tencent.supersonic.common.pojo.Constants.TIMES_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.AggregateInfo;
import com.tencent.supersonic.chat.api.pojo.response.MetricInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.core.config.AggregatorConfig;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.chat.core.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.QueryColumn;
@@ -29,8 +12,16 @@ import com.tencent.supersonic.common.pojo.enums.RatioOverType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.core.config.AggregatorConfig;
import com.tencent.supersonic.headless.core.utils.QueryReqBuilder;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.text.DecimalFormat;
import java.time.DayOfWeek;
import java.time.LocalDate;
@@ -48,8 +39,16 @@ import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import static com.tencent.supersonic.common.pojo.Constants.DAY;
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT_INT;
import static com.tencent.supersonic.common.pojo.Constants.MONTH;
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT_INT;
import static com.tencent.supersonic.common.pojo.Constants.TIMES_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
/**
* Add ratio queries for metric queries.
@@ -57,7 +56,7 @@ import org.springframework.util.CollectionUtils;
@Slf4j
public class MetricRatioProcessor implements ExecuteResultProcessor {
private SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
//private SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
@Override
public void process(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) {
@@ -68,8 +67,8 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
|| !QueryType.METRIC.equals(semanticParseInfo.getQueryType())) {
return;
}
AggregateInfo aggregateInfo = getAggregateInfo(queryReq.getUser(), semanticParseInfo, queryResult);
queryResult.setAggregateInfo(aggregateInfo);
//AggregateInfo aggregateInfo = getAggregateInfo(queryReq.getUser(), semanticParseInfo, queryResult);
//queryResult.setAggregateInfo(aggregateInfo);
}
public AggregateInfo getAggregateInfo(User user, SemanticParseInfo semanticParseInfo, QueryResult queryResult) {
@@ -133,7 +132,7 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
queryStructReq.setDateInfo(getRatioDateConf(aggOperatorEnum, semanticParseInfo, queryResult));
queryStructReq.setConvertToSql(false);
SemanticQueryResp queryResp = semanticInterpreter.queryByStruct(queryStructReq, user);
SemanticQueryResp queryResp = null;
MetricInfo metricInfo = new MetricInfo();
metricInfo.setStatistics(new HashMap<>());
if (Objects.isNull(queryResp) || CollectionUtils.isEmpty(queryResp.getResultList())) {

View File

@@ -1,21 +1,18 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.core.knowledge.MetaEmbeddingService;
import org.springframework.util.CollectionUtils;
import java.util.Collections;
@@ -33,8 +30,6 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
private static final int METRIC_RECOMMEND_SIZE = 5;
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
@Override
public void process(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) {
fillSimilarMetric(queryResult.getChatContext());
@@ -54,8 +49,7 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
.filterCondition(filterCondition).queryEmbeddings(null).build();
MetaEmbeddingService metaEmbeddingService = ContextUtils.getBean(MetaEmbeddingService.class);
List<RetrieveQueryResult> retrieveQueryResults =
metaEmbeddingService.retrieveQuery(Lists.newArrayList(parseInfo.getDataSetId()),
retrieveQuery, METRIC_RECOMMEND_SIZE + 1);
metaEmbeddingService.retrieveQuery(retrieveQuery, METRIC_RECOMMEND_SIZE + 1, new HashMap<>());
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
return;
}

View File

@@ -1,50 +0,0 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.query.llm.analytics.MetricAnalyzeQuery;
import com.tencent.supersonic.chat.server.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.stream.Collectors;
/**
* EntityInfoProcessor fills core attributes of an entity so that
* users get to know which entity is parsed out.
*/
public class EntityInfoProcessor implements ParseResultProcessor {
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
if (CollectionUtils.isEmpty(semanticQueries)) {
return;
}
List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo)
.collect(Collectors.toList());
selectedParses.forEach(parseInfo -> {
String queryMode = parseInfo.getQueryMode();
if (QueryManager.containsPluginQuery(queryMode)
|| MetricAnalyzeQuery.QUERY_MODE.equalsIgnoreCase(queryMode)) {
return;
}
//1. set entity info
DataSetSchema dataSetSchema =
queryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, queryContext.getUser());
if (QueryManager.isTagQuery(queryMode)
|| QueryManager.isMetricQuery(queryMode)) {
parseInfo.setEntityInfo(entityInfo);
}
});
}
}

View File

@@ -1,222 +0,0 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.server.service.impl.SchemaService;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* ParseInfoProcessor extracts structured info from S2SQL so that
* users get to know the details.
**/
@Slf4j
public class ParseInfoProcessor implements ParseResultProcessor {
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
if (CollectionUtils.isEmpty(candidateQueries)) {
return;
}
List<SemanticParseInfo> candidateParses = candidateQueries.stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
candidateParses.forEach(this::updateParseInfo);
}
public void updateParseInfo(SemanticParseInfo parseInfo) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
if (StringUtils.isBlank(correctS2SQL)) {
return;
}
// if S2SQL equals correctS2SQL, then not update the parseInfo.
if (correctS2SQL.equals(sqlInfo.getS2SQL())) {
return;
}
List<FieldExpression> expressions = SqlSelectHelper.getFilterExpression(correctS2SQL);
//set dataInfo
try {
if (!org.apache.commons.collections.CollectionUtils.isEmpty(expressions)) {
DateConf dateInfo = getDateInfo(expressions);
if (dateInfo != null && parseInfo.getDateInfo() == null) {
parseInfo.setDateInfo(dateInfo);
}
}
} catch (Exception e) {
log.error("set dateInfo error :", e);
}
//set filter
Long dataSetId = parseInfo.getDataSetId();
try {
Map<String, SchemaElement> fieldNameToElement = getNameToElement(dataSetId);
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
parseInfo.getDimensionFilters().addAll(result);
} catch (Exception e) {
log.error("set dimensionFilter error :", e);
}
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
if (Objects.isNull(semanticSchema)) {
return;
}
List<String> allFields = getFieldsExceptDate(SqlSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
Set<SchemaElement> metrics = getElements(dataSetId, allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics);
if (QueryType.METRIC.equals(parseInfo.getQueryType())) {
List<String> groupByFields = SqlSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL());
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
parseInfo.setDimensions(getElements(dataSetId, groupByDimensions, semanticSchema.getDimensions()));
} else if (QueryType.TAG.equals(parseInfo.getQueryType())) {
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
List<String> selectDimensions = getFieldsExceptDate(selectFields);
parseInfo.setDimensions(getElements(dataSetId, selectDimensions, semanticSchema.getDimensions()));
}
}
private Set<SchemaElement> getElements(Long dataSetId, List<String> allFields, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> {
if (CollectionUtils.isEmpty(schemaElement.getAlias())) {
return dataSetId.equals(schemaElement.getDataSet()) && allFields.contains(
schemaElement.getName());
}
Set<String> allFieldsSet = new HashSet<>(allFields);
Set<String> aliasSet = new HashSet<>(schemaElement.getAlias());
List<String> intersection = allFieldsSet.stream()
.filter(aliasSet::contains).collect(Collectors.toList());
return dataSetId.equals(schemaElement.getDataSet()) && (allFields.contains(
schemaElement.getName()) || !CollectionUtils.isEmpty(intersection));
}
).collect(Collectors.toSet());
}
private List<String> getFieldsExceptDate(List<String> allFields) {
if (org.springframework.util.CollectionUtils.isEmpty(allFields)) {
return new ArrayList<>();
}
return allFields.stream()
.filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry))
.collect(Collectors.toList());
}
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
List<FieldExpression> fieldExpressions) {
List<QueryFilter> result = Lists.newArrayList();
for (FieldExpression expression : fieldExpressions) {
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(expression.getFieldValue());
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
if (Objects.isNull(schemaElement)) {
continue;
}
dimensionFilter.setName(schemaElement.getName());
dimensionFilter.setBizName(schemaElement.getBizName());
dimensionFilter.setElementID(schemaElement.getId());
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
dimensionFilter.setOperator(operatorEnum);
dimensionFilter.setFunction(expression.getFunction());
result.add(dimensionFilter);
}
return result;
}
private DateConf getDateInfo(List<FieldExpression> fieldExpressions) {
List<FieldExpression> dateExpressions = fieldExpressions.stream()
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
.collect(Collectors.toList());
if (org.apache.commons.collections.CollectionUtils.isEmpty(dateExpressions)) {
return null;
}
DateConf dateInfo = new DateConf();
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
FieldExpression firstExpression = dateExpressions.get(0);
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
return dateInfo;
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
}
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
}
}
return dateInfo;
}
private boolean containOperators(FieldExpression expression, FilterOperatorEnum firstOperator,
FilterOperatorEnum... operatorEnums) {
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(
expression.getFieldValue()));
}
private boolean hasSecondDate(List<FieldExpression> dateExpressions) {
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
}
protected Map<String, SchemaElement> getNameToElement(Long dataSetId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dimensions = semanticSchema.getDimensions(dataSetId);
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
List<SchemaElement> allElements = Lists.newArrayList();
allElements.addAll(dimensions);
allElements.addAll(metrics);
//support alias
return allElements.stream()
.flatMap(schemaElement -> {
Set<Pair<String, SchemaElement>> result = new HashSet<>();
result.add(Pair.of(schemaElement.getName(), schemaElement));
List<String> aliasList = schemaElement.getAlias();
if (!org.springframework.util.CollectionUtils.isEmpty(aliasList)) {
for (String alias : aliasList) {
result.add(Pair.of(alias, schemaElement));
}
}
return result.stream();
})
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(),
(value1, value2) -> value2));
}
}

View File

@@ -1,14 +1,10 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
/**
* A ParseResultProcessor wraps things up before returning results to users in parse stage.
*/
public interface ParseResultProcessor extends ResultProcessor {
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext);
public interface ParseResultProcessor {
void process(ParseResp parseResp, ChatParseReq chatParseReq);
}

View File

@@ -1,83 +0,0 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
/**
* QueryRankProcessor ranks candidate parsing results based on
* a heuristic scoring algorithm and then takes topN.
**/
@Slf4j
public class QueryRankProcessor implements ParseResultProcessor {
private static final int candidateTopSize = 5;
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
candidateQueries = rank(candidateQueries);
queryContext.setCandidateQueries(candidateQueries);
}
public List<SemanticQuery> rank(List<SemanticQuery> candidateQueries) {
log.debug("pick before [{}]", candidateQueries);
if (CollectionUtils.isEmpty(candidateQueries)) {
return candidateQueries;
}
List<SemanticQuery> selectedQueries = new ArrayList<>();
if (candidateQueries.size() == 1) {
selectedQueries.addAll(candidateQueries);
} else {
selectedQueries = getTopCandidateQuery(candidateQueries);
}
generateParseInfoId(selectedQueries);
log.debug("pick after [{}]", selectedQueries);
return selectedQueries;
}
public List<SemanticQuery> getTopCandidateQuery(List<SemanticQuery> semanticQueries) {
return semanticQueries.stream()
.filter(query -> !checkFullyInherited(query))
.sorted((o1, o2) -> {
if (o1.getParseInfo().getScore() < o2.getParseInfo().getScore()) {
return 1;
} else if (o1.getParseInfo().getScore() > o2.getParseInfo().getScore()) {
return -1;
}
return 0;
}).limit(candidateTopSize)
.collect(Collectors.toList());
}
private void generateParseInfoId(List<SemanticQuery> semanticQueries) {
for (int i = 0; i < semanticQueries.size(); i++) {
SemanticQuery query = semanticQueries.get(i);
query.getParseInfo().setId(i + 1);
}
}
private boolean checkFullyInherited(SemanticQuery query) {
SemanticParseInfo parseInfo = query.getParseInfo();
if (!(query instanceof RuleSemanticQuery)) {
return false;
}
for (SchemaElementMatch match : parseInfo.getElementMatches()) {
if (!match.isInherited()) {
return false;
}
}
return parseInfo.getDateInfo() == null || parseInfo.getDateInfo().isInherited();
}
}

View File

@@ -3,29 +3,30 @@ package com.tencent.supersonic.chat.server.processor.parse;
import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.PageInfo;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.core.utils.SimilarQueryManager;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
/**
* MetricRecommendProcessor fills recommended query based on embedding similarity.
*/
@Slf4j
public class QueryRecommendProcessor implements ParseResultProcessor {
public class QueryRecommendProcessor implements ResultProcessor {
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
@@ -35,8 +36,9 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
@SneakyThrows
private void doProcess(ParseResp parseResp, QueryContext queryContext) {
Long queryId = parseResp.getQueryId();
//TODO
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(queryContext.getQueryText(),
queryContext.getAgentId());
null);
ChatQueryDO chatQueryDO = getChatQuery(queryId);
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
updateChatQuery(chatQueryDO);
@@ -44,8 +46,8 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
public List<SimilarQueryRecallResp> getSimilarQueries(String queryText, Integer agentId) {
//1. recall solved query by queryText
SimilarQueryManager solvedQueryManager = ContextUtils.getBean(SimilarQueryManager.class);
List<SimilarQueryRecallResp> similarQueries = solvedQueryManager.recallSimilarQuery(queryText, agentId);
//SimilarQueryManager solvedQueryManager = ContextUtils.getBean(SimilarQueryManager.class);
List<SimilarQueryRecallResp> similarQueries = Lists.newArrayList();
if (CollectionUtils.isEmpty(similarQueries)) {
return Lists.newArrayList();
}

View File

@@ -1,37 +0,0 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
/**
* RespBuildProcessor fill response object with parsing results.
**/
@Slf4j
public class RespBuildProcessor implements ParseResultProcessor {
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
parseResp.setChatId(queryContext.getChatId());
parseResp.setQueryText(queryContext.getQueryText());
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
ChatService chatService = ContextUtils.getBean(ChatService.class);
if (candidateQueries.size() > 0) {
List<SemanticParseInfo> candidateParses = candidateQueries.stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
parseResp.setSelectedParses(candidateParses);
parseResp.setState(ParseResp.ParseState.COMPLETED);
} else {
parseResp.setState(ParseResp.ParseState.FAILED);
}
chatService.batchAddParse(chatContext, queryContext, parseResp);
}
}

View File

@@ -1,67 +0,0 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
/**
* SqlInfoProcessor adds S2SQL to the parsing results so that
* technical users could verify SQL by themselves.
**/
public class SqlInfoProcessor implements ParseResultProcessor {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
if (CollectionUtils.isEmpty(semanticQueries)) {
return;
}
List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo)
.collect(Collectors.toList());
long startTime = System.currentTimeMillis();
addSqlInfo(queryContext, selectedParses);
parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - startTime);
}
private void addSqlInfo(QueryContext queryContext, List<SemanticParseInfo> semanticParseInfos) {
if (CollectionUtils.isEmpty(semanticParseInfos)) {
return;
}
semanticParseInfos.forEach(parseInfo -> {
addSqlInfo(queryContext, parseInfo);
});
}
private void addSqlInfo(QueryContext queryContext, SemanticParseInfo parseInfo) {
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
if (Objects.isNull(semanticQuery)) {
return;
}
semanticQuery.setParseInfo(parseInfo);
String explainSql = semanticQuery.explain(queryContext.getUser());
if (StringUtils.isBlank(explainSql)) {
return;
}
SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (semanticQuery instanceof LLMSqlQuery) {
keyPipelineLog.info("\ns2sql:{}\ncorrectS2SQL:{}\nquerySQL:{}", sqlInfo.getS2SQL(),
sqlInfo.getCorrectS2SQL(), explainSql);
}
sqlInfo.setQuerySQL(explainSql);
}
}

View File

@@ -1,21 +0,0 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import lombok.extern.slf4j.Slf4j;
/**
* TimeCostProcessor adds time cost of parsing.
**/
@Slf4j
public class TimeCostProcessor implements ParseResultProcessor {
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
parseResp.getParseTimeCost().setParseTime(
System.currentTimeMillis() - parseStartTime - parseResp.getParseTimeCost().getSqlTime());
}
}

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.server.rest;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.service.AgentService;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.PathVariable;

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.chat.server.rest;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
@@ -8,16 +7,9 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.chat.server.service.ConfigService;
import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq;
import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.server.service.SchemaService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
@@ -39,8 +31,8 @@ public class ChatConfigController {
@Autowired
private ConfigService configService;
private SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
@Autowired
private SchemaService schemaService;
@PostMapping
public Long addChatConfig(@RequestBody ChatConfigBaseReq extendBaseCmd,
@@ -76,40 +68,9 @@ public class ChatConfigController {
return configService.getAllChatRichConfig();
}
@GetMapping("/domainList")
public List<DomainResp> getDomainList(HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return semanticInterpreter.getDomainList(user);
}
//Compatible with front-end
@GetMapping("/dataSetList")
public List<DataSetResp> getDataSetList() {
return semanticInterpreter.getDataSetList(null);
}
@GetMapping("/dataSetList/{domainId}")
public List<DataSetResp> getDataSetList(@PathVariable("domainId") Long domainId) {
return semanticInterpreter.getDataSetList(domainId);
}
@PostMapping("/dimension/page")
public PageInfo<DimensionResp> getDimension(@RequestBody PageDimensionReq pageDimensionReq) {
return semanticInterpreter.getDimensionPage(pageDimensionReq);
}
@PostMapping("/metric/page")
public PageInfo<MetricResp> getMetric(@RequestBody PageMetricReq pageMetricReq,
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return semanticInterpreter.getMetricPage(pageMetricReq, user);
}
@GetMapping("/getDomainDataSetTree")
public List<ItemResp> getDomainDataSetTree() {
return semanticInterpreter.getDomainDataSetTree();
return schemaService.getDomainDataSetTree();
}
}

View File

@@ -4,7 +4,7 @@ package com.tencent.supersonic.chat.server.rest;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.server.service.ChatService;

View File

@@ -1,24 +1,23 @@
package com.tencent.supersonic.chat.server.rest;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.chat.server.service.QueryService;
import com.tencent.supersonic.chat.server.service.SearchService;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.validation.Valid;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.validation.Valid;
/**
* query controller
*/
@@ -27,62 +26,49 @@ import org.springframework.web.bind.annotation.RestController;
public class ChatQueryController {
@Autowired
@Qualifier("chatQueryService")
private QueryService queryService;
@Autowired
private SearchService searchService;
private ChatService chatService;
@PostMapping("search")
public Object search(@RequestBody QueryReq queryCtx, HttpServletRequest request,
public Object search(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request,
HttpServletResponse response) {
queryCtx.setUser(UserHolder.findUser(request, response));
return searchService.search(queryCtx);
chatParseReq.setUser(UserHolder.findUser(request, response));
return chatService.search(chatParseReq);
}
@PostMapping("parse")
public Object parse(@RequestBody QueryReq queryCtx, HttpServletRequest request, HttpServletResponse response)
throws Exception {
queryCtx.setUser(UserHolder.findUser(request, response));
return queryService.performParsing(queryCtx);
public Object parse(@RequestBody ChatParseReq chatParseReq,
HttpServletRequest request, HttpServletResponse response) throws Exception {
chatParseReq.setUser(UserHolder.findUser(request, response));
return chatService.performParsing(chatParseReq);
}
@PostMapping("execute")
public Object execute(@RequestBody ExecuteQueryReq queryReq,
public Object execute(@RequestBody ChatExecuteReq chatExecuteReq,
HttpServletRequest request, HttpServletResponse response)
throws Exception {
queryReq.setUser(UserHolder.findUser(request, response));
return queryService.performExecution(queryReq);
chatExecuteReq.setUser(UserHolder.findUser(request, response));
return chatService.performExecution(chatExecuteReq);
}
@PostMapping("queryContext")
public Object queryContext(@RequestBody QueryReq queryCtx, HttpServletRequest request,
HttpServletResponse response) throws Exception {
public Object queryContext(@RequestBody QueryReq queryCtx,
HttpServletRequest request, HttpServletResponse response) {
queryCtx.setUser(UserHolder.findUser(request, response));
return queryService.queryContext(queryCtx);
return chatService.queryContext(queryCtx.getChatId());
}
@PostMapping("queryData")
public Object queryData(@RequestBody QueryDataReq queryData,
HttpServletRequest request, HttpServletResponse response)
throws Exception {
HttpServletRequest request, HttpServletResponse response) throws Exception {
queryData.setUser(UserHolder.findUser(request, response));
return queryService.executeDirectQuery(queryData, UserHolder.findUser(request, response));
return chatService.queryData(queryData, UserHolder.findUser(request, response));
}
@PostMapping("queryDimensionValue")
public Object queryDimensionValue(@RequestBody @Valid DimensionValueReq dimensionValueReq,
HttpServletRequest request, HttpServletResponse response)
throws Exception {
return queryService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
HttpServletRequest request, HttpServletResponse response) throws Exception {
return chatService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
}
@RequestMapping("/getEntityInfo")
public Object getEntityInfo(Long queryId, Integer parseId,
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return queryService.getEntityInfo(queryId, parseId, user);
}
}

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.server.rest;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq;
import com.tencent.supersonic.chat.core.plugin.Plugin;
import com.tencent.supersonic.chat.server.plugin.Plugin;
import com.tencent.supersonic.chat.server.service.PluginService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RestController;

View File

@@ -1,48 +0,0 @@
package com.tencent.supersonic.chat.server.rest;
import com.tencent.supersonic.chat.api.pojo.request.RecommendReq;
import com.tencent.supersonic.chat.api.pojo.response.RecommendQuestionResp;
import com.tencent.supersonic.chat.api.pojo.response.RecommendResp;
import com.tencent.supersonic.chat.server.service.RecommendService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestParam;
import java.util.List;
/**
* recommend controller
*/
@RestController
@RequestMapping({"/api/chat/", "/openapi/chat/"})
public class RecommendController {
@Autowired
private RecommendService recommendService;
@GetMapping("recommend/{modelId}")
public RecommendResp recommend(@PathVariable("modelId") Long modelId,
@RequestParam(value = "limit", required = false) Long limit) {
RecommendReq recommendReq = new RecommendReq();
recommendReq.setModelId(modelId);
return recommendService.recommend(recommendReq, limit);
}
@GetMapping("recommend/metric/{modelId}")
public RecommendResp recommendMetricMode(@PathVariable("modelId") Long modelId,
@RequestParam(value = "metricId", required = false) Long metricId,
@RequestParam(value = "limit", required = false) Long limit) {
RecommendReq recommendReq = new RecommendReq();
recommendReq.setModelId(modelId);
recommendReq.setMetricId(metricId);
return recommendService.recommendMetricMode(recommendReq, limit);
}
@GetMapping("recommend/question")
public List<RecommendQuestionResp> recommendQuestion(
@RequestParam(value = "modelId", required = false) Long modelId) {
return recommendService.recommendQuestion(modelId);
}
}

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.server.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.server.agent.Agent;
import java.util.List;
public interface AgentService {

View File

@@ -2,30 +2,36 @@ package com.tencent.supersonic.chat.server.service;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import java.util.List;
public interface ChatService {
/***
* get the model from context
* @param chatId
* @return
*/
Long getContextModel(Integer chatId);
List<SearchResult> search(ChatParseReq chatParseReq);
ChatContext getOrCreateContext(int chatId);
ParseResp performParsing(ChatParseReq chatParseReq);
void updateContext(ChatContext chatCtx);
QueryResult performExecution(ChatExecuteReq chatExecuteReq) throws Exception;
Object queryData(QueryDataReq queryData, User user) throws Exception;
SemanticParseInfo queryContext(Integer chatId);
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
Boolean addChat(User user, String chatName, Integer agentId);
@@ -45,13 +51,13 @@ public interface ChatService {
ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryContext queryContext, ParseResp parseResult);
List<ChatParseDO> batchAddParse(ChatParseReq chatParseReq, ParseResp parseResult);
ChatQueryDO getLastQuery(long chatId);
int updateQuery(ChatQueryDO chatQueryDO);
void updateQuery(Long questionId, int parseId, QueryResult queryResult, ChatContext chatCtx);
void saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult);
ChatParseDO getParseInfo(Long questionId, int parseId);

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.server.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq;
import com.tencent.supersonic.chat.core.plugin.Plugin;
import com.tencent.supersonic.chat.server.plugin.Plugin;
import java.util.List;
import java.util.Map;

View File

@@ -1,31 +0,0 @@
package com.tencent.supersonic.chat.server.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import org.apache.calcite.sql.parser.SqlParseException;
/***
* QueryService for query and search
*/
public interface QueryService {
ParseResp performParsing(QueryReq queryReq);
QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception;
SemanticParseInfo queryContext(QueryReq queryReq);
QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws SqlParseException;
EntityInfo getEntityInfo(Long queryId, Integer parseId, User user);
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
}

View File

@@ -1,14 +0,0 @@
package com.tencent.supersonic.chat.server.service;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SearchResult;
import java.util.List;
/**
* search service
*/
public interface SearchService {
List<SearchResult> search(QueryReq queryCtx);
}

View File

@@ -1,229 +0,0 @@
package com.tencent.supersonic.chat.server.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.DataInfo;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.DataSetInfo;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.chat.core.utils.QueryReqBuilder;
import com.tencent.supersonic.chat.server.service.impl.SchemaService;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.time.LocalDate;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@Service
@Slf4j
public class SemanticService {
@Autowired
private SchemaService schemaService;
private SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
public SemanticSchema getSemanticSchema() {
return schemaService.getSemanticSchema();
}
public DataSetSchema getDataSetSchema(Long id) {
return schemaService.getDataSetSchema(id);
}
public EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user) {
if (parseInfo != null && parseInfo.getDataSetId() > 0) {
EntityInfo entityInfo = getEntityBasicInfo(dataSetSchema);
if (parseInfo.getDimensionFilters().size() <= 0 || entityInfo.getDataSetInfo() == null) {
entityInfo.setMetrics(null);
entityInfo.setDimensions(null);
return entityInfo;
}
String primaryKey = entityInfo.getDataSetInfo().getPrimaryKey();
if (StringUtils.isNotBlank(primaryKey)) {
String entityId = "";
for (QueryFilter chatFilter : parseInfo.getDimensionFilters()) {
if (chatFilter != null && chatFilter.getBizName() != null && chatFilter.getBizName()
.equals(primaryKey)) {
if (chatFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) {
entityId = chatFilter.getValue().toString();
}
}
}
entityInfo.setEntityId(entityId);
try {
fillEntityInfoValue(entityInfo, dataSetSchema, user);
return entityInfo;
} catch (Exception e) {
log.error("setMainModel error", e);
}
}
}
return null;
}
private EntityInfo getEntityBasicInfo(DataSetSchema dataSetSchema) {
EntityInfo entityInfo = new EntityInfo();
if (dataSetSchema == null) {
return entityInfo;
}
Long dataSetId = dataSetSchema.getDataSet().getDataSet();
DataSetInfo dataSetInfo = new DataSetInfo();
dataSetInfo.setItemId(dataSetId.intValue());
dataSetInfo.setName(dataSetSchema.getDataSet().getName());
dataSetInfo.setWords(dataSetSchema.getDataSet().getAlias());
dataSetInfo.setBizName(dataSetSchema.getDataSet().getBizName());
if (Objects.nonNull(dataSetSchema.getEntity())) {
dataSetInfo.setPrimaryKey(dataSetSchema.getEntity().getBizName());
}
entityInfo.setDataSetInfo(dataSetInfo);
TagTypeDefaultConfig tagTypeDefaultConfig = dataSetSchema.getTagTypeDefaultConfig();
if (tagTypeDefaultConfig == null || tagTypeDefaultConfig.getDefaultDisplayInfo() == null) {
return entityInfo;
}
List<DataInfo> dimensions = tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream()
.map(id -> {
SchemaElement element = dataSetSchema.getElement(SchemaElementType.DIMENSION, id);
if (element == null) {
return null;
}
return new DataInfo(element.getId().intValue(), element.getName(), element.getBizName(), null);
}).filter(Objects::nonNull).collect(Collectors.toList());
List<DataInfo> metrics = tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream()
.map(id -> {
SchemaElement element = dataSetSchema.getElement(SchemaElementType.METRIC, id);
if (element == null) {
return null;
}
return new DataInfo(element.getId().intValue(), element.getName(), element.getBizName(), null);
}).filter(Objects::nonNull).collect(Collectors.toList());
entityInfo.setDimensions(dimensions);
entityInfo.setMetrics(metrics);
return entityInfo;
}
public void fillEntityInfoValue(EntityInfo entityInfo, DataSetSchema dataSetSchema, User user) {
SemanticQueryResp queryResultWithColumns =
getQueryResultWithSchemaResp(entityInfo, dataSetSchema, user);
if (queryResultWithColumns != null) {
if (!CollectionUtils.isEmpty(queryResultWithColumns.getResultList())
&& queryResultWithColumns.getResultList().size() > 0) {
Map<String, Object> result = queryResultWithColumns.getResultList().get(0);
for (Map.Entry<String, Object> entry : result.entrySet()) {
String entryKey = getEntryKey(entry);
if (entry.getValue() == null || entryKey == null) {
continue;
}
entityInfo.getDimensions().stream().filter(i -> entryKey.equals(i.getBizName()))
.forEach(i -> i.setValue(entry.getValue().toString()));
entityInfo.getMetrics().stream().filter(i -> entryKey.equals(i.getBizName()))
.forEach(i -> i.setValue(entry.getValue().toString()));
}
}
}
}
public SemanticQueryResp getQueryResultWithSchemaResp(EntityInfo entityInfo,
DataSetSchema dataSetSchema, User user) {
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setDataSet(dataSetSchema.getDataSet());
semanticParseInfo.setQueryType(QueryType.TAG);
semanticParseInfo.setMetrics(getMetrics(entityInfo));
semanticParseInfo.setDimensions(getDimensions(entityInfo));
DateConf dateInfo = new DateConf();
int unit = 1;
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
if (Objects.nonNull(timeDefaultConfig)) {
unit = timeDefaultConfig.getUnit();
String date = LocalDate.now().plusDays(-unit).toString();
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
dateInfo.setStartDate(date);
dateInfo.setEndDate(date);
} else {
dateInfo.setUnit(unit);
dateInfo.setDateMode(DateConf.DateMode.RECENT);
}
semanticParseInfo.setDateInfo(dateInfo);
// add filter
QueryFilter chatFilter = getQueryFilter(entityInfo);
Set<QueryFilter> chatFilters = new LinkedHashSet();
chatFilters.add(chatFilter);
semanticParseInfo.setDimensionFilters(chatFilters);
SemanticQueryResp queryResultWithColumns = null;
try {
QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(semanticParseInfo);
queryResultWithColumns = semanticInterpreter.queryByStruct(queryStructReq, user);
} catch (Exception e) {
log.warn("setMainModel queryByStruct error, e:", e);
}
return queryResultWithColumns;
}
private QueryFilter getQueryFilter(EntityInfo entityInfo) {
QueryFilter chatFilter = new QueryFilter();
chatFilter.setValue(entityInfo.getEntityId());
chatFilter.setOperator(FilterOperatorEnum.EQUALS);
chatFilter.setBizName(getEntityPrimaryName(entityInfo));
return chatFilter;
}
private Set<SchemaElement> getDimensions(EntityInfo modelInfo) {
Set<SchemaElement> dimensions = new LinkedHashSet();
for (DataInfo mainEntityDimension : modelInfo.getDimensions()) {
SchemaElement dimension = new SchemaElement();
dimension.setBizName(mainEntityDimension.getBizName());
dimensions.add(dimension);
}
return dimensions;
}
private String getEntryKey(Map.Entry<String, Object> entry) {
// metric parser special handle, TODO delete
String entryKey = entry.getKey();
if (entryKey.contains("__")) {
entryKey = entryKey.split("__")[1];
}
return entryKey;
}
private Set<SchemaElement> getMetrics(EntityInfo modelInfo) {
Set<SchemaElement> metrics = new LinkedHashSet();
for (DataInfo metricValue : modelInfo.getMetrics()) {
SchemaElement metric = new SchemaElement();
BeanUtils.copyProperties(metricValue, metric);
metrics.add(metric);
}
return metrics;
}
private String getEntityPrimaryName(EntityInfo entityInfo) {
return entityInfo.getDataSetInfo().getPrimaryKey();
}
}

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.server.service;
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
import java.util.List;

View File

@@ -1,14 +0,0 @@
package com.tencent.supersonic.chat.server.service;
import java.lang.annotation.Documented;
import java.lang.annotation.Target;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
@Target({ElementType.PARAMETER, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface TimeCost {
}

View File

@@ -1,33 +0,0 @@
package com.tencent.supersonic.chat.server.service;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.stereotype.Component;
@Slf4j
@Component
@Aspect
public class TimeCostAOP {
@Pointcut("@annotation(com.tencent.supersonic.chat.server.service.TimeCost)")
private void timeCostAdvicePointcut() {
}
@Around("timeCostAdvicePointcut()")
public Object timeCostAdvice(ProceedingJoinPoint joinPoint) throws Throwable {
log.info("begin to add time cost!");
Long startTime = System.currentTimeMillis();
Object object = joinPoint.proceed();
if (object instanceof QueryResult) {
QueryResult queryResult = (QueryResult) object;
queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime);
return queryResult;
}
return object;
}
}

View File

@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.server.service.impl;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.server.persistence.repository.AgentRepository;
import com.tencent.supersonic.chat.server.service.AgentService;

View File

@@ -1,26 +1,36 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.QueryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatContextRepository;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.persistence.repository.ChatRepository;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import com.tencent.supersonic.headless.server.service.ChatQueryService;
import com.tencent.supersonic.headless.server.service.SearchService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Primary;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@@ -30,51 +40,87 @@ import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
@Service("ChatService")
@Primary
@Slf4j
@Service
public class ChatServiceImpl implements ChatService {
private ChatContextRepository chatContextRepository;
@Autowired
private ChatRepository chatRepository;
@Autowired
private ChatQueryRepository chatQueryRepository;
@Autowired
private ChatQueryService chatQueryService;
@Autowired
private AgentService agentService;
@Autowired
private SearchService searchService;
public ChatServiceImpl(ChatContextRepository chatContextRepository, ChatRepository chatRepository,
ChatQueryRepository chatQueryRepository) {
this.chatContextRepository = chatContextRepository;
this.chatRepository = chatRepository;
this.chatQueryRepository = chatQueryRepository;
@Override
public List<SearchResult> search(ChatParseReq chatParseReq) {
QueryReq queryReq = buildSqlQueryReq(chatParseReq);
return searchService.search(queryReq);
}
@Override
public Long getContextModel(Integer chatId) {
if (Objects.isNull(chatId)) {
return null;
}
ChatContext chatContext = getOrCreateContext(chatId);
if (Objects.isNull(chatContext)) {
return null;
}
SemanticParseInfo originalSemanticParse = chatContext.getParseInfo();
if (Objects.nonNull(originalSemanticParse) && Objects.nonNull(originalSemanticParse.getDataSetId())) {
return originalSemanticParse.getDataSetId();
}
return null;
public ParseResp performParsing(ChatParseReq chatParseReq) {
QueryReq queryReq = buildSqlQueryReq(chatParseReq);
ParseResp parseResp = chatQueryService.performParsing(queryReq);
batchAddParse(chatParseReq, parseResp);
return parseResp;
}
@Override
public ChatContext getOrCreateContext(int chatId) {
return chatContextRepository.getOrCreateContext(chatId);
public QueryResult performExecution(ChatExecuteReq chatExecuteReq) throws Exception {
ExecuteQueryReq executeQueryReq = buildExecuteReq(chatExecuteReq);
QueryResult queryResult = chatQueryService.performExecution(executeQueryReq);
saveQueryResult(chatExecuteReq, queryResult);
return queryResult;
}
@Override
public void updateContext(ChatContext chatCtx) {
log.debug("save ChatContext {}", chatCtx);
chatContextRepository.updateContext(chatCtx);
public Object queryData(QueryDataReq queryData, User user) throws Exception {
return chatQueryService.executeDirectQuery(queryData, user);
}
@Override
public SemanticParseInfo queryContext(Integer chatId) {
return chatQueryService.queryContext(chatId);
}
@Override
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
return chatQueryService.queryDimensionValue(dimensionValueReq, user);
}
private QueryReq buildSqlQueryReq(ChatParseReq chatParseReq) {
QueryReq queryReq = new QueryReq();
BeanMapper.mapper(chatParseReq, queryReq);
if (chatParseReq.getAgentId() == null) {
return queryReq;
}
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
if (agent == null) {
return queryReq;
}
queryReq.setDataSetIds(agent.getDataSetIds());
return queryReq;
}
private ExecuteQueryReq buildExecuteReq(ChatExecuteReq chatExecuteReq) {
ChatParseDO chatParseDO = getParseInfo(chatExecuteReq.getQueryId(), chatExecuteReq.getParseId());
SemanticParseInfo parseInfo = JSONObject.parseObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
return ExecuteQueryReq.builder()
.queryId(chatExecuteReq.getQueryId())
.chatId(chatExecuteReq.getChatId())
.queryText(chatExecuteReq.getQueryText())
.parseInfo(parseInfo)
.saveAnswer(chatExecuteReq.isSaveAnswer())
.user(chatExecuteReq.getUser())
.build();
}
@Override
@@ -190,18 +236,18 @@ public class ChatServiceImpl implements ChatService {
}
@Override
public void updateQuery(Long questionId, int parseId, QueryResult queryResult, ChatContext chatCtx) {
public void saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult) {
//The history record only retains the query result of the first parse
if (parseId > 1) {
if (chatExecuteReq.getParseId() > 1) {
return;
}
ChatQueryDO chatQueryDO = new ChatQueryDO();
chatQueryDO.setQuestionId(questionId);
chatQueryDO.setQuestionId(chatExecuteReq.getQueryId());
chatQueryDO.setQueryResult(JsonUtil.toString(queryResult));
chatQueryDO.setQueryState(1);
updateQuery(chatQueryDO);
chatRepository.updateLastQuestion(chatCtx.getChatId().longValue(),
chatCtx.getQueryText(), getCurrentTime());
chatRepository.updateLastQuestion(chatExecuteReq.getChatId().longValue(),
chatExecuteReq.getQueryText(), getCurrentTime());
}
@Override
@@ -210,9 +256,9 @@ public class ChatServiceImpl implements ChatService {
}
@Override
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryContext queryContext, ParseResp parseResult) {
public List<ChatParseDO> batchAddParse(ChatParseReq chatParseReq, ParseResp parseResult) {
List<SemanticParseInfo> candidateParses = parseResult.getSelectedParses();
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryContext, parseResult, candidateParses);
return chatQueryRepository.batchSaveParseInfo(chatParseReq, parseResult, candidateParses);
}
@Override

View File

@@ -3,9 +3,6 @@ package com.tencent.supersonic.chat.server.service.impl;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
@@ -24,23 +21,22 @@ import com.tencent.supersonic.chat.api.pojo.response.ChatDetailRichConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.EntityRichInfoResp;
import com.tencent.supersonic.chat.api.pojo.response.ItemVisibilityInfo;
import com.tencent.supersonic.chat.server.config.ChatConfig;
import com.tencent.supersonic.chat.server.util.ChatConfigHelper;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.chat.server.util.VisibilityEvent;
import com.tencent.supersonic.chat.server.persistence.repository.ChatConfigRepository;
import com.tencent.supersonic.chat.server.service.ConfigService;
import com.tencent.supersonic.chat.server.service.SemanticService;
import com.tencent.supersonic.chat.server.util.ChatConfigHelper;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.impl.SemanticService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@@ -62,10 +58,6 @@ public class ConfigServiceImpl implements ConfigService {
private final MetricService metricService;
@Autowired
private SemanticService semanticService;
@Autowired
private ApplicationEventPublisher applicationEventPublisher;
private SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
public ConfigServiceImpl(ChatConfigRepository chatConfigRepository,
@@ -83,9 +75,7 @@ public class ConfigServiceImpl implements ConfigService {
log.info("[create model extend] object:{}", JsonUtil.toString(configBaseCmd, true));
duplicateCheck(configBaseCmd.getModelId());
ChatConfig chaConfig = chatConfigHelper.newChatConfig(configBaseCmd, user);
Long id = chatConfigRepository.createConfig(chaConfig);
applicationEventPublisher.publishEvent(new VisibilityEvent(this, chaConfig));
return id;
return chatConfigRepository.createConfig(chaConfig);
}
private void duplicateCheck(Long modelId) {
@@ -106,7 +96,6 @@ public class ConfigServiceImpl implements ConfigService {
}
ChatConfig chaConfig = chatConfigHelper.editChatConfig(configEditCmd, user);
chatConfigRepository.updateConfig(chaConfig);
applicationEventPublisher.publishEvent(new VisibilityEvent(this, chaConfig));
return configEditCmd.getId();
}
@@ -350,15 +339,7 @@ public class ConfigServiceImpl implements ConfigService {
@Override
public List<ChatConfigRichResp> getAllChatRichConfig() {
List<ChatConfigRichResp> chatConfigRichInfoList = new ArrayList<>();
List<DataSetSchema> modelSchemas = semanticInterpreter.getDataSetSchema();
modelSchemas.stream().forEach(modelSchema -> {
ChatConfigRichResp chatConfigRichInfo = getConfigRichInfo(modelSchema.getDataSet().getId());
if (Objects.nonNull(chatConfigRichInfo)) {
chatConfigRichInfoList.add(chatConfigRichInfo);
}
});
return chatConfigRichInfoList;
return new ArrayList<>();
}
@Override
@@ -367,4 +348,5 @@ public class ConfigServiceImpl implements ConfigService {
return allChatRichConfig.stream()
.collect(Collectors.toMap(ChatConfigRichResp::getModelId, value -> value, (k1, k2) -> k1));
}
}

View File

@@ -4,11 +4,11 @@ import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq;
import com.tencent.supersonic.chat.core.plugin.Plugin;
import com.tencent.supersonic.chat.core.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.core.plugin.event.PluginAddEvent;
import com.tencent.supersonic.chat.core.plugin.event.PluginDelEvent;
import com.tencent.supersonic.chat.core.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.server.plugin.Plugin;
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent;
import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.server.persistence.dataobject.PluginDO;
import com.tencent.supersonic.chat.server.persistence.repository.PluginRepository;
import com.tencent.supersonic.chat.server.service.PluginService;

View File

@@ -1,713 +0,0 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.SimilarQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.corrector.SemanticCorrector;
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.core.knowledge.SearchService;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.mapper.SchemaMapper;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.plugin.Plugin;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
import com.tencent.supersonic.chat.core.utils.SimilarQueryManager;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.CostType;
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.ConfigService;
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.chat.server.service.QueryService;
import com.tencent.supersonic.chat.server.service.SemanticService;
import com.tencent.supersonic.chat.server.service.StatisticsService;
import com.tencent.supersonic.chat.server.service.TimeCost;
import com.tencent.supersonic.chat.server.util.ComponentFactory;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.server.service.KnowledgeService;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.schema.Column;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Component;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@Service
@Component("chatQueryService")
@Primary
@Slf4j
public class QueryServiceImpl implements QueryService {
@Autowired
private ChatService chatService;
@Autowired
private StatisticsService statisticsService;
@Autowired
private SimilarQueryManager similarQueryManager;
@Autowired
private SchemaService schemaService;
@Autowired
private AgentService agentService;
@Autowired
private ConfigService configService;
@Autowired
private PluginService pluginService;
@Autowired
private KnowledgeService knowledgeService;
@Value("${time.threshold: 100}")
private Integer timeThreshold;
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
private List<ParseResultProcessor> parseProcessors = ComponentFactory.getParseProcessors();
private List<ExecuteResultProcessor> executeProcessors = ComponentFactory.getExecuteProcessors();
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
@Override
public ParseResp performParsing(QueryReq queryReq) {
ParseResp parseResult = new ParseResp();
// build queryContext and chatContext
QueryContext queryCtx = buildQueryContext(queryReq);
// in order to support multi-turn conversation, chat context is needed
ChatContext chatCtx = chatService.getOrCreateContext(queryReq.getChatId());
List<StatisticsDO> timeCostDOList = new ArrayList<>();
// 1. mapper
schemaMappers.forEach(mapper -> {
long startTime = System.currentTimeMillis();
mapper.map(queryCtx);
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(mapper.getClass().getSimpleName()).type(CostType.MAPPER.getType()).build());
});
// 2. parser
semanticParsers.forEach(parser -> {
long startTime = System.currentTimeMillis();
parser.parse(queryCtx, chatCtx);
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(parser.getClass().getSimpleName()).type(CostType.PARSER.getType()).build());
log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
});
// 3. corrector
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
if (CollectionUtils.isNotEmpty(candidateQueries)) {
for (SemanticQuery semanticQuery : candidateQueries) {
// the rules are not being corrected.
if (semanticQuery instanceof RuleSemanticQuery) {
continue;
}
semanticCorrectors.forEach(corrector -> {
corrector.correct(queryCtx, semanticQuery.getParseInfo());
});
}
}
// 4. processor
parseProcessors.forEach(processor -> {
long startTime = System.currentTimeMillis();
processor.process(parseResult, queryCtx, chatCtx);
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(processor.getClass().getSimpleName())
.type(CostType.PROCESSOR.getType()).build());
log.debug("{} result:{}", processor.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
});
if (Objects.nonNull(parseResult.getQueryId()) && timeCostDOList.size() > 0) {
saveTimeCostInfo(timeCostDOList, queryReq.getQueryText(), parseResult.getQueryId(),
queryReq.getUser().getName(), queryReq.getChatId().longValue());
}
return parseResult;
}
private QueryContext buildQueryContext(QueryReq queryReq) {
Integer agentId = queryReq.getAgentId();
Agent agent = agentService.getAgent(agentId);
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
Map<Long, ChatConfigRichResp> modelIdToChatRichConfig = configService.getModelIdToChatRichConfig();
Map<String, Plugin> nameToPlugin = pluginService.getNameToPlugin();
List<Plugin> pluginList = pluginService.getPluginList();
QueryContext queryCtx = QueryContext.builder()
.queryFilters(queryReq.getQueryFilters())
.semanticSchema(semanticSchema)
.candidateQueries(new ArrayList<>())
.mapInfo(new SchemaMapInfo())
.agent(agent)
.modelIdToChatRichConfig(modelIdToChatRichConfig)
.nameToPlugin(nameToPlugin)
.pluginList(pluginList)
.build();
BeanUtils.copyProperties(queryReq, queryCtx);
return queryCtx;
}
@Override
@TimeCost
public QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception {
ChatParseDO chatParseDO = chatService.getParseInfo(queryReq.getQueryId(), queryReq.getParseId());
ChatQueryDO chatQueryDO = chatService.getLastQuery(queryReq.getChatId());
List<StatisticsDO> timeCostDOList = new ArrayList<>();
SemanticParseInfo parseInfo = JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
if (semanticQuery == null) {
return null;
}
semanticQuery.setParseInfo(parseInfo);
// in order to support multi-turn conversation, chat context is needed
ChatContext chatCtx = chatService.getOrCreateContext(queryReq.getChatId());
chatCtx.setAgentId(queryReq.getAgentId());
Long startTime = System.currentTimeMillis();
QueryResult queryResult = semanticQuery.execute(queryReq.getUser());
if (queryResult != null) {
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(semanticQuery.getClass().getSimpleName()).type(CostType.QUERY.getType()).build());
queryResult.setQueryTimeCost(timeCostDOList.get(0).getCost().longValue());
saveTimeCostInfo(timeCostDOList, queryReq.getQueryText(), queryReq.getQueryId(),
queryReq.getUser().getName(), queryReq.getChatId().longValue());
queryResult.setChatContext(parseInfo);
// update chat context after a successful semantic query
if (QueryState.SUCCESS.equals(queryResult.getQueryState())) {
chatCtx.setParseInfo(parseInfo);
chatService.updateContext(chatCtx);
saveSolvedQuery(queryReq, parseInfo, chatQueryDO, queryResult);
}
chatCtx.setQueryText(queryReq.getQueryText());
chatCtx.setUser(queryReq.getUser().getName());
for (ExecuteResultProcessor executeResultProcessor : executeProcessors) {
executeResultProcessor.process(queryResult, parseInfo, queryReq);
}
chatService.updateQuery(queryReq.getQueryId(), queryReq.getParseId(), queryResult, chatCtx);
} else {
chatService.deleteChatQuery(queryReq.getQueryId());
}
return queryResult;
}
/**
* save time cost data
*
* @param timeCostDOList
* @param queryText
* @param queryId
* @param userName
* @param chatId
*/
private void saveTimeCostInfo(List<StatisticsDO> timeCostDOList,
String queryText, Long queryId,
String userName, Long chatId) {
List<StatisticsDO> list = timeCostDOList.stream()
.filter(o -> o.getCost() > timeThreshold).collect(Collectors.toList());
list.forEach(o -> {
o.setQueryText(queryText);
o.setQuestionId(queryId);
o.setUserName(userName);
o.setChatId(chatId);
o.setCreateTime(new java.util.Date());
});
if (list.size() > 0) {
log.info("filterStatistics size:{},data:{}", list.size(), JsonUtil.toString(list));
statisticsService.batchSaveStatistics(list);
}
}
private void saveSolvedQuery(ExecuteQueryReq queryReq, SemanticParseInfo parseInfo,
ChatQueryDO chatQueryDO, QueryResult queryResult) {
if (queryResult.getResponse() == null && CollectionUtils.isEmpty(queryResult.getQueryResults())) {
return;
}
similarQueryManager.saveSimilarQuery(SimilarQueryReq.builder().parseId(queryReq.getParseId())
.queryId(queryReq.getQueryId())
.agentId(chatQueryDO.getAgentId())
.dataSetId(parseInfo.getDataSetId())
.queryText(queryReq.getQueryText()).build());
}
@Override
public SemanticParseInfo queryContext(QueryReq queryCtx) {
ChatContext context = chatService.getOrCreateContext(queryCtx.getChatId());
return context.getParseInfo();
}
//mainly used for executing after revising filters,for example:"fans_cnt>=100000"->"fans_cnt>500000",
//"style='流行'"->"style in ['流行','爱国']"
@Override
@TimeCost
public QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws SqlParseException {
ChatParseDO chatParseDO = chatService.getParseInfo(queryData.getQueryId(),
queryData.getParseId());
SemanticParseInfo parseInfo = getSemanticParseInfo(queryData, chatParseDO);
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
semanticQuery.setParseInfo(parseInfo);
List<String> fields = new ArrayList<>();
if (Objects.nonNull(parseInfo.getSqlInfo())
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) {
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
fields = SqlSelectHelper.getAllFields(correctorSql);
}
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())
&& checkMetricReplace(fields, queryData.getMetrics())) {
//replace metrics
log.info("llm begin replace metrics!");
SchemaElement metricToReplace = queryData.getMetrics().iterator().next();
replaceMetrics(parseInfo, metricToReplace);
} else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
log.info("llm begin revise filters!");
String correctorSql = reviseCorrectS2SQL(queryData, parseInfo);
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
semanticQuery.setParseInfo(parseInfo);
String explainSql = semanticQuery.explain(user);
if (StringUtils.isNotBlank(explainSql)) {
parseInfo.getSqlInfo().setQuerySQL(explainSql);
}
} else {
log.info("rule begin replace metrics and revise filters!");
//remove unvalid filters
validFilter(semanticQuery.getParseInfo().getDimensionFilters());
validFilter(semanticQuery.getParseInfo().getMetricFilters());
//init s2sql
semanticQuery.initS2Sql(semanticSchema, user);
QueryReq queryReq = new QueryReq();
queryReq.setQueryFilters(new QueryFilters());
queryReq.setUser(user);
}
QueryResult queryResult = semanticQuery.execute(user);
queryResult.setChatContext(semanticQuery.getParseInfo());
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(parseInfo.getDataSetId());
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user);
queryResult.setEntityInfo(entityInfo);
return queryResult;
}
private boolean checkMetricReplace(List<String> oriFields, Set<SchemaElement> metrics) {
if (CollectionUtils.isEmpty(oriFields)) {
return false;
}
if (CollectionUtils.isEmpty(metrics)) {
return false;
}
List<String> metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList());
return !oriFields.containsAll(metricNames);
}
public String reviseCorrectS2SQL(QueryDataReq queryData, SemanticParseInfo parseInfo) {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
log.info("correctorSql before replacing:{}", correctorSql);
// get where filter and having filter
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql);
List<FieldExpression> havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql);
List<Expression> addWhereConditions = new ArrayList<>();
List<Expression> addHavingConditions = new ArrayList<>();
Set<String> removeWhereFieldNames = new HashSet<>();
Set<String> removeHavingFieldNames = new HashSet<>();
// replace where filter
updateFilters(whereExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addWhereConditions, removeWhereFieldNames);
updateDateInfo(queryData, parseInfo, filedNameToValueMap,
whereExpressionList, addWhereConditions, removeWhereFieldNames);
correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
correctorSql = SqlRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames);
// replace having filter
updateFilters(havingExpressionList, queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(), addHavingConditions, removeHavingFieldNames);
correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap);
correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
correctorSql = SqlAddHelper.addWhere(correctorSql, addWhereConditions);
correctorSql = SqlAddHelper.addHaving(correctorSql, addHavingConditions);
log.info("correctorSql after replacing:{}", correctorSql);
return correctorSql;
}
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
List<String> oriMetrics = parseInfo.getMetrics().stream()
.map(SchemaElement::getName).collect(Collectors.toList());
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
log.info("before replaceMetrics:{}", correctorSql);
log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric);
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
if (CollectionUtils.isNotEmpty(oriMetrics) && !oriMetrics.contains(metric.getName())) {
fieldMap.put(oriMetrics.get(0), Pair.of(metric.getName(), metric.getDefaultAgg()));
correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap);
}
log.info("after replaceMetrics:{}", correctorSql);
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
}
@Override
public EntityInfo getEntityInfo(Long queryId, Integer parseId, User user) {
ChatParseDO chatParseDO = chatService.getParseInfo(queryId, parseId);
SemanticParseInfo parseInfo = JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
DataSetSchema dataSetSchema =
schemaService.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
return semanticService.getEntityInfo(parseInfo, dataSetSchema, user);
}
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
Map<String, Map<String, String>> filedNameToValueMap,
List<FieldExpression> fieldExpressionList,
List<Expression> addConditions,
Set<String> removeFieldNames) {
if (Objects.isNull(queryData.getDateInfo())) {
return;
}
Map<String, String> map = new HashMap<>();
String dateField = TimeDimensionEnum.DAY.getChName();
if (queryData.getDateInfo().getUnit() > 1) {
queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1));
queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(1));
}
// startDate equals to endDate
if (queryData.getDateInfo().getStartDate().equals(queryData.getDateInfo().getEndDate())) {
for (FieldExpression fieldExpression : fieldExpressionList) {
if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) {
//sql where condition exists 'equals' operator about date,just replace
if (fieldExpression.getOperator().equals(FilterOperatorEnum.EQUALS)) {
dateField = fieldExpression.getFieldName();
map.put(fieldExpression.getFieldValue().toString(),
queryData.getDateInfo().getStartDate());
filedNameToValueMap.put(dateField, map);
} else {
// first remove,then add
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
EqualsTo equalsTo = new EqualsTo();
Column column = new Column(TimeDimensionEnum.DAY.getChName());
StringValue stringValue = new StringValue(queryData.getDateInfo().getStartDate());
equalsTo.setLeftExpression(column);
equalsTo.setRightExpression(stringValue);
addConditions.add(equalsTo);
}
break;
}
}
} else {
for (FieldExpression fieldExpression : fieldExpressionList) {
if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) {
dateField = fieldExpression.getFieldName();
//just replace
if (FilterOperatorEnum.GREATER_THAN_EQUALS.getValue().equals(fieldExpression.getOperator())
|| FilterOperatorEnum.GREATER_THAN.getValue().equals(fieldExpression.getOperator())) {
map.put(fieldExpression.getFieldValue().toString(),
queryData.getDateInfo().getStartDate());
}
if (FilterOperatorEnum.MINOR_THAN_EQUALS.getValue().equals(fieldExpression.getOperator())
|| FilterOperatorEnum.MINOR_THAN.getValue().equals(fieldExpression.getOperator())) {
map.put(fieldExpression.getFieldValue().toString(),
queryData.getDateInfo().getEndDate());
}
filedNameToValueMap.put(dateField, map);
// first remove,then add
if (FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator())) {
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions);
MinorThanEquals minorThanEquals = new MinorThanEquals();
addTimeFilters(queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions);
}
}
}
}
parseInfo.setDateInfo(queryData.getDateInfo());
}
private <T extends ComparisonOperator> void addTimeFilters(String date,
T comparisonExpression,
List<Expression> addConditions) {
Column column = new Column(TimeDimensionEnum.DAY.getChName());
StringValue stringValue = new StringValue(date);
comparisonExpression.setLeftExpression(column);
comparisonExpression.setRightExpression(stringValue);
addConditions.add(comparisonExpression);
}
private void updateFilters(List<FieldExpression> fieldExpressionList,
Set<QueryFilter> metricFilters,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions,
Set<String> removeFieldNames) {
if (CollectionUtils.isEmpty(metricFilters)) {
return;
}
for (QueryFilter dslQueryFilter : metricFilters) {
for (FieldExpression fieldExpression : fieldExpressionList) {
if (fieldExpression.getFieldName() != null
&& fieldExpression.getFieldName().contains(dslQueryFilter.getName())) {
removeFieldNames.add(dslQueryFilter.getName());
if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) {
EqualsTo equalsTo = new EqualsTo();
addWhereFilters(dslQueryFilter, equalsTo, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN_EQUALS)) {
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
addWhereFilters(dslQueryFilter, greaterThanEquals, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.GREATER_THAN)) {
GreaterThan greaterThan = new GreaterThan();
addWhereFilters(dslQueryFilter, greaterThan, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN_EQUALS)) {
MinorThanEquals minorThanEquals = new MinorThanEquals();
addWhereFilters(dslQueryFilter, minorThanEquals, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.MINOR_THAN)) {
MinorThan minorThan = new MinorThan();
addWhereFilters(dslQueryFilter, minorThan, contextMetricFilters, addConditions);
} else if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.IN)) {
InExpression inExpression = new InExpression();
addWhereInFilters(dslQueryFilter, inExpression, contextMetricFilters, addConditions);
}
break;
}
}
}
}
// add in condition to sql where condition
private void addWhereInFilters(QueryFilter dslQueryFilter,
InExpression inExpression,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) {
Column column = new Column(dslQueryFilter.getName());
ExpressionList expressionList = new ExpressionList();
List<Expression> expressions = new ArrayList<>();
List<String> valueList = JsonUtil.toList(
JsonUtil.toString(dslQueryFilter.getValue()), String.class);
if (CollectionUtils.isEmpty(valueList)) {
return;
}
valueList.stream().forEach(o -> {
StringValue stringValue = new StringValue(o);
expressions.add(stringValue);
});
expressionList.setExpressions(expressions);
inExpression.setLeftExpression(column);
inExpression.setRightItemsList(expressionList);
addConditions.add(inExpression);
contextMetricFilters.stream().forEach(o -> {
if (o.getName().equals(dslQueryFilter.getName())) {
o.setValue(dslQueryFilter.getValue());
o.setOperator(dslQueryFilter.getOperator());
}
});
}
// add where filter
private <T extends ComparisonOperator> void addWhereFilters(QueryFilter dslQueryFilter,
T comparisonExpression,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) {
String columnName = dslQueryFilter.getName();
if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) {
columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")";
}
if (Objects.isNull(dslQueryFilter.getValue())) {
return;
}
Column column = new Column(columnName);
comparisonExpression.setLeftExpression(column);
if (StringUtils.isNumeric(dslQueryFilter.getValue().toString())) {
LongValue longValue = new LongValue(Long.parseLong(dslQueryFilter.getValue().toString()));
comparisonExpression.setRightExpression(longValue);
} else {
StringValue stringValue = new StringValue(dslQueryFilter.getValue().toString());
comparisonExpression.setRightExpression(stringValue);
}
addConditions.add(comparisonExpression);
contextMetricFilters.stream().forEach(o -> {
if (o.getName().equals(dslQueryFilter.getName())) {
o.setValue(dslQueryFilter.getValue());
o.setOperator(dslQueryFilter.getOperator());
}
});
}
private SemanticParseInfo getSemanticParseInfo(QueryDataReq queryData, ChatParseDO chatParseDO) {
SemanticParseInfo parseInfo = JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
return parseInfo;
}
if (CollectionUtils.isNotEmpty(queryData.getDimensions())) {
parseInfo.setDimensions(queryData.getDimensions());
}
if (CollectionUtils.isNotEmpty(queryData.getMetrics())) {
parseInfo.setMetrics(queryData.getMetrics());
}
if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) {
parseInfo.setDimensionFilters(queryData.getDimensionFilters());
}
if (CollectionUtils.isNotEmpty(queryData.getMetricFilters())) {
parseInfo.setMetricFilters(queryData.getMetricFilters());
}
if (Objects.nonNull(queryData.getDateInfo())) {
parseInfo.setDateInfo(queryData.getDateInfo());
}
return parseInfo;
}
private void validFilter(Set<QueryFilter> filters) {
for (QueryFilter queryFilter : filters) {
if (Objects.isNull(queryFilter.getValue())) {
filters.remove(queryFilter);
}
if (queryFilter.getOperator().equals(FilterOperatorEnum.IN) && CollectionUtils.isEmpty(
JsonUtil.toList(JsonUtil.toString(queryFilter.getValue()), String.class))) {
filters.remove(queryFilter);
}
}
}
@Override
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
SemanticQueryResp semanticQueryResp = new SemanticQueryResp();
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
SchemaElement schemaElement = semanticSchema.getDimension(dimensionValueReq.getElementID());
Set<Long> detectDataSetIds = new HashSet<>();
detectDataSetIds.add(schemaElement.getDataSet());
dimensionValueReq.setModelId(schemaElement.getDataSet());
List<String> dimensionValues = getDimensionValues(dimensionValueReq, detectDataSetIds);
// if the search results is null,search dimensionValue from database
if (CollectionUtils.isEmpty(dimensionValues)) {
semanticQueryResp = queryDatabase(dimensionValueReq, user);
return semanticQueryResp;
}
List<QueryColumn> columns = new ArrayList<>();
QueryColumn queryColumn = new QueryColumn();
queryColumn.setNameEn(dimensionValueReq.getBizName());
queryColumn.setShowType("CATEGORY");
queryColumn.setAuthorized(true);
queryColumn.setType("CHAR");
columns.add(queryColumn);
List<Map<String, Object>> resultList = new ArrayList<>();
dimensionValues.stream().forEach(o -> {
Map<String, Object> map = new HashMap<>();
map.put(dimensionValueReq.getBizName(), o);
resultList.add(map);
});
semanticQueryResp.setColumns(columns);
semanticQueryResp.setResultList(resultList);
return semanticQueryResp;
}
private List<String> getDimensionValues(DimensionValueReq dimensionValueReq, Set<Long> dataSetIds) {
//if value is null ,then search from NATURE_TO_VALUES
if (StringUtils.isBlank(dimensionValueReq.getValue())) {
return SearchService.getDimensionValue(dimensionValueReq);
}
//search from prefixSearch
List<HanlpMapResult> hanlpMapResultList = knowledgeService.prefixSearch(dimensionValueReq.getValue(),
2000, dataSetIds);
HanlpHelper.transLetterOriginal(hanlpMapResultList);
return hanlpMapResultList.stream()
.filter(o -> {
for (String nature : o.getNatures()) {
Long elementID = NatureHelper.getElementID(nature);
if (dimensionValueReq.getElementID().equals(elementID)) {
return true;
}
}
return false;
})
.map(mapResult -> mapResult.getName())
.collect(Collectors.toList());
}
private SemanticQueryResp queryDatabase(DimensionValueReq dimensionValueReq, User user) {
QueryStructReq queryStructReq = new QueryStructReq();
DateConf dateConf = new DateConf();
dateConf.setDateMode(DateConf.DateMode.RECENT);
dateConf.setUnit(1);
dateConf.setPeriod("DAY");
queryStructReq.setDateInfo(dateConf);
queryStructReq.setLimit(20L);
queryStructReq.setDataSetId(dimensionValueReq.getModelId());
queryStructReq.setQueryType(QueryType.ID);
List<String> groups = new ArrayList<>();
groups.add(dimensionValueReq.getBizName());
queryStructReq.setGroups(groups);
SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
return semanticInterpreter.queryByStruct(queryStructReq, user);
}
}

View File

@@ -1,142 +0,0 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.RelatedSchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.request.RecommendReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.RecommendQuestionResp;
import com.tencent.supersonic.chat.api.pojo.response.RecommendResp;
import com.tencent.supersonic.chat.server.service.ConfigService;
import com.tencent.supersonic.chat.server.service.RecommendService;
import com.tencent.supersonic.chat.server.service.SemanticService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/***
* Recommend Service impl
*/
@Service
@Slf4j
public class RecommendServiceImpl implements RecommendService {
@Autowired
private ConfigService configService;
@Autowired
private SemanticService semanticService;
@Override
public RecommendResp recommend(RecommendReq recommendReq, Long limit) {
if (Objects.isNull(limit) || limit <= 0) {
limit = Long.MAX_VALUE;
}
Long modelId = recommendReq.getModelId();
if (Objects.isNull(modelId)) {
return new RecommendResp();
}
DataSetSchema modelSchema = semanticService.getDataSetSchema(modelId);
if (Objects.isNull(modelSchema)) {
return new RecommendResp();
}
List<Long> drillDownDimensions = Lists.newArrayList();
Set<SchemaElement> metricElements = modelSchema.getMetrics();
if (recommendReq.getMetricId() != null && !CollectionUtils.isEmpty(metricElements)) {
Optional<SchemaElement> metric = metricElements.stream().filter(schemaElement ->
recommendReq.getMetricId().equals(schemaElement.getId())
&& !CollectionUtils.isEmpty(schemaElement.getRelatedSchemaElements()))
.findFirst();
if (metric.isPresent()) {
drillDownDimensions = metric.get().getRelatedSchemaElements().stream()
.map(RelatedSchemaElement::getDimensionId).collect(Collectors.toList());
}
}
final List<Long> drillDownDimensionsFinal = drillDownDimensions;
List<SchemaElement> dimensions = modelSchema.getDimensions().stream()
.filter(dim -> {
if (Objects.isNull(dim)) {
return false;
}
if (!CollectionUtils.isEmpty(drillDownDimensionsFinal)) {
return drillDownDimensionsFinal.contains(dim.getId());
} else {
return Objects.nonNull(dim.getUseCnt());
}
})
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(limit)
.map(dimSchemaDesc -> {
SchemaElement item = new SchemaElement();
item.setDataSet(modelId);
item.setName(dimSchemaDesc.getName());
item.setBizName(dimSchemaDesc.getBizName());
item.setId(dimSchemaDesc.getId());
item.setAlias(dimSchemaDesc.getAlias());
return item;
}).collect(Collectors.toList());
List<SchemaElement> metrics = modelSchema.getMetrics().stream()
.filter(metric -> Objects.nonNull(metric) && Objects.nonNull(metric.getUseCnt()))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(limit)
.map(metricSchemaDesc -> {
SchemaElement item = new SchemaElement();
item.setDataSet(modelId);
item.setName(metricSchemaDesc.getName());
item.setBizName(metricSchemaDesc.getBizName());
item.setId(metricSchemaDesc.getId());
item.setAlias(metricSchemaDesc.getAlias());
return item;
}).collect(Collectors.toList());
RecommendResp response = new RecommendResp();
response.setDimensions(dimensions);
response.setMetrics(metrics);
return response;
}
@Override
public RecommendResp recommendMetricMode(RecommendReq recommendReq, Long limit) {
return recommend(recommendReq, limit);
}
@Override
public List<RecommendQuestionResp> recommendQuestion(Long modelId) {
List<RecommendQuestionResp> recommendQuestions = new ArrayList<>();
ChatConfigFilter chatConfigFilter = new ChatConfigFilter();
chatConfigFilter.setModelId(modelId);
List<ChatConfigResp> chatConfigRespList = configService.search(chatConfigFilter, null);
if (!CollectionUtils.isEmpty(chatConfigRespList)) {
chatConfigRespList.stream().forEach(chatConfigResp -> {
if (Objects.nonNull(chatConfigResp)
&& !CollectionUtils.isEmpty(chatConfigResp.getRecommendedQuestions())) {
recommendQuestions.add(
new RecommendQuestionResp(chatConfigResp.getModelId(),
chatConfigResp.getRecommendedQuestions()));
}
});
return recommendQuestions;
}
return new ArrayList<>();
}
private List<SchemaElement> filterBlackItem(List<SchemaElement> itemList, List<Long> blackDimIdList) {
if (CollectionUtils.isEmpty(blackDimIdList) || CollectionUtils.isEmpty(itemList)) {
return itemList;
}
return itemList.stream().filter(dim -> !blackDimIdList.contains(dim.getId())).collect(Collectors.toList());
}
}

View File

@@ -1,47 +0,0 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.concurrent.TimeUnit;
@Service
@Slf4j
public class SchemaService {
public static final String ALL_CACHE = "all";
private static final Integer META_CACHE_TIME = 30;
private SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
private LoadingCache<String, SemanticSchema> cache = CacheBuilder.newBuilder()
.expireAfterWrite(META_CACHE_TIME, TimeUnit.SECONDS)
.build(
new CacheLoader<String, SemanticSchema>() {
@Override
public SemanticSchema load(String key) {
log.info("load getDomainSchemaInfo cache [{}]", key);
return new SemanticSchema(semanticInterpreter.getDataSetSchema());
}
}
);
public DataSetSchema getDataSetSchema(Long id) {
return semanticInterpreter.getDataSetSchema(id, true);
}
public SemanticSchema getSemanticSchema() {
return cache.getUnchecked(ALL_CACHE);
}
public LoadingCache<String, SemanticSchema> getCache() {
return cache;
}
}

View File

@@ -1,364 +0,0 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.github.benmanes.caffeine.cache.Cache;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.ItemNameVisibilityInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SearchResult;
import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.DictWord;
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.core.knowledge.DataSetInfoStat;
import com.tencent.supersonic.chat.core.mapper.MapperHelper;
import com.tencent.supersonic.chat.core.mapper.MatchText;
import com.tencent.supersonic.chat.core.mapper.ModelWithSemanticType;
import com.tencent.supersonic.chat.core.mapper.SearchMatchStrategy;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.ConfigService;
import com.tencent.supersonic.chat.server.service.SearchService;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.server.service.KnowledgeService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* search service impl
*/
@Service
@Slf4j
public class SearchServiceImpl implements SearchService {
private static final int RESULT_SIZE = 10;
@Autowired
private SchemaService schemaService;
@Autowired
private ChatService chatService;
@Autowired
private SearchMatchStrategy searchMatchStrategy;
@Autowired
private AgentService agentService;
@Autowired
@Qualifier("searchCaffeineCache")
private Cache<Long, Object> caffeineCache;
@Autowired
private ConfigService configService;
@Autowired
private KnowledgeService knowledgeService;
@Override
public List<SearchResult> search(QueryReq queryReq) {
// 1. check search enable
Integer agentId = queryReq.getAgentId();
if (agentId != null) {
Agent agent = agentService.getAgent(agentId);
if (!agent.enableSearch()) {
return Lists.newArrayList();
}
}
String queryText = queryReq.getQueryText();
// 2.get meta info
SemanticSchema semanticSchemaDb = schemaService.getSemanticSchema();
List<SchemaElement> metricsDb = semanticSchemaDb.getMetrics();
final Map<Long, String> modelToName = semanticSchemaDb.getDataSetIdToName();
// 3.detect by segment
List<S2Term> originals = knowledgeService.getTerms(queryText);
log.info("hanlp parse result: {}", originals);
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
Set<Long> detectDataSetIds = mapperHelper.getDataSetIds(queryReq.getDataSetId(),
agentService.getAgent(agentId));
QueryContext queryContext = new QueryContext();
BeanUtils.copyProperties(queryReq, queryContext);
Map<MatchText, List<HanlpMapResult>> regTextMap =
searchMatchStrategy.match(queryContext, originals, detectDataSetIds);
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));
// 4.get the most matching data
Optional<Entry<MatchText, List<HanlpMapResult>>> mostSimilarSearchResult = regTextMap.entrySet()
.stream()
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
.reduce((entry1, entry2) ->
entry1.getKey().getDetectSegment().length() >= entry2.getKey().getDetectSegment().length()
? entry1 : entry2);
// 5.optimize the results after the query
if (!mostSimilarSearchResult.isPresent()) {
return Lists.newArrayList();
}
Map.Entry<MatchText, List<HanlpMapResult>> searchTextEntry = mostSimilarSearchResult.get();
log.info("searchTextEntry:{},queryReq:{}", searchTextEntry, queryReq);
Set<SearchResult> searchResults = new LinkedHashSet();
DataSetInfoStat modelStat = NatureHelper.getDataSetStat(originals);
List<Long> possibleModels = getPossibleModels(queryReq, originals, modelStat, queryReq.getDataSetId());
// 5.1 priority dimension metric
boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleModels), modelToName,
searchTextEntry, searchResults);
// 5.2 process based on dimension values
MatchText matchText = searchTextEntry.getKey();
Map<String, String> natureToNameMap = getNatureToNameMap(searchTextEntry, new HashSet<>(possibleModels));
log.debug("possibleModels:{},natureToNameMap:{}", possibleModels, natureToNameMap);
for (Map.Entry<String, String> natureToNameEntry : natureToNameMap.entrySet()) {
Set<SearchResult> searchResultSet = searchDimensionValue(metricsDb, modelToName,
modelStat.getMetricDataSetCount(), existMetricAndDimension,
matchText, natureToNameMap, natureToNameEntry, queryReq.getQueryFilters());
searchResults.addAll(searchResultSet);
}
return searchResults.stream().limit(RESULT_SIZE).collect(Collectors.toList());
}
private List<Long> getPossibleModels(QueryReq queryCtx, List<S2Term> originals,
DataSetInfoStat modelStat, Long webModelId) {
if (Objects.nonNull(webModelId) && webModelId > 0) {
List<Long> result = new ArrayList<>();
result.add(webModelId);
return result;
}
List<Long> possibleModels = NatureHelper.selectPossibleDataSets(originals);
Long contextModel = chatService.getContextModel(queryCtx.getChatId());
log.debug("possibleModels:{},modelStat:{},contextModel:{}", possibleModels, modelStat, contextModel);
// If nothing is recognized or only metric are present, then add the contextModel.
if (nothingOrOnlyMetric(modelStat)) {
return Lists.newArrayList(contextModel);
}
return possibleModels;
}
private boolean nothingOrOnlyMetric(DataSetInfoStat modelStat) {
return modelStat.getMetricDataSetCount() >= 0 && modelStat.getDimensionDataSetCount() <= 0
&& modelStat.getDimensionValueDataSetCount() <= 0 && modelStat.getDataSetCount() <= 0;
}
private boolean effectiveModel(Long contextModel) {
return Objects.nonNull(contextModel) && contextModel > 0;
}
private Set<SearchResult> searchDimensionValue(List<SchemaElement> metricsDb,
Map<Long, String> modelToName,
long metricModelCount,
boolean existMetricAndDimension,
MatchText matchText,
Map<String, String> natureToNameMap,
Map.Entry<String, String> natureToNameEntry,
QueryFilters queryFilters) {
Set<SearchResult> searchResults = new LinkedHashSet();
String nature = natureToNameEntry.getKey();
String wordName = natureToNameEntry.getValue();
Long modelId = NatureHelper.getDataSetId(nature);
SchemaElementType schemaElementType = NatureHelper.convertToElementType(nature);
if (SchemaElementType.ENTITY.equals(schemaElementType)) {
return searchResults;
}
// If there are no metric/dimension, complete the metric information
SearchResult searchResult = SearchResult.builder()
.modelId(modelId)
.modelName(modelToName.get(modelId))
.recommend(matchText.getRegText() + wordName)
.schemaElementType(schemaElementType)
.subRecommend(wordName)
.build();
ItemNameVisibilityInfo visibility = (ItemNameVisibilityInfo) caffeineCache.getIfPresent(modelId);
if (visibility == null) {
visibility = configService.getVisibilityByModelId(modelId);
caffeineCache.put(modelId, visibility);
}
if (visibility.getBlackMetricNameList().contains(searchResult.getRecommend())
|| visibility.getBlackDimNameList().contains(searchResult.getRecommend())) {
return searchResults;
}
if (metricModelCount <= 0 && !existMetricAndDimension) {
if (filterByQueryFilter(wordName, queryFilters)) {
return searchResults;
}
searchResults.add(searchResult);
int metricSize = getMetricSize(natureToNameMap);
//invisibility to filter metrics
List<String> blackMetricNameList = visibility.getBlackMetricNameList();
List<String> metrics = filerMetricsByModel(metricsDb, modelId, metricSize * 3)
.stream().filter(o -> !blackMetricNameList.contains(o))
.limit(metricSize).collect(Collectors.toList());
for (String metric : metrics) {
SearchResult result = SearchResult.builder()
.modelId(modelId)
.modelName(modelToName.get(modelId))
.recommend(matchText.getRegText() + wordName + DictWordType.SPACE + metric)
.subRecommend(wordName + DictWordType.SPACE + metric)
.isComplete(false)
.build();
searchResults.add(result);
}
} else {
searchResults.add(searchResult);
}
return searchResults;
}
private int getMetricSize(Map<String, String> natureToNameMap) {
int metricSize = RESULT_SIZE / (natureToNameMap.entrySet().size());
if (metricSize <= 1) {
metricSize = 1;
}
return metricSize;
}
private boolean filterByQueryFilter(String wordName, QueryFilters queryFilters) {
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return false;
}
List<QueryFilter> filters = queryFilters.getFilters();
for (QueryFilter filter : filters) {
if (wordName.equalsIgnoreCase(String.valueOf(filter.getValue()))) {
return false;
}
}
return true;
}
protected List<String> filerMetricsByModel(List<SchemaElement> metricsDb, Long model, int metricSize) {
if (CollectionUtils.isEmpty(metricsDb)) {
return Lists.newArrayList();
}
return metricsDb.stream()
.filter(mapDO -> Objects.nonNull(mapDO) && model.equals(mapDO.getDataSet()))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.flatMap(entry -> {
List<String> result = new ArrayList<>();
result.add(entry.getName());
return result.stream();
})
.limit(metricSize).collect(Collectors.toList());
}
/***
* convert nature to name
* @param recommendTextListEntry
* @return
*/
private Map<String, String> getNatureToNameMap(Map.Entry<MatchText, List<HanlpMapResult>> recommendTextListEntry,
Set<Long> possibleModels) {
List<HanlpMapResult> recommendValues = recommendTextListEntry.getValue();
return recommendValues.stream()
.flatMap(entry -> entry.getNatures().stream()
.filter(nature -> {
if (CollectionUtils.isEmpty(possibleModels)) {
return true;
}
Long model = NatureHelper.getDataSetId(nature);
return possibleModels.contains(model);
})
.map(nature -> {
DictWord posDO = new DictWord();
posDO.setWord(entry.getName());
posDO.setNature(nature);
return posDO;
})).sorted(Comparator.comparingInt(a -> a.getWord().length()))
.collect(Collectors.toMap(DictWord::getNature, DictWord::getWord, (value1, value2) -> value1,
LinkedHashMap::new));
}
private boolean searchMetricAndDimension(Set<Long> possibleModels, Map<Long, String> modelToName,
Map.Entry<MatchText, List<HanlpMapResult>> searchTextEntry, Set<SearchResult> searchResults) {
boolean existMetric = false;
log.info("searchMetricAndDimension searchTextEntry:{}", searchTextEntry);
MatchText matchText = searchTextEntry.getKey();
List<HanlpMapResult> hanlpMapResults = searchTextEntry.getValue();
for (HanlpMapResult hanlpMapResult : hanlpMapResults) {
List<ModelWithSemanticType> dimensionMetricClassIds = hanlpMapResult.getNatures().stream()
.map(nature -> new ModelWithSemanticType(NatureHelper.getDataSetId(nature),
NatureHelper.convertToElementType(nature)))
.filter(entry -> matchCondition(entry, possibleModels)).collect(Collectors.toList());
if (CollectionUtils.isEmpty(dimensionMetricClassIds)) {
continue;
}
for (ModelWithSemanticType modelWithSemanticType : dimensionMetricClassIds) {
existMetric = true;
Long modelId = modelWithSemanticType.getModel();
SchemaElementType schemaElementType = modelWithSemanticType.getSchemaElementType();
SearchResult searchResult = SearchResult.builder()
.modelId(modelId)
.modelName(modelToName.get(modelId))
.recommend(matchText.getRegText() + hanlpMapResult.getName())
.subRecommend(hanlpMapResult.getName())
.schemaElementType(schemaElementType)
.build();
//visibility to filter metrics
ItemNameVisibilityInfo visibility = (ItemNameVisibilityInfo) caffeineCache.getIfPresent(modelId);
if (visibility == null) {
visibility = configService.getVisibilityByModelId(modelId);
caffeineCache.put(modelId, visibility);
}
if (!visibility.getBlackMetricNameList().contains(hanlpMapResult.getName())
&& !visibility.getBlackDimNameList().contains(hanlpMapResult.getName())) {
searchResults.add(searchResult);
}
}
log.info("parseResult:{},dimensionMetricClassIds:{},possibleModels:{}", hanlpMapResult,
dimensionMetricClassIds, possibleModels);
}
log.info("searchMetricAndDimension searchResults:{}", searchResults);
return existMetric;
}
private boolean matchCondition(ModelWithSemanticType entry, Set<Long> possibleModels) {
if (!(SchemaElementType.METRIC.equals(entry.getSchemaElementType()) || SchemaElementType.DIMENSION.equals(
entry.getSchemaElementType()))) {
return false;
}
if (CollectionUtils.isEmpty(possibleModels)) {
return true;
}
return possibleModels.contains(entry.getModel());
}
}

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.server.persistence.repository.StatisticsRepository;
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.server.service.StatisticsService;
import com.tencent.supersonic.headless.server.persistence.mapper.StatisticsMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;
@@ -15,11 +15,11 @@ import java.util.List;
public class StatisticsServiceImpl implements StatisticsService {
@Autowired
private StatisticsRepository statisticsRepository;
private StatisticsMapper statisticsMapper;
@Async
@Override
public void batchSaveStatistics(List<StatisticsDO> list) {
statisticsRepository.batchSaveStatistics(list);
statisticsMapper.batchSaveStatistics(list);
}
}

View File

@@ -1,63 +0,0 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.core.knowledge.DictWord;
import com.tencent.supersonic.headless.core.knowledge.builder.WordBuilderFactory;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
@Service
@Slf4j
public class WordService {
private List<DictWord> preDictWords = new ArrayList<>();
public List<DictWord> getAllDictWords() {
SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
SemanticSchema semanticSchema = new SemanticSchema(semanticInterpreter.getDataSetSchema());
List<DictWord> words = new ArrayList<>();
addWordsByType(DictWordType.DIMENSION, semanticSchema.getDimensions(), words);
addWordsByType(DictWordType.METRIC, semanticSchema.getMetrics(), words);
addWordsByType(DictWordType.ENTITY, semanticSchema.getEntities(), words);
addWordsByType(DictWordType.VALUE, semanticSchema.getDimensionValues(), words);
addWordsByType(DictWordType.TAG, semanticSchema.getTags(), words);
return words;
}
private void addWordsByType(DictWordType value, List<SchemaElement> metas, List<DictWord> natures) {
metas = distinct(metas);
List<DictWord> natureList = WordBuilderFactory.get(value).getDictWords(metas);
log.debug("nature type:{} , nature size:{}", value.name(), natureList.size());
natures.addAll(natureList);
}
public List<DictWord> getPreDictWords() {
return preDictWords;
}
public void setPreDictWords(List<DictWord> preDictWords) {
this.preDictWords = preDictWords;
}
private List<SchemaElement> distinct(List<SchemaElement> metas) {
return metas.stream()
.collect(Collectors.toMap(SchemaElement::getId, Function.identity(), (e1, e2) -> e1))
.values()
.stream()
.collect(Collectors.toList());
}
}

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.server.util;
import static com.tencent.supersonic.common.pojo.Constants.ADMIN_LOWER;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;

View File

@@ -1,43 +1,21 @@
package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.chat.core.corrector.SemanticCorrector;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.mapper.SchemaMapper;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.core.io.support.SpringFactoriesLoader;
import java.util.ArrayList;
import java.util.List;
@Slf4j
public class ComponentFactory {
private static List<SchemaMapper> schemaMappers = new ArrayList<>();
private static List<SemanticParser> semanticParsers = new ArrayList<>();
private static List<SemanticCorrector> semanticCorrectors = new ArrayList<>();
private static SemanticInterpreter semanticInterpreter;
private static List<ParseResultProcessor> parseProcessors = new ArrayList<>();
private static List<ResultProcessor> parseProcessors = new ArrayList<>();
private static List<ExecuteResultProcessor> executeProcessors = new ArrayList<>();
public static List<SchemaMapper> getSchemaMappers() {
return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers) : schemaMappers;
}
public static List<SemanticParser> getSemanticParsers() {
return CollectionUtils.isEmpty(semanticParsers) ? init(SemanticParser.class, semanticParsers) : semanticParsers;
}
public static List<SemanticCorrector> getSemanticCorrectors() {
return CollectionUtils.isEmpty(semanticCorrectors) ? init(SemanticCorrector.class,
semanticCorrectors) : semanticCorrectors;
}
public static List<ParseResultProcessor> getParseProcessors() {
return CollectionUtils.isEmpty(parseProcessors) ? init(ParseResultProcessor.class,
public static List<ResultProcessor> getParseProcessors() {
return CollectionUtils.isEmpty(parseProcessors) ? init(ResultProcessor.class,
parseProcessors) : parseProcessors;
}
@@ -46,13 +24,6 @@ public class ComponentFactory {
? init(ExecuteResultProcessor.class, executeProcessors) : executeProcessors;
}
public static SemanticInterpreter getSemanticLayer() {
if (Objects.isNull(semanticInterpreter)) {
semanticInterpreter = init(SemanticInterpreter.class);
}
return semanticInterpreter;
}
private static <T> List<T> init(Class<T> factoryType, List list) {
list.addAll(SpringFactoriesLoader.loadFactories(factoryType,
Thread.currentThread().getContextClassLoader()));

View File

@@ -1,22 +0,0 @@
package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.chat.server.config.ChatConfig;
import org.springframework.context.ApplicationEvent;
public class VisibilityEvent extends ApplicationEvent {
private static final long serialVersionUID = 1L;
private ChatConfig chatConfig;
public VisibilityEvent(Object source, ChatConfig chatConfig) {
super(source);
this.chatConfig = chatConfig;
}
public void setChatConfig(ChatConfig chatConfig) {
this.chatConfig = chatConfig;
}
public ChatConfig getChatConfig() {
return chatConfig;
}
}

View File

@@ -1,30 +0,0 @@
package com.tencent.supersonic.chat.server.util;
import com.github.benmanes.caffeine.cache.Cache;
import com.tencent.supersonic.chat.api.pojo.request.ItemNameVisibilityInfo;
import com.tencent.supersonic.chat.server.service.ConfigService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.ApplicationListener;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class VisibilityListener implements ApplicationListener<VisibilityEvent> {
@Autowired
@Qualifier("searchCaffeineCache")
private Cache<Long, Object> caffeineCache;
@Autowired
private ConfigService configService;
@Override
public void onApplicationEvent(VisibilityEvent event) {
log.info("visibility has changed,so update cache!");
ItemNameVisibilityInfo itemNameVisibility = configService.getItemNameVisibility(event.getChatConfig());
log.info("itemNameVisibility :{}", itemNameVisibility);
caffeineCache.put(event.getChatConfig().getModelId(), itemNameVisibility);
}
}

View File

@@ -1,30 +0,0 @@
<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.tencent.supersonic.chat.server.persistence.mapper.ChatContextMapper">
<resultMap id="ChatContextDO"
type="com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO">
<id column="chat_id" property="chatId"/>
<result column="modified_at" property="modifiedAt"/>
<result column="user" property="user"/>
<result column="query_text" property="queryText"/>
<result column="semantic_parse" property="semanticParse"/>
<!--<result column="ext_data" property="extData"/>-->
</resultMap>
<select id="getContextByChatId" resultMap="ChatContextDO">
select *
from s2_chat_context where chat_id=#{chatId} limit 1
</select>
<insert id="addContext" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO" >
insert into s2_chat_context (chat_id,user,query_text,semantic_parse) values (#{chatId}, #{user},#{queryText}, #{semanticParse})
</insert>
<update id="updateContext">
update s2_chat_context set query_text=#{queryText},semantic_parse=#{semanticParse} where chat_id=#{chatId}
</update>
</mapper>

View File

@@ -1,35 +0,0 @@
<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.tencent.supersonic.chat.server.persistence.mapper.StatisticsMapper">
<resultMap id="Statistics" type="com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO">
<id column="question_id" property="questionId"/>
<result column="chat_id" property="chatId"/>
<result column="user_name" property="userName"/>
<result column="query_text" property="queryText"/>
<result column="interface_name" property="interfaceName"/>
<result column="cost" property="cost"/>
<result column="type" property="type"/>
<result column="create_time" property="createTime"/>
</resultMap>
<insert id="batchSaveStatistics" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO">
insert into s2_chat_statistics
(question_id,chat_id, user_name, query_text, interface_name,cost,type ,create_time)
values
<foreach collection="list" item="item" index="index" separator=",">
(#{item.questionId}, #{item.chatId}, #{item.userName}, #{item.queryText}, #{item.interfaceName}, #{item.cost}, #{item.type},#{item.createTime})
</foreach>
</insert>
<select id="getStatistics" resultMap="Statistics">
select *
from s2_chat_statistics
where question_id = #{questionId} and user_name = #{userName}
limit 1
</select>
</mapper>

View File

@@ -1,18 +0,0 @@
package com.tencent.supersonic.chat.server.test;
import org.mybatis.spring.annotation.MapperScan;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.ComponentScan;
@SpringBootApplication(scanBasePackages = {"com.tencent.supersonic.chat"})
@ComponentScan("com.tencent.supersonic.chat")
@MapperScan("com.tencent.supersonic.chat")
public class ChatBizLauncher {
public static void main(String[] args) {
SpringApplication.run(ChatBizLauncher.class, args);
}
}

View File

@@ -1,150 +0,0 @@
package com.tencent.supersonic.chat.server.test.context;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.when;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.core.config.DefaultMetric;
import com.tencent.supersonic.chat.core.config.DefaultMetricInfo;
import com.tencent.supersonic.chat.core.config.EntityInternalDetail;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.server.persistence.mapper.ChatContextMapper;
import com.tencent.supersonic.chat.server.persistence.repository.impl.ChatContextRepositoryImpl;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.QueryService;
import com.tencent.supersonic.chat.server.service.impl.ConfigServiceImpl;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelSchemaResp;
import com.tencent.supersonic.headless.server.pojo.DimensionFilter;
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.ModelService;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.mockito.Mockito;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate;
@Configuration
public class MockBeansConfiguration {
public static void getOrCreateContextMock(ChatService chatService) {
ChatContext context = new ChatContext();
context.setChatId(1);
when(chatService.getOrCreateContext(1)).thenReturn(context);
}
public static void buildHttpSemanticServiceImpl(List<DimSchemaResp> dimensionDescs,
List<MetricSchemaResp> metricDescs) {
DefaultMetric defaultMetricDesc = new DefaultMetric();
defaultMetricDesc.setUnit(3);
defaultMetricDesc.setPeriod(Constants.DAY);
List<DimSchemaResp> dimensionDescs1 = new ArrayList<>();
DimSchemaResp dimensionDesc = new DimSchemaResp();
dimensionDesc.setId(162L);
dimensionDescs1.add(dimensionDesc);
DimSchemaResp dimensionDesc2 = new DimSchemaResp();
dimensionDesc2.setId(163L);
dimensionDesc2.setBizName("song_name");
dimensionDesc2.setName("歌曲名");
EntityInternalDetail entityInternalDetailDesc = new EntityInternalDetail();
entityInternalDetailDesc.setDimensionList(new ArrayList<>(Arrays.asList(dimensionDesc2)));
MetricSchemaResp metricDesc = new MetricSchemaResp();
metricDesc.setId(877L);
metricDesc.setBizName("js_play_cnt");
metricDesc.setName("结算播放量");
entityInternalDetailDesc.setMetricList(new ArrayList<>(Arrays.asList(metricDesc)));
ModelSchemaResp modelSchemaDesc = new ModelSchemaResp();
modelSchemaDesc.setDimensions(dimensionDescs);
modelSchemaDesc.setMetrics(metricDescs);
}
public static void getModelExtendMock(ConfigServiceImpl configService) {
DefaultMetricInfo defaultMetricInfo = new DefaultMetricInfo();
defaultMetricInfo.setUnit(3);
defaultMetricInfo.setPeriod(Constants.DAY);
List<DefaultMetricInfo> defaultMetricInfos = new ArrayList<>();
defaultMetricInfos.add(defaultMetricInfo);
ChatConfigResp chaConfigDesc = new ChatConfigResp();
when(configService.fetchConfigByModelId(anyLong())).thenReturn(chaConfigDesc);
}
public static void dimensionDescBuild(DimensionService dimensionService, List<DimensionResp> dimensionDescs) {
when(dimensionService.getDimensions(any(DimensionFilter.class))).thenReturn(dimensionDescs);
}
public static void metricDescBuild(MetricService metricService, List<MetricResp> metricDescs) {
when(metricService.getMetrics(any(MetaFilter.class))).thenReturn(metricDescs);
}
public static DimSchemaResp getDimensionDesc(Long id, String bizName, String name) {
DimSchemaResp dimensionDesc = new DimSchemaResp();
dimensionDesc.setId(id);
dimensionDesc.setName(name);
dimensionDesc.setBizName(bizName);
return dimensionDesc;
}
public static MetricSchemaResp getMetricDesc(Long id, String bizName, String name) {
MetricSchemaResp dimensionDesc = new MetricSchemaResp();
dimensionDesc.setId(id);
dimensionDesc.setName(name);
dimensionDesc.setBizName(bizName);
return dimensionDesc;
}
@Bean
public ChatContextRepositoryImpl getChatContextRepository() {
return Mockito.mock(ChatContextRepositoryImpl.class);
}
@Bean
public QueryService getQueryService() {
return Mockito.mock(QueryService.class);
}
@Bean
public DimensionService getDimensionService() {
return Mockito.mock(DimensionService.class);
}
@Bean
public MetricService getMetricService() {
return Mockito.mock(MetricService.class);
}
//queryDimensionDescs
@Bean
public ModelService getModelService() {
return Mockito.mock(ModelService.class);
}
@Bean
public ChatContextMapper getChatContextMapper() {
return Mockito.mock(ChatContextMapper.class);
}
@Bean
public ConfigServiceImpl getModelExtendService() {
return Mockito.mock(ConfigServiceImpl.class);
}
@Bean
public RestTemplate restTemplate() {
return new RestTemplate();
}
}

View File

@@ -1,118 +0,0 @@
package com.tencent.supersonic.chat.server.test.context;
import com.google.gson.Gson;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.DateConf;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import lombok.Data;
public class SemanticParseObjectHelper {
public static SemanticParseInfo copy(SemanticParseInfo semanticParseInfo) {
Gson g = new Gson();
return g.fromJson(g.toJson(semanticParseInfo), SemanticParseInfo.class);
}
public static SemanticParseInfo getSemanticParseInfo(String json) {
Gson gson = new Gson();
SemanticParseJson semanticParseJson = gson.fromJson(json, SemanticParseJson.class);
if (semanticParseJson != null) {
return getSemanticParseInfo(semanticParseJson);
}
return null;
}
private static SemanticParseInfo getSemanticParseInfo(SemanticParseJson semanticParseJson) {
Long model = semanticParseJson.getModel();
Set<SchemaElement> dimensionList = new LinkedHashSet();
Set<SchemaElement> metricList = new LinkedHashSet();
Set<QueryFilter> chatFilters = new LinkedHashSet();
if (semanticParseJson.getFilter() != null && semanticParseJson.getFilter().size() > 0) {
for (List<String> filter : semanticParseJson.getFilter()) {
chatFilters.add(getChatFilter(filter));
}
}
for (String dim : semanticParseJson.getDimensions()) {
dimensionList.add(getDimension(dim, model));
}
for (String metric : semanticParseJson.getMetrics()) {
metricList.add(getMetric(metric, model));
}
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setDimensionFilters(chatFilters);
semanticParseInfo.setAggType(semanticParseJson.getAggregateType());
semanticParseInfo.setQueryMode(semanticParseJson.getQueryMode());
semanticParseInfo.setMetrics(metricList);
semanticParseInfo.setDimensions(dimensionList);
DateConf dateInfo = getDateInfoAgo(semanticParseJson.getDay());
semanticParseInfo.setDateInfo(dateInfo);
return semanticParseInfo;
}
private static DateConf getDateInfoAgo(int dayAgo) {
if (dayAgo > 0) {
DateConf dateInfo = new DateConf();
dateInfo.setUnit(dayAgo);
dateInfo.setDateMode(DateConf.DateMode.RECENT);
return dateInfo;
}
return null;
}
private static QueryFilter getChatFilter(List<String> filters) {
if (filters.size() > 1) {
QueryFilter chatFilter = new QueryFilter();
chatFilter.setBizName(filters.get(1));
chatFilter.setOperator(FilterOperatorEnum.getSqlOperator(filters.get(2)));
if (filters.size() > 4) {
List<String> valuse = new ArrayList<>();
valuse.addAll(filters.subList(3, filters.size()));
chatFilter.setValue(valuse);
} else {
chatFilter.setValue(filters.get(3));
}
return chatFilter;
}
return null;
}
private static SchemaElement getMetric(String bizName, Long modelId) {
SchemaElement metric = new SchemaElement();
metric.setBizName(bizName);
return metric;
}
private static SchemaElement getDimension(String bizName, Long modelId) {
SchemaElement dimension = new SchemaElement();
dimension.setBizName(bizName);
return dimension;
}
@Data
public static class SemanticParseJson {
private Long model;
private String queryMode;
private AggregateTypeEnum aggregateType;
private Integer day;
private List<String> dimensions;
private List<String> metrics;
private List<List<String>> filter;
}
}

View File

@@ -1,80 +0,0 @@
package com.tencent.supersonic.chat.server.utils;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateModeUtils;
import com.tencent.supersonic.common.util.SqlFilterUtils;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* QueryReqBuilderTest
*/
class QueryReqBuilderTest {
@Test
void buildS2SQLReq() {
init();
QueryStructReq queryStructReq = new QueryStructReq();
queryStructReq.setDataSetId(1L);
queryStructReq.setDataSetName("内容库");
queryStructReq.setQueryType(QueryType.METRIC);
Aggregator aggregator = new Aggregator();
aggregator.setFunc(AggOperatorEnum.UNKNOWN);
aggregator.setColumn("pv");
queryStructReq.setAggregators(Arrays.asList(aggregator));
queryStructReq.setGroups(Arrays.asList("department"));
DateConf dateConf = new DateConf();
dateConf.setDateMode(DateMode.LIST);
dateConf.setDateList(Arrays.asList("2023-08-01"));
queryStructReq.setDateInfo(dateConf);
List<Order> orders = new ArrayList<>();
Order order = new Order();
order.setColumn("uv");
orders.add(order);
queryStructReq.setOrders(orders);
QuerySqlReq querySQLReq = queryStructReq.convert();
Assert.assertEquals(
"SELECT department, SUM(pv) AS pv FROM 内容库 "
+ "WHERE (sys_imp_date IN ('2023-08-01')) GROUP "
+ "BY department ORDER BY uv LIMIT 2000", querySQLReq.getSql());
queryStructReq.setQueryType(QueryType.TAG);
querySQLReq = queryStructReq.convert();
Assert.assertEquals(
"SELECT department, pv FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) "
+ "ORDER BY uv LIMIT 2000",
querySQLReq.getSql());
}
private void init() {
MockedStatic<ContextUtils> mockContextUtils = Mockito.mockStatic(ContextUtils.class);
SqlFilterUtils sqlFilterUtils = new SqlFilterUtils();
mockContextUtils.when(() -> ContextUtils.getBean(SqlFilterUtils.class)).thenReturn(sqlFilterUtils);
DateModeUtils dateModeUtils = new DateModeUtils();
mockContextUtils.when(() -> ContextUtils.getBean(DateModeUtils.class)).thenReturn(dateModeUtils);
dateModeUtils.setSysDateCol("sys_imp_date");
dateModeUtils.setSysDateWeekCol("sys_imp_week");
dateModeUtils.setSysDateMonthCol("sys_imp_month");
}
}