mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-04-28 03:14:18 +08:00
Compare commits
33 Commits
v0.9.8
...
df17c27704
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df17c27704 | ||
|
|
5b45cfbad7 | ||
|
|
0fc29304a8 | ||
|
|
639d1a78da | ||
|
|
82c63a7f22 | ||
|
|
593597fe26 | ||
|
|
224c114d20 | ||
|
|
722f40cdf7 | ||
|
|
cb183b7ac8 | ||
|
|
244052e806 | ||
|
|
e990b37433 | ||
|
|
534da49309 | ||
|
|
5a8c20a00b | ||
|
|
e8c9855163 | ||
|
|
ba1938f04b | ||
|
|
5e22b412c6 | ||
|
|
87729956e8 | ||
|
|
ea6a9ebc5f | ||
|
|
14a19a901f | ||
|
|
ca4545bb15 | ||
|
|
e0e167fd40 | ||
|
|
d4a9d5a7e6 | ||
|
|
c9c6dc4e44 | ||
|
|
524ec38edc | ||
|
|
9edcb9f91c | ||
|
|
6f2af79756 | ||
|
|
7be885d9c8 | ||
|
|
9a05b5cce6 | ||
|
|
3b65b1c80b | ||
|
|
1e5bf7909e | ||
|
|
6a4458a572 | ||
|
|
1867447b6e | ||
|
|
ff7fb50030 |
34
.github/workflows/docker-publish.yml
vendored
Normal file
34
.github/workflows/docker-publish.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: Docker Publish
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version of the Docker image'
|
||||
required: true
|
||||
default: 'latest'
|
||||
|
||||
jobs:
|
||||
build-and-publish:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Build and publish Docker image
|
||||
run: |
|
||||
VERSION=${{ github.event.inputs.version }}
|
||||
chmod +x docker/docker-build.sh
|
||||
chmod +x docker/docker-publish.sh
|
||||
sh docker/docker-build.sh $VERSION
|
||||
sh docker/docker-publish.sh $VERSION
|
||||
@@ -7,11 +7,11 @@
|
||||
## SuperSonic [0.9.8] - 2024-11-01
|
||||
- Add LLM management module to reuse connection across agents.
|
||||
- Add ChatAPP configuration sub-module in Agent Management.
|
||||
- Add dimension value management sub-module.
|
||||
- Enhance dimension value management sub-module.
|
||||
- Enhance memory management and term management sub-module.
|
||||
- Support semantic translation of complex S2SQL.
|
||||
- Enhance semantic translation of complex S2SQL.
|
||||
- Enhance user experience in Chat UI.
|
||||
- Introduce LLM-based semantic corrector and data interpreter.
|
||||
- Introduce new experience in Chat UI.
|
||||
|
||||
## SuperSonic [0.9.2] - 2024-06-01
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
# SuperSonic
|
||||
|
||||
SuperSonic is the next-generation BI platform that unifies **Chat BI** (powered by LLM) and **Headless BI** (powered by semantic layer) paradigms. This unification ensures that Chat BI has access to the same curated and governed semantic data models as traditional BI. Furthermore, the implementation of both paradigms benefit from each other:
|
||||
SuperSonic is the next-generation AI+BI platform that unifies **Chat BI** (powered by LLM) and **Headless BI** (powered by semantic layer) paradigms. This unification ensures that Chat BI has access to the same curated and governed semantic data models as traditional BI. Furthermore, the implementation of both paradigms benefit from each other:
|
||||
|
||||
- Chat BI's Text2SQL gets augmented with context-retrieval from semantic models.
|
||||
- Headless BI's query interface gets extended with natural language API.
|
||||
|
||||
@@ -4,7 +4,13 @@ sbinDir=$(cd "$(dirname "$0")"; pwd)
|
||||
chmod +x $sbinDir/supersonic-common.sh
|
||||
source $sbinDir/supersonic-common.sh
|
||||
cd $projectDir
|
||||
MVN_VERSION=$(mvn help:evaluate -Dexpression=project.version | grep -e '^[^\[]')
|
||||
|
||||
MVN_VERSION=$(mvn help:evaluate -Dexpression=project.version -q -DforceStdout | grep -v '^\[' | sed -n '/^[0-9]/p')
|
||||
if [ -z "$MVN_VERSION" ]; then
|
||||
echo "Failed to retrieve Maven project version."
|
||||
exit 1
|
||||
fi
|
||||
echo "Maven project version: $MVN_VERSION"
|
||||
|
||||
cd $baseDir
|
||||
service=$1
|
||||
|
||||
@@ -11,6 +11,7 @@ import java.util.Map;
|
||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_CREATE_TIME;
|
||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_IS_ADMIN;
|
||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_DISPLAY_NAME;
|
||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_EMAIL;
|
||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_ID;
|
||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_NAME;
|
||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_PASSWORD;
|
||||
@@ -38,6 +39,7 @@ public class UserWithPassword extends User {
|
||||
claims.put(TOKEN_USER_NAME, StringUtils.isEmpty(user.getName()) ? "" : user.getName());
|
||||
claims.put(TOKEN_USER_PASSWORD,
|
||||
StringUtils.isEmpty(user.getPassword()) ? "" : user.getPassword());
|
||||
claims.put(TOKEN_USER_EMAIL, StringUtils.isEmpty(user.getEmail()) ? "" : user.getEmail());
|
||||
claims.put(TOKEN_USER_DISPLAY_NAME, user.getDisplayName());
|
||||
claims.put(TOKEN_CREATE_TIME, System.currentTimeMillis());
|
||||
claims.put(TOKEN_IS_ADMIN, user.getIsAdmin());
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/** the entity info about the model */
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@ToString
|
||||
@NoArgsConstructor
|
||||
public class Entity {
|
||||
|
||||
/** uniquely identifies an entity */
|
||||
private Long entityId;
|
||||
|
||||
/** entity name list */
|
||||
private List<String> names;
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class EntityRichInfoResp {
|
||||
/** entity alias */
|
||||
private List<String> names;
|
||||
|
||||
private SchemaElement dimItem;
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.api.pojo.response;
|
||||
import com.tencent.supersonic.common.pojo.QueryAuthorization;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.headless.api.pojo.AggregateInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
@@ -26,7 +25,6 @@ public class QueryResult {
|
||||
private String textResult;
|
||||
private String textSummary;
|
||||
private Long queryTimeCost;
|
||||
private EntityInfo entityInfo;
|
||||
private List<SchemaElement> recommendedDimensions;
|
||||
private AggregateInfo aggregateInfo;
|
||||
private String errorMsg;
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.server.memory.MemoryReviewTask;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -12,6 +14,7 @@ import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
@@ -33,6 +36,8 @@ public class Agent extends RecordInfo {
|
||||
private String toolConfig;
|
||||
private Map<String, ChatApp> chatAppConfig = Collections.emptyMap();
|
||||
private VisualConfig visualConfig;
|
||||
private List<String> admins = Lists.newArrayList();
|
||||
private List<String> viewers = Lists.newArrayList();
|
||||
|
||||
public List<String> getTools(AgentToolType type) {
|
||||
Map<String, Object> map = JSONObject.parseObject(toolConfig, Map.class);
|
||||
@@ -105,4 +110,9 @@ public class Agent extends RecordInfo {
|
||||
.filter(dataSetIds -> !CollectionUtils.isEmpty(dataSetIds))
|
||||
.flatMap(Collection::stream).collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
public boolean contains(User user, Function<Agent, List<String>> list) {
|
||||
return list.apply(this).contains(user.getName());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ import java.util.stream.Collectors;
|
||||
public class PlainTextExecutor implements ChatQueryExecutor {
|
||||
|
||||
public static final String APP_KEY = "SMALL_TALK";
|
||||
private static final String INSTRUCTION = "" + "#Role: You are a nice person to talk to."
|
||||
private static final String INSTRUCTION = "#Role: You are a nice person to talk to."
|
||||
+ "\n#Task: Respond quickly and nicely to the user."
|
||||
+ "\n#Rules: 1.ALWAYS use the same language as the `#Current Input`."
|
||||
+ "\n#History Inputs: %s" + "\n#Current Input: %s" + "\n#Response: ";
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.memory;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
@@ -21,6 +22,7 @@ import org.springframework.scheduling.annotation.Scheduled;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
@@ -57,28 +59,44 @@ public class MemoryReviewTask {
|
||||
|
||||
@Scheduled(fixedDelay = 60 * 1000)
|
||||
public void review() {
|
||||
memoryService.getMemoriesForLlmReview().stream().forEach(memory -> {
|
||||
try {
|
||||
processMemory(memory);
|
||||
} catch (Exception e) {
|
||||
log.error("Exception occurred while processing memory with id {}: {}",
|
||||
memory.getId(), e.getMessage(), e);
|
||||
List<Agent> agentList = agentService.getAgents();
|
||||
for (Agent agent : agentList) {
|
||||
if (!agent.enableMemoryReview()) {
|
||||
continue;
|
||||
}
|
||||
});
|
||||
ChatMemoryFilter chatMemoryFilter =
|
||||
ChatMemoryFilter.builder().agentId(agent.getId()).build();
|
||||
memoryService.getMemories(chatMemoryFilter).stream().forEach(memory -> {
|
||||
try {
|
||||
processMemory(memory, agent);
|
||||
} catch (Exception e) {
|
||||
log.error("Exception occurred while processing memory with id {}: {}",
|
||||
memory.getId(), e.getMessage(), e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private void processMemory(ChatMemoryDO m) {
|
||||
Agent chatAgent = agentService.getAgent(m.getAgentId());
|
||||
if (Objects.isNull(chatAgent)) {
|
||||
private void processMemory(ChatMemoryDO m, Agent agent) {
|
||||
if (Objects.isNull(agent)) {
|
||||
log.warn("Agent id {} not found or memory review disabled", m.getAgentId());
|
||||
return;
|
||||
}
|
||||
|
||||
ChatApp chatApp = chatAgent.getChatAppConfig().get(APP_KEY);
|
||||
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
|
||||
if (Objects.isNull(chatApp) || !chatApp.isEnable()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 如果大模型已经评估过,则不再评估
|
||||
if (Objects.nonNull(m.getLlmReviewRet())) {
|
||||
// directly enable memory if the LLM determines it positive
|
||||
if (MemoryReviewResult.POSITIVE.equals(m.getLlmReviewRet())) {
|
||||
memoryService.enableMemory(m);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
String promptStr = createPromptString(m, chatApp.getPrompt());
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
|
||||
@@ -90,7 +108,7 @@ public class MemoryReviewTask {
|
||||
response);
|
||||
processResponse(response, m);
|
||||
} else {
|
||||
log.debug("ChatLanguageModel not found for agent:{}", chatAgent.getId());
|
||||
log.debug("ChatLanguageModel not found for agent:{}", agent.getId());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -83,6 +83,9 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
if (Objects.isNull(parseContext.getRequest().getSelectedParse())) {
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
||||
if (parseContext.enableLLM()) {
|
||||
queryNLReq.setText2SQLType(Text2SQLType.NONE);
|
||||
}
|
||||
|
||||
// for every requested dataSet, recursively invoke rule-based parser with different
|
||||
// mapModes
|
||||
@@ -100,6 +103,9 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
queryNLReq.setMapModeEnum(MapModeEnum.LOOSE);
|
||||
doParse(queryNLReq, parseResp);
|
||||
}
|
||||
if (parseResp.getSelectedParses().isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
// for one dataset select the top 1 parse after sorting
|
||||
SemanticParseInfo.sort(parseResp.getSelectedParses());
|
||||
candidateParses.add(parseResp.getSelectedParses().get(0));
|
||||
|
||||
@@ -40,4 +40,8 @@ public class AgentDO {
|
||||
private String chatModelConfig;
|
||||
|
||||
private String visualConfig;
|
||||
|
||||
private String admin;
|
||||
|
||||
private String viewer;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.IdType;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@@ -7,9 +10,10 @@ import java.util.Date;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
@TableName("s2_chat_config")
|
||||
public class ChatConfigDO {
|
||||
|
||||
/** database auto-increment primary key */
|
||||
@TableId(type = IdType.AUTO)
|
||||
private Long id;
|
||||
|
||||
private Long modelId;
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.TableField;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.time.Instant;
|
||||
|
||||
@Data
|
||||
@TableName("s2_chat_context")
|
||||
public class ChatContextDO implements Serializable {
|
||||
|
||||
@TableId
|
||||
private Integer chatId;
|
||||
private Instant modifiedAt;
|
||||
@TableField("query_user")
|
||||
private String user;
|
||||
private String queryText;
|
||||
private String semanticParse;
|
||||
|
||||
@@ -5,9 +5,6 @@ import lombok.Data;
|
||||
@Data
|
||||
public class QueryDO {
|
||||
|
||||
public String aggregator = "trend";
|
||||
public String startTime;
|
||||
public String endTime;
|
||||
private long id;
|
||||
private long questionId;
|
||||
private String createTime;
|
||||
@@ -25,7 +22,6 @@ public class QueryDO {
|
||||
private int topNum;
|
||||
private String querySql;
|
||||
private Object queryColumn;
|
||||
private Object entityInfo;
|
||||
private int score;
|
||||
private String feedback;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.mapper;
|
||||
|
||||
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
|
||||
import com.tencent.supersonic.chat.server.config.ChatConfigFilterInternal;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatConfigDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
@@ -7,11 +8,7 @@ import org.apache.ibatis.annotations.Mapper;
|
||||
import java.util.List;
|
||||
|
||||
@Mapper
|
||||
public interface ChatConfigMapper {
|
||||
|
||||
Long addConfig(ChatConfigDO chaConfigPO);
|
||||
|
||||
Long editConfig(ChatConfigDO chaConfigPO);
|
||||
public interface ChatConfigMapper extends BaseMapper<ChatConfigDO> {
|
||||
|
||||
List<ChatConfigDO> search(ChatConfigFilterInternal filterInternal);
|
||||
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.mapper;
|
||||
|
||||
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
@Mapper
|
||||
public interface ChatContextMapper {
|
||||
public interface ChatContextMapper extends BaseMapper<ChatContextDO> {
|
||||
|
||||
ChatContextDO getContextByChatId(Integer chatId);
|
||||
|
||||
int updateContext(ChatContextDO contextDO);
|
||||
|
||||
int addContext(ChatContextDO contextDO);
|
||||
}
|
||||
|
||||
@@ -32,15 +32,15 @@ public class ChatConfigRepositoryImpl implements ChatConfigRepository {
|
||||
@Override
|
||||
public Long createConfig(ChatConfig chaConfig) {
|
||||
ChatConfigDO chaConfigDO = chatConfigHelper.chatConfig2DO(chaConfig);
|
||||
chatConfigMapper.addConfig(chaConfigDO);
|
||||
chatConfigMapper.insert(chaConfigDO);
|
||||
return chaConfigDO.getId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long updateConfig(ChatConfig chaConfig) {
|
||||
ChatConfigDO chaConfigDO = chatConfigHelper.chatConfig2DO(chaConfig);
|
||||
|
||||
return chatConfigMapper.editConfig(chaConfigDO);
|
||||
chatConfigMapper.updateById(chaConfigDO);
|
||||
return chaConfigDO.getId();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -35,12 +35,7 @@ public class ChatContextRepositoryImpl implements ChatContextRepository {
|
||||
|
||||
@Override
|
||||
public void updateContext(ChatContext chatCtx) {
|
||||
ChatContextDO context = cast(chatCtx);
|
||||
if (chatContextMapper.getContextByChatId(chatCtx.getChatId()) == null) {
|
||||
chatContextMapper.addContext(context);
|
||||
} else {
|
||||
chatContextMapper.updateContext(context);
|
||||
}
|
||||
chatContextMapper.insertOrUpdate(cast(chatCtx));
|
||||
}
|
||||
|
||||
private ChatContext cast(ChatContextDO contextDO) {
|
||||
|
||||
@@ -7,7 +7,7 @@ import java.util.Map;
|
||||
|
||||
public class PluginQueryManager {
|
||||
|
||||
private static Map<String, PluginSemanticQuery> pluginQueries = new HashMap<>();
|
||||
private static final Map<String, PluginSemanticQuery> pluginQueries = new HashMap<>();
|
||||
|
||||
public static void register(String queryMode, PluginSemanticQuery pluginSemanticQuery) {
|
||||
pluginQueries.put(queryMode, pluginSemanticQuery);
|
||||
|
||||
@@ -117,8 +117,12 @@ public class MetricRatioCalcProcessor implements ExecuteResultProcessor {
|
||||
|
||||
CompletableFuture.allOf(metricInfoRoll, metricInfoOver).join();
|
||||
|
||||
metricInfo.setName(metricInfoRoll.get().getName());
|
||||
metricInfo.setValue(metricInfoRoll.get().getValue());
|
||||
if (metricInfoRoll.get().getName() != null) {
|
||||
metricInfo.setName(metricInfoRoll.get().getName());
|
||||
}
|
||||
if (metricInfoOver.get().getValue() != null) {
|
||||
metricInfo.setValue(metricInfoRoll.get().getValue());
|
||||
}
|
||||
metricInfo.getStatistics().putAll(metricInfoRoll.get().getStatistics());
|
||||
metricInfo.getStatistics().putAll(metricInfoOver.get().getStatistics());
|
||||
|
||||
|
||||
@@ -39,7 +39,8 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
||||
}
|
||||
|
||||
private void fillSimilarMetric(SemanticParseInfo parseInfo) {
|
||||
if (!parseInfo.getQueryType().equals(QueryType.AGGREGATE)
|
||||
if (Objects.isNull(parseInfo.getQueryType())
|
||||
|| !parseInfo.getQueryType().equals(QueryType.AGGREGATE)
|
||||
|| parseInfo.getMetrics().size() > METRIC_RECOMMEND_SIZE
|
||||
|| CollectionUtils.isEmpty(parseInfo.getMetrics())) {
|
||||
return;
|
||||
|
||||
@@ -1,13 +1,30 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
@@ -15,20 +32,21 @@ import java.util.stream.Collectors;
|
||||
/**
|
||||
* ParseInfoFormatProcessor formats parse info to make it more readable to the users.
|
||||
**/
|
||||
@Slf4j
|
||||
public class ParseInfoFormatProcessor implements ParseResultProcessor {
|
||||
@Override
|
||||
public void process(ParseContext parseContext) {
|
||||
parseContext.getResponse().getSelectedParses().forEach(p -> {
|
||||
if (PluginQueryManager.isPluginQuery(p.getQueryMode())
|
||||
|| "PLAIN_TEXT".equals(p.getQueryMode())) {
|
||||
if (Objects.isNull(p.getDataSet()) || Objects.isNull(p.getSqlInfo().getParsedS2SQL())) {
|
||||
return;
|
||||
}
|
||||
|
||||
formatNL2SQLParseInfo(p);
|
||||
buildParseInfoFromSQL(p);
|
||||
buildTextInfo(p);
|
||||
});
|
||||
}
|
||||
|
||||
private static void formatNL2SQLParseInfo(SemanticParseInfo parseInfo) {
|
||||
private void buildTextInfo(SemanticParseInfo parseInfo) {
|
||||
StringBuilder textBuilder = new StringBuilder();
|
||||
textBuilder.append("**数据集:** ").append(parseInfo.getDataSet().getName()).append(" ");
|
||||
List<String> metricNames = parseInfo.getMetrics().stream().map(SchemaElement::getName)
|
||||
@@ -60,4 +78,198 @@ public class ParseInfoFormatProcessor implements ParseResultProcessor {
|
||||
}
|
||||
parseInfo.setTextInfo(textBuilder.toString());
|
||||
}
|
||||
|
||||
|
||||
private void buildParseInfoFromSQL(SemanticParseInfo parseInfo) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
String s2SQL = sqlInfo.getCorrectedS2SQL();
|
||||
if (StringUtils.isBlank(s2SQL)) {
|
||||
return;
|
||||
}
|
||||
|
||||
parseQueryType(parseInfo);
|
||||
List<FieldExpression> expressions = SqlSelectHelper.getFilterExpression(s2SQL);
|
||||
Long dataSetId = parseInfo.getDataSetId();
|
||||
SemanticLayerService semanticLayerService =
|
||||
ContextUtils.getBean(SemanticLayerService.class);
|
||||
DataSetSchema dsSchema = semanticLayerService.getDataSetSchema(dataSetId);
|
||||
|
||||
// extract date filter from S2SQL
|
||||
try {
|
||||
if (parseInfo.getDateInfo() == null && !CollectionUtils.isEmpty(expressions)) {
|
||||
parseInfo.setDateInfo(extractDateFilter(expressions, dsSchema));
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("failed to extract date range:", e);
|
||||
}
|
||||
|
||||
// extract dimension filters from S2SQL
|
||||
try {
|
||||
List<QueryFilter> queryFilters = extractDimensionFilter(dsSchema, expressions);
|
||||
parseInfo.getDimensionFilters().addAll(queryFilters);
|
||||
} catch (Exception e) {
|
||||
log.error("failed to extract dimension filters:", e);
|
||||
}
|
||||
|
||||
// extract metrics from S2SQL
|
||||
List<String> allFields =
|
||||
filterDateField(dsSchema, SqlSelectHelper.getAllSelectFields(s2SQL));
|
||||
Set<SchemaElement> metrics = matchSchemaElements(allFields, dsSchema.getMetrics());
|
||||
parseInfo.setMetrics(metrics);
|
||||
|
||||
// extract dimensions from S2SQL
|
||||
if (QueryType.AGGREGATE.equals(parseInfo.getQueryType())) {
|
||||
List<String> groupByFields = SqlSelectHelper.getGroupByFields(s2SQL);
|
||||
List<String> groupByDimensions = filterDateField(dsSchema, groupByFields);
|
||||
parseInfo.setDimensions(
|
||||
matchSchemaElements(groupByDimensions, dsSchema.getDimensions()));
|
||||
} else if (QueryType.DETAIL.equals(parseInfo.getQueryType())) {
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(s2SQL);
|
||||
List<String> selectDimensions = filterDateField(dsSchema, selectFields);
|
||||
parseInfo
|
||||
.setDimensions(matchSchemaElements(selectDimensions, dsSchema.getDimensions()));
|
||||
}
|
||||
}
|
||||
|
||||
private Set<SchemaElement> matchSchemaElements(List<String> allFields,
|
||||
Set<SchemaElement> elements) {
|
||||
return elements.stream().filter(schemaElement -> {
|
||||
if (CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
return allFields.contains(schemaElement.getName());
|
||||
}
|
||||
Set<String> allFieldsSet = new HashSet<>(allFields);
|
||||
Set<String> aliasSet = new HashSet<>(schemaElement.getAlias());
|
||||
List<String> intersection =
|
||||
allFieldsSet.stream().filter(aliasSet::contains).collect(Collectors.toList());
|
||||
return allFields.contains(schemaElement.getName())
|
||||
|| !CollectionUtils.isEmpty(intersection);
|
||||
}).collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
private List<String> filterDateField(DataSetSchema dataSetSchema, List<String> allFields) {
|
||||
return allFields.stream().filter(entry -> !isPartitionDimension(dataSetSchema, entry))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private List<QueryFilter> extractDimensionFilter(DataSetSchema dsSchema,
|
||||
List<FieldExpression> fieldExpressions) {
|
||||
|
||||
Map<String, SchemaElement> fieldNameToElement = getNameToElement(dsSchema);
|
||||
List<QueryFilter> result = Lists.newArrayList();
|
||||
for (FieldExpression expression : fieldExpressions) {
|
||||
QueryFilter dimensionFilter = new QueryFilter();
|
||||
dimensionFilter.setValue(expression.getFieldValue());
|
||||
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
|
||||
if (Objects.isNull(schemaElement)
|
||||
|| isPartitionDimension(dsSchema, schemaElement.getName())) {
|
||||
continue;
|
||||
}
|
||||
dimensionFilter.setName(schemaElement.getName());
|
||||
dimensionFilter.setBizName(schemaElement.getBizName());
|
||||
dimensionFilter.setElementID(schemaElement.getId());
|
||||
|
||||
FilterOperatorEnum operatorEnum =
|
||||
FilterOperatorEnum.getSqlOperator(expression.getOperator());
|
||||
dimensionFilter.setOperator(operatorEnum);
|
||||
dimensionFilter.setFunction(expression.getFunction());
|
||||
result.add(dimensionFilter);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private DateConf extractDateFilter(List<FieldExpression> fieldExpressions,
|
||||
DataSetSchema dataSetSchema) {
|
||||
List<FieldExpression> dateExpressions = fieldExpressions.stream().filter(
|
||||
expression -> isPartitionDimension(dataSetSchema, expression.getFieldName()))
|
||||
.collect(Collectors.toList());
|
||||
if (CollectionUtils.isEmpty(dateExpressions)) {
|
||||
return null;
|
||||
}
|
||||
DateConf dateInfo = new DateConf();
|
||||
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
|
||||
FieldExpression firstExpression = dateExpressions.get(0);
|
||||
|
||||
FilterOperatorEnum firstOperator =
|
||||
FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
|
||||
if (FilterOperatorEnum.EQUALS.equals(firstOperator)
|
||||
&& Objects.nonNull(firstExpression.getFieldValue())) {
|
||||
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
|
||||
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
|
||||
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
|
||||
return dateInfo;
|
||||
}
|
||||
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
|
||||
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
|
||||
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
|
||||
if (hasSecondDate(dateExpressions)) {
|
||||
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
|
||||
}
|
||||
}
|
||||
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
|
||||
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
|
||||
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
|
||||
if (hasSecondDate(dateExpressions)) {
|
||||
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
|
||||
}
|
||||
}
|
||||
return dateInfo;
|
||||
}
|
||||
|
||||
private static boolean isPartitionDimension(DataSetSchema dataSetSchema, String sqlFieldName) {
|
||||
if (TimeDimensionEnum.containsTimeDimension(sqlFieldName)) {
|
||||
return true;
|
||||
}
|
||||
if (Objects.isNull(dataSetSchema) || Objects.isNull(dataSetSchema.getPartitionDimension())
|
||||
|| Objects.isNull(dataSetSchema.getPartitionDimension().getName())) {
|
||||
return false;
|
||||
}
|
||||
return sqlFieldName.equalsIgnoreCase(dataSetSchema.getPartitionDimension().getName());
|
||||
}
|
||||
|
||||
private boolean containOperators(FieldExpression expression, FilterOperatorEnum firstOperator,
|
||||
FilterOperatorEnum... operatorEnums) {
|
||||
return (Arrays.asList(operatorEnums).contains(firstOperator)
|
||||
&& Objects.nonNull(expression.getFieldValue()));
|
||||
}
|
||||
|
||||
private boolean hasSecondDate(List<FieldExpression> dateExpressions) {
|
||||
return dateExpressions.size() > 1
|
||||
&& Objects.nonNull(dateExpressions.get(1).getFieldValue());
|
||||
}
|
||||
|
||||
private Map<String, SchemaElement> getNameToElement(DataSetSchema dsSchema) {
|
||||
Set<SchemaElement> dimensions = dsSchema.getDimensions();
|
||||
Set<SchemaElement> metrics = dsSchema.getMetrics();
|
||||
|
||||
List<SchemaElement> allElements = Lists.newArrayList();
|
||||
allElements.addAll(dimensions);
|
||||
allElements.addAll(metrics);
|
||||
// support alias
|
||||
return allElements.stream().flatMap(schemaElement -> {
|
||||
Set<Pair<String, SchemaElement>> result = new HashSet<>();
|
||||
result.add(Pair.of(schemaElement.getName(), schemaElement));
|
||||
List<String> aliasList = schemaElement.getAlias();
|
||||
if (!org.springframework.util.CollectionUtils.isEmpty(aliasList)) {
|
||||
for (String alias : aliasList) {
|
||||
result.add(Pair.of(alias, schemaElement));
|
||||
}
|
||||
}
|
||||
return result.stream();
|
||||
}).collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
private void parseQueryType(SemanticParseInfo parseInfo) {
|
||||
parseInfo.setQueryType(QueryType.DETAIL);
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getCorrectedS2SQL())) {
|
||||
parseInfo.setQueryType(QueryType.DETAIL);
|
||||
}
|
||||
|
||||
// 2. AGG queryType
|
||||
if (Objects.nonNull(sqlInfo) && StringUtils.isNotBlank(sqlInfo.getParsedS2SQL())
|
||||
&& SqlSelectFunctionHelper.hasAggregateFunction(sqlInfo.getCorrectedS2SQL())) {
|
||||
parseInfo.setQueryType(QueryType.AGGREGATE);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.DeleteMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
@@ -15,6 +16,7 @@ import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.PutMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import java.util.List;
|
||||
@@ -48,8 +50,11 @@ public class AgentController {
|
||||
}
|
||||
|
||||
@RequestMapping("/getAgentList")
|
||||
public List<Agent> getAgentList() {
|
||||
return agentService.getAgents();
|
||||
public List<Agent> getAgentList(
|
||||
@RequestParam(value = "authType", required = false) AuthType authType,
|
||||
HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
|
||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
return agentService.getAgents(user, authType);
|
||||
}
|
||||
|
||||
@RequestMapping("/getToolTypes")
|
||||
|
||||
@@ -2,10 +2,12 @@ package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface AgentService {
|
||||
List<Agent> getAgents(User user, AuthType authType);
|
||||
|
||||
List<Agent> getAgents();
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.common.config.ChatModel;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import com.tencent.supersonic.common.service.ChatModelService;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -43,6 +44,27 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
|
||||
private ExecutorService executorService = Executors.newFixedThreadPool(1);
|
||||
|
||||
@Override
|
||||
public List<Agent> getAgents(User user, AuthType authType) {
|
||||
return getAgentDOList().stream().map(this::convert)
|
||||
.filter(agent -> filterByAuth(agent, user, authType)).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private boolean filterByAuth(Agent agent, User user, AuthType authType) {
|
||||
if (user.isSuperAdmin() || user.getName().equals(agent.getCreatedBy())) {
|
||||
return true;
|
||||
}
|
||||
authType = authType == null ? AuthType.VIEWER : authType;
|
||||
switch (authType) {
|
||||
case ADMIN:
|
||||
return agent.contains(user, Agent::getAdmins);
|
||||
case VIEWER:
|
||||
default:
|
||||
return agent.contains(user, Agent::getAdmins)
|
||||
|| agent.contains(user, Agent::getViewers);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Agent> getAgents() {
|
||||
return getAgentDOList().stream().map(this::convert).collect(Collectors.toList());
|
||||
@@ -135,6 +157,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
c.setChatModelConfig(chatModelService.getChatModel(c.getChatModelId()).getConfig());
|
||||
}
|
||||
});
|
||||
agent.setAdmins(JsonUtil.toList(agentDO.getAdmin(), String.class));
|
||||
agent.setViewers(JsonUtil.toList(agentDO.getViewer(), String.class));
|
||||
return agent;
|
||||
}
|
||||
|
||||
@@ -145,6 +169,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
agentDO.setExamples(JsonUtil.toString(agent.getExamples()));
|
||||
agentDO.setChatModelConfig(JsonUtil.toString(agent.getChatAppConfig()));
|
||||
agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig()));
|
||||
agentDO.setAdmin(JsonUtil.toString(agent.getAdmins()));
|
||||
agentDO.setViewer(JsonUtil.toString(agent.getViewers()));
|
||||
if (agentDO.getStatus() == null) {
|
||||
agentDO.setStatus(1);
|
||||
}
|
||||
|
||||
@@ -25,11 +25,9 @@ import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
@@ -198,7 +196,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
handleRuleQueryMode(semanticQuery, dataSetSchema, user);
|
||||
}
|
||||
|
||||
return executeQuery(semanticQuery, user, dataSetSchema);
|
||||
return executeQuery(semanticQuery, user);
|
||||
}
|
||||
|
||||
private List<String> getFieldsFromSql(SemanticParseInfo parseInfo) {
|
||||
@@ -233,18 +231,14 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
log.info("rule begin replace metrics and revise filters!");
|
||||
validFilter(semanticQuery.getParseInfo().getDimensionFilters());
|
||||
validFilter(semanticQuery.getParseInfo().getMetricFilters());
|
||||
semanticQuery.initS2Sql(dataSetSchema, user);
|
||||
semanticQuery.buildS2Sql(dataSetSchema);
|
||||
}
|
||||
|
||||
private QueryResult executeQuery(SemanticQuery semanticQuery, User user,
|
||||
DataSetSchema dataSetSchema) throws Exception {
|
||||
private QueryResult executeQuery(SemanticQuery semanticQuery, User user) throws Exception {
|
||||
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
QueryResult queryResult = doExecution(semanticQueryReq, parseInfo.getQueryMode(), user);
|
||||
queryResult.setChatContext(semanticQuery.getParseInfo());
|
||||
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user);
|
||||
queryResult.setEntityInfo(entityInfo);
|
||||
parseInfo.getSqlInfo().setQuerySQL(queryResult.getQuerySql());
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.Entity;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ItemNameVisibilityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeInfoReq;
|
||||
@@ -16,7 +15,6 @@ import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatDetailRichConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityRichInfoResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ItemVisibilityInfo;
|
||||
import com.tencent.supersonic.chat.server.config.ChatConfig;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatConfigRepository;
|
||||
@@ -88,16 +86,6 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
return configEditCmd.getId();
|
||||
}
|
||||
|
||||
public ItemNameVisibilityInfo getVisibilityByModelId(Long modelId) {
|
||||
ChatConfigResp chatConfigResp = fetchConfigByModelId(modelId);
|
||||
ChatConfig chatConfig = new ChatConfig();
|
||||
chatConfig.setModelId(modelId);
|
||||
chatConfig.setChatAggConfig(chatConfigResp.getChatAggConfig());
|
||||
chatConfig.setChatDetailConfig(chatConfigResp.getChatDetailConfig());
|
||||
ItemNameVisibilityInfo itemNameVisibility = getItemNameVisibility(chatConfig);
|
||||
return itemNameVisibility;
|
||||
}
|
||||
|
||||
public ItemNameVisibilityInfo getItemNameVisibility(ChatConfig chatConfig) {
|
||||
Long modelId = chatConfig.getModelId();
|
||||
|
||||
@@ -240,19 +228,6 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
return detailRichConfig;
|
||||
}
|
||||
|
||||
private EntityRichInfoResp generateRichEntity(Entity entity, DataSetSchema modelSchema) {
|
||||
EntityRichInfoResp entityRichInfo = new EntityRichInfoResp();
|
||||
if (Objects.isNull(entity) || Objects.isNull(entity.getEntityId())) {
|
||||
return entityRichInfo;
|
||||
}
|
||||
BeanUtils.copyProperties(entity, entityRichInfo);
|
||||
Map<Long, SchemaElement> dimIdAndRespPair = modelSchema.getDimensions().stream().collect(
|
||||
Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
|
||||
|
||||
entityRichInfo.setDimItem(dimIdAndRespPair.get(entity.getEntityId()));
|
||||
return entityRichInfo;
|
||||
}
|
||||
|
||||
private ChatAggRichConfigResp fillChatAggRichConfig(DataSetSchema modelSchema,
|
||||
ChatConfigResp chatConfigResp) {
|
||||
if (Objects.isNull(chatConfigResp) || Objects.isNull(chatConfigResp.getChatAggConfig())) {
|
||||
@@ -327,7 +302,7 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
}
|
||||
Map<Long, SchemaElement> dimIdAndRespPair = modelSchema.getDimensions().stream().collect(
|
||||
Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
|
||||
knowledgeInfos.stream().forEach(knowledgeInfo -> {
|
||||
knowledgeInfos.forEach(knowledgeInfo -> {
|
||||
if (Objects.nonNull(knowledgeInfo)) {
|
||||
SchemaElement dimSchemaResp = dimIdAndRespPair.get(knowledgeInfo.getItemId());
|
||||
if (Objects.nonNull(dimSchemaResp)) {
|
||||
|
||||
@@ -49,10 +49,10 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
@Override
|
||||
public void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user) {
|
||||
ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId());
|
||||
boolean hadEnabled = MemoryStatus.ENABLED.equals(chatMemoryDO.getStatus());
|
||||
chatMemoryDO.setUpdatedBy(user.getName());
|
||||
chatMemoryDO.setUpdatedAt(new Date());
|
||||
BeanMapper.mapper(chatMemoryUpdateReq, chatMemoryDO);
|
||||
boolean hadEnabled = MemoryStatus.ENABLED.equals(chatMemoryDO.getStatus());
|
||||
if (MemoryStatus.ENABLED.equals(chatMemoryUpdateReq.getStatus()) && !hadEnabled) {
|
||||
enableMemory(chatMemoryDO);
|
||||
} else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) {
|
||||
|
||||
@@ -20,54 +20,6 @@
|
||||
<result column="updated_at" property="updatedAt"/>
|
||||
</resultMap>
|
||||
|
||||
<insert id="addConfig"
|
||||
parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.ChatConfigDO"
|
||||
useGeneratedKeys="true" keyProperty="id">
|
||||
insert into s2_chat_config
|
||||
(
|
||||
model_id, `chat_detail_config`, chat_agg_config, recommended_questions, status, llm_examples, created_by, updated_by, created_at, updated_at
|
||||
)
|
||||
values
|
||||
(
|
||||
#{modelId}, #{chatDetailConfig}, #{chatAggConfig}, #{recommendedQuestions}, #{status}, #{llmExamples}, #{createdBy}, #{updatedBy}, #{createdAt}, #{updatedAt}
|
||||
)
|
||||
</insert>
|
||||
|
||||
|
||||
<update id="editConfig">
|
||||
update s2_chat_config
|
||||
<set>
|
||||
`updated_at` = #{updatedAt} ,
|
||||
<if test="chatDetailConfig != null and chatDetailConfig != ''">
|
||||
`chat_detail_config` = #{chatDetailConfig} ,
|
||||
</if>
|
||||
<if test="chatAggConfig != null and chatAggConfig != ''">
|
||||
chat_agg_config = #{chatAggConfig} ,
|
||||
</if>
|
||||
<if test="recommendedQuestions != null and recommendedQuestions != ''">
|
||||
recommended_questions = #{recommendedQuestions} ,
|
||||
</if>
|
||||
<if test="status != null and status != ''">
|
||||
status = #{status} ,
|
||||
</if>
|
||||
<if test="updatedBy != null and updatedBy != ''">
|
||||
updated_by = #{updatedBy} ,
|
||||
</if>
|
||||
<if test="llmExamples != null and llmExamples != ''">
|
||||
llm_examples = #{llmExamples} ,
|
||||
</if>
|
||||
</set>
|
||||
|
||||
<where>
|
||||
<if test="id != null and id != ''">
|
||||
id = #{id}
|
||||
</if>
|
||||
<if test="modelId != null and modelId != ''">
|
||||
and model_id = #{modelId}
|
||||
</if>
|
||||
</where>
|
||||
</update>
|
||||
|
||||
<select id="search" resultMap="chaConfigDO">
|
||||
select *
|
||||
from s2_chat_config
|
||||
|
||||
@@ -20,11 +20,4 @@
|
||||
from s2_chat_context where chat_id=#{chatId} limit 1
|
||||
</select>
|
||||
|
||||
<insert id="addContext" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO" >
|
||||
insert into s2_chat_context (chat_id,user,query_text,semantic_parse) values (#{chatId}, #{user},#{queryText}, #{semanticParse})
|
||||
</insert>
|
||||
<update id="updateContext">
|
||||
update s2_chat_context set query_text=#{queryText},semantic_parse=#{semanticParse} where chat_id=#{chatId}
|
||||
</update>
|
||||
|
||||
</mapper>
|
||||
@@ -528,7 +528,7 @@ public class SqlReplaceHelper {
|
||||
}
|
||||
}
|
||||
|
||||
private static Select replaceAggAliasOrderItem(Select selectStatement) {
|
||||
private static Select replaceAggAliasOrderbyField(Select selectStatement) {
|
||||
if (selectStatement instanceof PlainSelect) {
|
||||
PlainSelect plainSelect = (PlainSelect) selectStatement;
|
||||
if (Objects.nonNull(plainSelect.getOrderByElements())) {
|
||||
@@ -564,15 +564,15 @@ public class SqlReplaceHelper {
|
||||
if (plainSelect.getFromItem() instanceof ParenthesedSelect) {
|
||||
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) plainSelect.getFromItem();
|
||||
parenthesedSelect
|
||||
.setSelect(replaceAggAliasOrderItem(parenthesedSelect.getSelect()));
|
||||
.setSelect(replaceAggAliasOrderbyField(parenthesedSelect.getSelect()));
|
||||
}
|
||||
return selectStatement;
|
||||
}
|
||||
return selectStatement;
|
||||
}
|
||||
|
||||
public static String replaceAggAliasOrderItem(String sql) {
|
||||
Select selectStatement = replaceAggAliasOrderItem(SqlSelectHelper.getSelect(sql));
|
||||
public static String replaceAggAliasOrderbyField(String sql) {
|
||||
Select selectStatement = replaceAggAliasOrderbyField(SqlSelectHelper.getSelect(sql));
|
||||
return selectStatement.toString();
|
||||
}
|
||||
|
||||
|
||||
@@ -74,7 +74,6 @@ public class SqlValidHelper {
|
||||
CCJSqlParserUtil.parse(sql);
|
||||
return true;
|
||||
} catch (Exception e) {
|
||||
log.error("isValidSQL parse:{}", e);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.common.pojo;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@@ -10,6 +11,7 @@ import java.util.List;
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
@Builder
|
||||
public class Filter {
|
||||
|
||||
private Relation relation = Relation.FILTER;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
package com.tencent.supersonic.common.pojo.enums;
|
||||
|
||||
public enum AuthType {
|
||||
VISIBLE, ADMIN
|
||||
VIEWER, ADMIN
|
||||
}
|
||||
|
||||
@@ -14,8 +14,6 @@ public enum DictWordType {
|
||||
|
||||
DATASET("dataSet"),
|
||||
|
||||
ENTITY("entity"),
|
||||
|
||||
NUMBER("m"),
|
||||
|
||||
TAG("tag"),
|
||||
|
||||
@@ -8,7 +8,8 @@ public enum EngineType {
|
||||
KAFKA(4, "kafka"),
|
||||
H2(5, "h2"),
|
||||
POSTGRESQL(6, "postgresql"),
|
||||
OTHER(7, "other");
|
||||
OTHER(7, "other"),
|
||||
DUCKDB(8, "duckdb");
|
||||
|
||||
private Integer code;
|
||||
|
||||
|
||||
@@ -5,9 +5,7 @@ public enum QueryType {
|
||||
/** queries with aggregation (optionally slice and dice by dimensions) */
|
||||
AGGREGATE,
|
||||
/** queries with field selection */
|
||||
DETAIL,
|
||||
/** queries with ID-based entity selection */
|
||||
ID;
|
||||
DETAIL;
|
||||
|
||||
public boolean isNativeAggQuery() {
|
||||
return DETAIL.equals(this);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.common.pojo.enums;
|
||||
|
||||
public enum Text2SQLType {
|
||||
ONLY_RULE, LLM_OR_RULE;
|
||||
ONLY_RULE, LLM_OR_RULE, NONE;
|
||||
|
||||
public boolean enableLLM() {
|
||||
return this.equals(LLM_OR_RULE);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
package com.tencent.supersonic.common.pojo.enums;
|
||||
|
||||
public enum TypeEnums {
|
||||
METRIC, DIMENSION, TAG_OBJECT, TAG, DOMAIN, ENTITY, DATASET, MODEL, UNKNOWN
|
||||
METRIC, DIMENSION, TAG_OBJECT, TAG, DOMAIN, DATASET, MODEL, UNKNOWN
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.ItemDateResp;
|
||||
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@@ -32,14 +33,9 @@ import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT;
|
||||
@Data
|
||||
public class DateModeUtils {
|
||||
|
||||
@Value("${s2.query.parameter.sys.date:sys_imp_date}")
|
||||
private String sysDateCol;
|
||||
|
||||
@Value("${s2.query.parameter.sys.month:sys_imp_month}")
|
||||
private String sysDateMonthCol;
|
||||
|
||||
@Value("${s2.query.parameter.sys.month:sys_imp_week}")
|
||||
private String sysDateWeekCol;
|
||||
private final String sysDateCol = TimeDimensionEnum.DAY.getName();
|
||||
private final String sysDateMonthCol = TimeDimensionEnum.MONTH.getName();
|
||||
private final String sysDateWeekCol = TimeDimensionEnum.WEEK.getName();
|
||||
|
||||
@Value("${s2.query.parameter.sys.zipper.begin:start_}")
|
||||
private String sysZipperDateColBegin;
|
||||
|
||||
@@ -16,6 +16,7 @@ import java.time.temporal.TemporalAdjuster;
|
||||
import java.time.temporal.TemporalAdjusters;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Calendar;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
@@ -201,6 +202,13 @@ public class DateUtils {
|
||||
return false;
|
||||
}
|
||||
|
||||
public static Long calculateDiffMs(Date createAt) {
|
||||
Calendar calendar = Calendar.getInstance();
|
||||
Date now = calendar.getTime();
|
||||
long milliseconds = now.getTime() - createAt.getTime();
|
||||
return milliseconds;
|
||||
}
|
||||
|
||||
public static boolean isDateString(String value, String format) {
|
||||
try {
|
||||
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(format);
|
||||
|
||||
@@ -325,10 +325,10 @@ class SqlReplaceHelperTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
void testReplaceAggAliasOrderItem() {
|
||||
void testReplaceAggAliasOrderbyField() {
|
||||
String sql = "SELECT SUM(访问次数) AS top10总播放量 FROM (SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数 "
|
||||
+ "GROUP BY 部门 ORDER BY SUM(访问次数) DESC LIMIT 10) AS top10";
|
||||
String replaceSql = SqlReplaceHelper.replaceAggAliasOrderItem(sql);
|
||||
String replaceSql = SqlReplaceHelper.replaceAggAliasOrderbyField(sql);
|
||||
Assert.assertEquals(
|
||||
"SELECT SUM(访问次数) AS top10总播放量 FROM (SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数 "
|
||||
+ "GROUP BY 部门 ORDER BY 2 DESC LIMIT 10) AS top10",
|
||||
|
||||
@@ -62,7 +62,9 @@ services:
|
||||
sleep 15 &&
|
||||
if ! mysql -h supersonic_mysql -usupersonic_user -psupersonic_password -e 'use supersonic_db; show tables;' | grep -q 's2_database'; then
|
||||
mysql -h supersonic_mysql -usupersonic_user -psupersonic_password supersonic_db < /usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-latest}/conf/db/schema-mysql.sql &&
|
||||
mysql -h supersonic_mysql -usupersonic_user -psupersonic_password supersonic_db < /usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-latest}/conf/db/data-mysql.sql
|
||||
mysql -h supersonic_mysql -usupersonic_user -psupersonic_password supersonic_db < /usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-latest}/conf/db/schema-mysql-demo.sql &&
|
||||
mysql -h supersonic_mysql -usupersonic_user -psupersonic_password supersonic_db < /usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-latest}/conf/db/data-mysql.sql &&
|
||||
mysql -h supersonic_mysql -usupersonic_user -psupersonic_password supersonic_db < /usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-latest}/conf/db/data-mysql-demo.sql
|
||||
else
|
||||
echo 'Database already initialized.'
|
||||
fi
|
||||
|
||||
24
docker/docker-publish.sh
Normal file
24
docker/docker-publish.sh
Normal file
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env bash
|
||||
# Exit immediately if a command exits with a non-zero status
|
||||
set -e
|
||||
VERSION=$1
|
||||
|
||||
# Image name
|
||||
IMAGE_NAME="supersonicbi/supersonic"
|
||||
|
||||
# Default tag is latest
|
||||
TAGS="latest"
|
||||
|
||||
# If VERSION is provided, add it to TAGS and tag the image as latest
|
||||
if [ -n "$VERSION" ]; then
|
||||
TAGS="$TAGS $VERSION"
|
||||
docker tag $IMAGE_NAME:$VERSION $IMAGE_NAME:latest
|
||||
fi
|
||||
|
||||
# Push Docker images
|
||||
for TAG in $TAGS; do
|
||||
echo "Pushing Docker image $IMAGE_NAME:$TAG"
|
||||
docker push $IMAGE_NAME:$TAG
|
||||
done
|
||||
|
||||
echo "Docker images pushed successfully."
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.headless.api.pojo;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
@@ -14,4 +15,6 @@ public class DBColumn {
|
||||
private String dataType;
|
||||
|
||||
private String comment;
|
||||
|
||||
private FieldType fieldType;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.headless.api.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.io.Serializable;
|
||||
@@ -9,7 +8,6 @@ import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
@@ -24,16 +22,16 @@ public class DataSetSchema implements Serializable {
|
||||
private Set<SchemaElement> tags = new HashSet<>();
|
||||
private Set<SchemaElement> dimensionValues = new HashSet<>();
|
||||
private Set<SchemaElement> terms = new HashSet<>();
|
||||
private SchemaElement entity = new SchemaElement();
|
||||
private QueryConfig queryConfig;
|
||||
|
||||
public Long getDataSetId() {
|
||||
return dataSet.getDataSetId();
|
||||
}
|
||||
|
||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||
Optional<SchemaElement> element = Optional.empty();
|
||||
|
||||
switch (elementType) {
|
||||
case ENTITY:
|
||||
element = Optional.ofNullable(entity);
|
||||
break;
|
||||
case DATASET:
|
||||
element = Optional.of(dataSet);
|
||||
break;
|
||||
@@ -55,11 +53,7 @@ public class DataSetSchema implements Serializable {
|
||||
default:
|
||||
}
|
||||
|
||||
if (element.isPresent()) {
|
||||
return element.get();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
return element.orElse(null);
|
||||
}
|
||||
|
||||
public Map<String, String> getBizNameToName() {
|
||||
@@ -70,7 +64,7 @@ public class DataSetSchema implements Serializable {
|
||||
SchemaElement::getName, (k1, k2) -> k1));
|
||||
}
|
||||
|
||||
public TimeDefaultConfig getTagTypeTimeDefaultConfig() {
|
||||
public TimeDefaultConfig getDetailTypeTimeDefaultConfig() {
|
||||
if (queryConfig == null) {
|
||||
return null;
|
||||
}
|
||||
@@ -90,45 +84,6 @@ public class DataSetSchema implements Serializable {
|
||||
return queryConfig.getAggregateTypeDefaultConfig().getTimeDefaultConfig();
|
||||
}
|
||||
|
||||
public DetailTypeDefaultConfig getTagTypeDefaultConfig() {
|
||||
if (queryConfig == null) {
|
||||
return null;
|
||||
}
|
||||
return queryConfig.getDetailTypeDefaultConfig();
|
||||
}
|
||||
|
||||
public List<SchemaElement> getTagDefaultDimensions() {
|
||||
DetailTypeDefaultConfig detailTypeDefaultConfig = getTagTypeDefaultConfig();
|
||||
if (Objects.isNull(detailTypeDefaultConfig)
|
||||
|| Objects.isNull(detailTypeDefaultConfig.getDefaultDisplayInfo())) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
if (CollectionUtils
|
||||
.isNotEmpty(detailTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds())) {
|
||||
return detailTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds().stream()
|
||||
.map(id -> {
|
||||
SchemaElement metric = getElement(SchemaElementType.METRIC, id);
|
||||
return metric;
|
||||
}).filter(Objects::nonNull).collect(Collectors.toList());
|
||||
}
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
public List<SchemaElement> getTagDefaultMetrics() {
|
||||
DetailTypeDefaultConfig detailTypeDefaultConfig = getTagTypeDefaultConfig();
|
||||
if (Objects.isNull(detailTypeDefaultConfig)
|
||||
|| Objects.isNull(detailTypeDefaultConfig.getDefaultDisplayInfo())) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
if (CollectionUtils
|
||||
.isNotEmpty(detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds())) {
|
||||
return detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream()
|
||||
.map(id -> getElement(SchemaElementType.DIMENSION, id)).filter(Objects::nonNull)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
public boolean containsPartitionDimensions() {
|
||||
return dimensions.stream().anyMatch(SchemaElement::isPartitionTime);
|
||||
}
|
||||
|
||||
@@ -17,5 +17,7 @@ public class DbSchema {
|
||||
|
||||
private String sql;
|
||||
|
||||
private String ddl;
|
||||
|
||||
private List<DBColumn> dbColumns;
|
||||
}
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
package com.tencent.supersonic.headless.api.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class DefaultDisplayInfo implements Serializable {
|
||||
|
||||
// When displaying tag selection results, the information displayed by default
|
||||
private List<Long> dimensionIds = new ArrayList<>();
|
||||
|
||||
private List<Long> metricIds = new ArrayList<>();
|
||||
}
|
||||
@@ -8,8 +8,6 @@ import java.io.Serializable;
|
||||
@Data
|
||||
public class DetailTypeDefaultConfig implements Serializable {
|
||||
|
||||
private DefaultDisplayInfo defaultDisplayInfo;
|
||||
|
||||
// default time to filter tag selection results
|
||||
private TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.api.pojo;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
@@ -12,7 +13,7 @@ public class Dim {
|
||||
|
||||
private String name;
|
||||
|
||||
private String type;
|
||||
private DimensionType type;
|
||||
|
||||
private String expr;
|
||||
|
||||
@@ -28,23 +29,15 @@ public class Dim {
|
||||
|
||||
private int isTag;
|
||||
|
||||
public Dim(String name, String bizName, String type, Integer isCreateDimension) {
|
||||
public Dim(String name, String bizName, DimensionType type, Integer isCreateDimension) {
|
||||
this.name = name;
|
||||
this.type = type;
|
||||
this.isCreateDimension = isCreateDimension;
|
||||
this.bizName = bizName;
|
||||
}
|
||||
|
||||
public Dim(String name, String bizName, String type, Integer isCreateDimension, int isTag) {
|
||||
this.name = name;
|
||||
this.type = type;
|
||||
this.isCreateDimension = isCreateDimension;
|
||||
this.bizName = bizName;
|
||||
this.isTag = isTag;
|
||||
}
|
||||
|
||||
public Dim(String name, String type, String expr, String dateFormat,
|
||||
DimensionTimeTypeParams typeParams, Integer isCreateDimension, String bizName) {
|
||||
public Dim(String name, String bizName, DimensionType type, Integer isCreateDimension,
|
||||
String expr, String dateFormat, DimensionTimeTypeParams typeParams) {
|
||||
this.name = name;
|
||||
this.type = type;
|
||||
this.expr = expr;
|
||||
@@ -55,8 +48,8 @@ public class Dim {
|
||||
}
|
||||
|
||||
public static Dim getDefault() {
|
||||
return new Dim("日期", "time", "2023-05-28", Constants.DAY_FORMAT,
|
||||
new DimensionTimeTypeParams("true", "day"), 0, "imp_date");
|
||||
return new Dim("数据日期", "imp_date", DimensionType.partition_time, 0, "imp_date",
|
||||
Constants.DAY_FORMAT, new DimensionTimeTypeParams("false", "day"));
|
||||
}
|
||||
|
||||
public String getFieldName() {
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
package com.tencent.supersonic.headless.api.pojo;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class Entity {
|
||||
|
||||
/** uniquely identifies an entity */
|
||||
private Long entityId;
|
||||
|
||||
/** entity name list */
|
||||
private List<String> names;
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package com.tencent.supersonic.headless.api.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class EntityInfo implements Serializable {
|
||||
|
||||
private DataSetInfo dataSetInfo = new DataSetInfo();
|
||||
private List<DataInfo> dimensions = new ArrayList<>();
|
||||
private List<DataInfo> metrics = new ArrayList<>();
|
||||
private String entityId;
|
||||
}
|
||||
@@ -18,8 +18,6 @@ public class Identify {
|
||||
|
||||
private String bizName;
|
||||
|
||||
private List<String> entityNames;
|
||||
|
||||
private Integer isCreateDimension = 0;
|
||||
|
||||
public Identify(String name, String type, String bizName) {
|
||||
@@ -28,13 +26,6 @@ public class Identify {
|
||||
this.bizName = bizName;
|
||||
}
|
||||
|
||||
public Identify(String name, String type, String bizName, Integer isCreateDimension) {
|
||||
this.name = name;
|
||||
this.type = type;
|
||||
this.bizName = bizName;
|
||||
this.isCreateDimension = isCreateDimension;
|
||||
}
|
||||
|
||||
public String getFieldName() {
|
||||
return bizName;
|
||||
}
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
package com.tencent.supersonic.headless.api.pojo;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.AggOption;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class MetricTable {
|
||||
|
||||
private String alias;
|
||||
private List<String> metrics = Lists.newArrayList();
|
||||
private List<String> dimensions = Lists.newArrayList();
|
||||
private String where;
|
||||
private AggOption aggOption = AggOption.DEFAULT;
|
||||
}
|
||||
@@ -43,8 +43,7 @@ public class ModelDetail {
|
||||
if (CollectionUtils.isEmpty(dimensions)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return dimensions.stream()
|
||||
.filter(dim -> DimensionType.partition_time.name().equalsIgnoreCase(dim.getType()))
|
||||
return dimensions.stream().filter(dim -> DimensionType.partition_time.equals(dim.getType()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
package com.tencent.supersonic.headless.api.pojo;
|
||||
|
||||
public enum SchemaElementType {
|
||||
DATASET, METRIC, DIMENSION, VALUE, ENTITY, ID, DATE, TAG, TERM
|
||||
DATASET, METRIC, DIMENSION, VALUE, ID, DATE, TAG, TERM
|
||||
}
|
||||
|
||||
@@ -26,26 +26,27 @@ import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_METRIC_LIMIT;
|
||||
public class SemanticParseInfo implements Serializable {
|
||||
|
||||
private Integer id;
|
||||
private String queryMode = "PLAIN_TEXT";
|
||||
private SchemaElement dataSet;
|
||||
private String queryMode = "";
|
||||
private QueryConfig queryConfig;
|
||||
private QueryType queryType;
|
||||
|
||||
private SchemaElement dataSet;
|
||||
private Set<SchemaElement> metrics = Sets.newTreeSet(new SchemaNameLengthComparator());
|
||||
private Set<SchemaElement> dimensions = Sets.newTreeSet(new SchemaNameLengthComparator());
|
||||
private SchemaElement entity;
|
||||
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
|
||||
private FilterType filterType = FilterType.AND;
|
||||
|
||||
private Set<QueryFilter> dimensionFilters = Sets.newHashSet();
|
||||
private Set<QueryFilter> metricFilters = Sets.newHashSet();
|
||||
private FilterType filterType = FilterType.AND;
|
||||
|
||||
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
|
||||
private Set<Order> orders = Sets.newHashSet();
|
||||
private DateConf dateInfo;
|
||||
private long limit = DEFAULT_DETAIL_LIMIT;
|
||||
private double score;
|
||||
private List<SchemaElementMatch> elementMatches = Lists.newArrayList();
|
||||
private DateConf dateInfo;
|
||||
private SqlInfo sqlInfo = new SqlInfo();
|
||||
private SqlEvaluation sqlEvaluation = new SqlEvaluation();
|
||||
private QueryType queryType = QueryType.ID;
|
||||
private EntityInfo entityInfo;
|
||||
private String textInfo;
|
||||
private SqlEvaluation sqlEvaluation = new SqlEvaluation();
|
||||
private Map<String, Object> properties = Maps.newHashMap();
|
||||
|
||||
@Data
|
||||
@@ -138,8 +139,7 @@ public class SemanticParseInfo implements Serializable {
|
||||
public long getDetailLimit() {
|
||||
long limit = DEFAULT_DETAIL_LIMIT;
|
||||
if (Objects.nonNull(queryConfig)
|
||||
&& Objects.nonNull(queryConfig.getDetailTypeDefaultConfig())
|
||||
&& Objects.nonNull(queryConfig.getDetailTypeDefaultConfig().getLimit())) {
|
||||
&& Objects.nonNull(queryConfig.getDetailTypeDefaultConfig())) {
|
||||
limit = queryConfig.getDetailTypeDefaultConfig().getLimit();
|
||||
}
|
||||
return limit;
|
||||
@@ -148,8 +148,7 @@ public class SemanticParseInfo implements Serializable {
|
||||
public long getMetricLimit() {
|
||||
long limit = DEFAULT_METRIC_LIMIT;
|
||||
if (Objects.nonNull(queryConfig)
|
||||
&& Objects.nonNull(queryConfig.getAggregateTypeDefaultConfig())
|
||||
&& Objects.nonNull(queryConfig.getAggregateTypeDefaultConfig().getLimit())) {
|
||||
&& Objects.nonNull(queryConfig.getAggregateTypeDefaultConfig())) {
|
||||
limit = queryConfig.getAggregateTypeDefaultConfig().getLimit();
|
||||
}
|
||||
return limit;
|
||||
|
||||
@@ -13,7 +13,7 @@ import java.util.stream.Collectors;
|
||||
|
||||
public class SemanticSchema implements Serializable {
|
||||
|
||||
private List<DataSetSchema> dataSetSchemaList;
|
||||
private final List<DataSetSchema> dataSetSchemaList;
|
||||
|
||||
public SemanticSchema(List<DataSetSchema> dataSetSchemaList) {
|
||||
this.dataSetSchemaList = dataSetSchemaList;
|
||||
@@ -27,9 +27,6 @@ public class SemanticSchema implements Serializable {
|
||||
Optional<SchemaElement> element = Optional.empty();
|
||||
|
||||
switch (elementType) {
|
||||
case ENTITY:
|
||||
element = getElementsById(elementID, getEntities());
|
||||
break;
|
||||
case DATASET:
|
||||
element = getElementsById(elementID, getDataSets());
|
||||
break;
|
||||
@@ -51,11 +48,7 @@ public class SemanticSchema implements Serializable {
|
||||
default:
|
||||
}
|
||||
|
||||
if (element.isPresent()) {
|
||||
return element.get();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
return element.orElse(null);
|
||||
}
|
||||
|
||||
public Map<Long, String> getDataSetIdToName() {
|
||||
@@ -65,13 +58,13 @@ public class SemanticSchema implements Serializable {
|
||||
|
||||
public List<SchemaElement> getDimensionValues() {
|
||||
List<SchemaElement> dimensionValues = new ArrayList<>();
|
||||
dataSetSchemaList.stream().forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
|
||||
dataSetSchemaList.forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
|
||||
return dimensionValues;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getDimensions() {
|
||||
List<SchemaElement> dimensions = new ArrayList<>();
|
||||
dataSetSchemaList.stream().forEach(d -> dimensions.addAll(d.getDimensions()));
|
||||
dataSetSchemaList.forEach(d -> dimensions.addAll(d.getDimensions()));
|
||||
return dimensions;
|
||||
}
|
||||
|
||||
@@ -97,26 +90,15 @@ public class SemanticSchema implements Serializable {
|
||||
return getElementsByDataSetId(dataSetId, metrics);
|
||||
}
|
||||
|
||||
public List<SchemaElement> getEntities() {
|
||||
List<SchemaElement> entities = new ArrayList<>();
|
||||
dataSetSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
||||
return entities;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getEntities(Long dataSetId) {
|
||||
List<SchemaElement> entities = getEntities();
|
||||
return getElementsByDataSetId(dataSetId, entities);
|
||||
}
|
||||
|
||||
public List<SchemaElement> getTags() {
|
||||
List<SchemaElement> tags = new ArrayList<>();
|
||||
dataSetSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
|
||||
dataSetSchemaList.forEach(d -> tags.addAll(d.getTags()));
|
||||
return tags;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getTerms() {
|
||||
List<SchemaElement> terms = new ArrayList<>();
|
||||
dataSetSchemaList.stream().forEach(d -> terms.addAll(d.getTerms()));
|
||||
dataSetSchemaList.forEach(d -> terms.addAll(d.getTerms()));
|
||||
return terms;
|
||||
}
|
||||
|
||||
@@ -137,22 +119,26 @@ public class SemanticSchema implements Serializable {
|
||||
return getElementsById(dataSetId, dataSets).orElse(null);
|
||||
}
|
||||
|
||||
public QueryConfig getQueryConfig(Long dataSetId) {
|
||||
DataSetSchema first = dataSetSchemaList.stream().filter(
|
||||
dataSetSchema -> dataSetId.equals(dataSetSchema.getDataSet().getDataSetId()))
|
||||
.findFirst().orElse(null);
|
||||
if (Objects.nonNull(first)) {
|
||||
return first.getQueryConfig();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getDataSets() {
|
||||
List<SchemaElement> dataSets = new ArrayList<>();
|
||||
dataSetSchemaList.stream().forEach(d -> dataSets.add(d.getDataSet()));
|
||||
dataSetSchemaList.forEach(d -> dataSets.add(d.getDataSet()));
|
||||
return dataSets;
|
||||
}
|
||||
|
||||
public DataSetSchema getDataSetSchema(Long dataSetId) {
|
||||
return dataSetSchemaList.stream()
|
||||
.filter(dataSetSchema -> dataSetId.equals(dataSetSchema.getDataSetId())).findFirst()
|
||||
.orElse(null);
|
||||
}
|
||||
|
||||
public QueryConfig getQueryConfig(Long dataSetId) {
|
||||
DataSetSchema dataSetSchema = getDataSetSchema(dataSetId);
|
||||
if (Objects.nonNull(dataSetSchema)) {
|
||||
return dataSetSchema.getQueryConfig();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public Map<Long, DataSetSchema> getDataSetSchemaMap() {
|
||||
if (CollectionUtils.isEmpty(dataSetSchemaList)) {
|
||||
return new HashMap<>();
|
||||
|
||||
@@ -45,7 +45,8 @@ public enum DataType {
|
||||
|
||||
TDENGINE("TAOS", "TAOS", "com.taosdata.jdbc.TSDBDriver", "'", "'", "\"", "\""),
|
||||
|
||||
POSTGRESQL("postgresql", "postgresql", "org.postgresql.Driver", "'", "'", "\"", "\"");
|
||||
POSTGRESQL("postgresql", "postgresql", "org.postgresql.Driver", "'", "'", "\"", "\""),
|
||||
DUCKDB("duckdb", "duckdb", "org.duckdb.DuckDBDriver", "'", "'", "\"", "\"");
|
||||
|
||||
private String feature;
|
||||
private String desc;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.enums;
|
||||
|
||||
public enum FieldType {
|
||||
primary_key, foreign_key, data_time, dimension, measure;
|
||||
primary_key, foreign_key, partition_time, time, dimension, measure;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.DbSchema;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
@@ -10,12 +11,16 @@ public class ModelBuildReq {
|
||||
|
||||
private Long databaseId;
|
||||
|
||||
private Long domainId;
|
||||
|
||||
private String sql;
|
||||
|
||||
private String db;
|
||||
|
||||
private List<String> tables;
|
||||
|
||||
private List<DbSchema> dbSchemas;
|
||||
|
||||
private boolean buildByLLM;
|
||||
|
||||
private Integer chatModelId;
|
||||
|
||||
@@ -35,8 +35,6 @@ public class ModelReq extends SchemaItem {
|
||||
|
||||
private List<String> adminOrgs;
|
||||
|
||||
private Long tagObjectId;
|
||||
|
||||
private Map<String, Object> ext;
|
||||
|
||||
public String getViewer() {
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.MetricTable;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
public class ParseSqlReq {
|
||||
private Map<String, String> variables;
|
||||
private String sql = "";
|
||||
private List<MetricTable> tables;
|
||||
private boolean supportWith = true;
|
||||
private boolean withAlias = true;
|
||||
|
||||
public Map<String, String> getVariables() {
|
||||
if (variables == null) {
|
||||
variables = new HashMap<>();
|
||||
}
|
||||
return variables;
|
||||
}
|
||||
}
|
||||
@@ -30,6 +30,6 @@ public class QueryDataSetReq {
|
||||
private List<Filter> metricFilters = new ArrayList<>();
|
||||
private DateConf dateInfo;
|
||||
private Long limit = 2000L;
|
||||
private QueryType queryType = QueryType.ID;
|
||||
private QueryType queryType = QueryType.DETAIL;
|
||||
private boolean innerLayerNative = false;
|
||||
}
|
||||
|
||||
@@ -34,12 +34,11 @@ public class QueryFilter implements Serializable {
|
||||
QueryFilter that = (QueryFilter) o;
|
||||
return Objects.equal(bizName, that.bizName) && Objects.equal(name, that.name)
|
||||
&& operator == that.operator && Objects.equal(value, that.value)
|
||||
&& Objects.equal(elementID, that.elementID)
|
||||
&& Objects.equal(function, that.function);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(bizName, name, operator, value, elementID, function);
|
||||
return Objects.hashCode(bizName, name, operator, value, function);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ public class QueryNLReq extends SemanticQueryReq implements Serializable {
|
||||
private List<Text2SQLExemplar> dynamicExemplars = Lists.newArrayList();
|
||||
private SemanticParseInfo contextParseInfo;
|
||||
private SemanticParseInfo selectedParseInfo;
|
||||
private boolean descriptionMapped;
|
||||
|
||||
@Override
|
||||
public String toCustomizedString() {
|
||||
|
||||
@@ -12,6 +12,7 @@ import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateModeUtils;
|
||||
import com.tencent.supersonic.common.util.SqlFilterUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
@@ -42,6 +43,7 @@ import java.util.stream.Collectors;
|
||||
@Slf4j
|
||||
public class QueryStructReq extends SemanticQueryReq {
|
||||
|
||||
private List<SchemaElement> dimensions = new ArrayList<>();
|
||||
private List<String> groups = new ArrayList<>();
|
||||
private List<Aggregator> aggregators = new ArrayList<>();
|
||||
private List<Order> orders = new ArrayList<>();
|
||||
@@ -49,7 +51,7 @@ public class QueryStructReq extends SemanticQueryReq {
|
||||
private List<Filter> metricFilters = new ArrayList<>();
|
||||
private DateConf dateInfo;
|
||||
private long limit = Constants.DEFAULT_DETAIL_LIMIT;
|
||||
private QueryType queryType = QueryType.ID;
|
||||
private QueryType queryType = QueryType.DETAIL;
|
||||
private boolean convertToSql = true;
|
||||
|
||||
public List<String> getGroups() {
|
||||
@@ -186,36 +188,31 @@ public class QueryStructReq extends SemanticQueryReq {
|
||||
List<Aggregator> aggregators = queryStructReq.getAggregators();
|
||||
if (!CollectionUtils.isEmpty(aggregators)) {
|
||||
for (Aggregator aggregator : aggregators) {
|
||||
selectItems.add(buildAggregatorSelectItem(aggregator, queryStructReq));
|
||||
selectItems.add(buildAggregatorSelectItem(aggregator));
|
||||
}
|
||||
}
|
||||
|
||||
return selectItems;
|
||||
}
|
||||
|
||||
private SelectItem buildAggregatorSelectItem(Aggregator aggregator,
|
||||
QueryStructReq queryStructReq) {
|
||||
private SelectItem buildAggregatorSelectItem(Aggregator aggregator) {
|
||||
String columnName = aggregator.getColumn();
|
||||
if (queryStructReq.getQueryType().isNativeAggQuery()) {
|
||||
return new SelectItem(new Column(columnName));
|
||||
} else {
|
||||
Function function = new Function();
|
||||
AggOperatorEnum func = aggregator.getFunc();
|
||||
if (AggOperatorEnum.UNKNOWN.equals(func)) {
|
||||
func = AggOperatorEnum.SUM;
|
||||
}
|
||||
function.setName(func.getOperator());
|
||||
if (AggOperatorEnum.COUNT_DISTINCT.equals(func)) {
|
||||
function.setName("count");
|
||||
function.setDistinct(true);
|
||||
}
|
||||
function.setParameters(new ExpressionList(new Column(columnName)));
|
||||
SelectItem selectExpressionItem = new SelectItem(function);
|
||||
String alias = StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias()
|
||||
: columnName;
|
||||
selectExpressionItem.setAlias(new Alias(alias));
|
||||
return selectExpressionItem;
|
||||
Function function = new Function();
|
||||
AggOperatorEnum func = aggregator.getFunc();
|
||||
if (AggOperatorEnum.UNKNOWN.equals(func)) {
|
||||
func = AggOperatorEnum.SUM;
|
||||
}
|
||||
function.setName(func.getOperator());
|
||||
if (AggOperatorEnum.COUNT_DISTINCT.equals(func)) {
|
||||
function.setName("count");
|
||||
function.setDistinct(true);
|
||||
}
|
||||
function.setParameters(new ExpressionList(new Column(columnName)));
|
||||
SelectItem selectExpressionItem = new SelectItem(function);
|
||||
String alias =
|
||||
StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias() : columnName;
|
||||
selectExpressionItem.setAlias(new Alias(alias));
|
||||
return selectExpressionItem;
|
||||
}
|
||||
|
||||
private List<OrderByElement> buildOrderByElements(QueryStructReq queryStructReq) {
|
||||
@@ -239,7 +236,7 @@ public class QueryStructReq extends SemanticQueryReq {
|
||||
|
||||
private GroupByElement buildGroupByElement(QueryStructReq queryStructReq) {
|
||||
List<String> groups = queryStructReq.getGroups();
|
||||
if (!CollectionUtils.isEmpty(groups) && !queryStructReq.getQueryType().isNativeAggQuery()) {
|
||||
if (!CollectionUtils.isEmpty(groups) && !queryStructReq.getAggregators().isEmpty()) {
|
||||
GroupByElement groupByElement = new GroupByElement();
|
||||
for (String group : groups) {
|
||||
groupByElement.addGroupByExpression(new Column(group));
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.request;
|
||||
|
||||
import javax.validation.constraints.NotNull;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.PageBaseReq;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @author: kanedai
|
||||
* @date: 2024/11/24
|
||||
*/
|
||||
@Data
|
||||
public class ValueTaskQueryReq extends PageBaseReq {
|
||||
|
||||
@NotNull
|
||||
private Long itemId;
|
||||
|
||||
private List<String> taskStatusList;
|
||||
|
||||
private String key;
|
||||
}
|
||||
@@ -19,19 +19,4 @@ public class DataSetSchemaResp extends DataSetResp {
|
||||
private List<ModelResp> modelResps = Lists.newArrayList();
|
||||
private List<TermResp> termResps = Lists.newArrayList();
|
||||
|
||||
public DimSchemaResp getPrimaryKey() {
|
||||
for (ModelResp modelResp : modelResps) {
|
||||
Identify identify = modelResp.getPrimaryIdentify();
|
||||
if (identify == null) {
|
||||
continue;
|
||||
}
|
||||
for (DimSchemaResp dimension : dimensions) {
|
||||
if (identify.getBizName().equals(dimension.getBizName())) {
|
||||
dimension.setEntityAlias(identify.getEntityNames());
|
||||
return dimension;
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,5 +11,4 @@ public class DimSchemaResp extends DimensionResp {
|
||||
|
||||
private Long useCnt = 0L;
|
||||
|
||||
private List<String> entityAlias;
|
||||
}
|
||||
|
||||
@@ -62,21 +62,6 @@ public class ModelResp extends SchemaItem {
|
||||
return isOpen != null && isOpen == 1;
|
||||
}
|
||||
|
||||
public Identify getPrimaryIdentify() {
|
||||
if (modelDetail == null) {
|
||||
return null;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(modelDetail.getIdentifiers())) {
|
||||
return null;
|
||||
}
|
||||
for (Identify identify : modelDetail.getIdentifiers()) {
|
||||
if (!CollectionUtils.isEmpty(identify.getEntityNames())) {
|
||||
return identify;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public List<Dim> getTimeDimension() {
|
||||
if (modelDetail == null) {
|
||||
return Lists.newArrayList();
|
||||
|
||||
@@ -7,14 +7,11 @@ import com.tencent.supersonic.headless.api.pojo.enums.SchemaType;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.Constants.UNDERLINE;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
@@ -32,13 +29,6 @@ public class SemanticSchemaResp {
|
||||
private DatabaseResp databaseResp;
|
||||
private QueryType queryType;
|
||||
|
||||
public String getSchemaKey() {
|
||||
if (dataSetId == null) {
|
||||
return String.format("%s_%s", schemaType, StringUtils.join(modelIds, UNDERLINE));
|
||||
}
|
||||
return String.format("%s_%s", schemaType, dataSetId);
|
||||
}
|
||||
|
||||
public MetricSchemaResp getMetric(String bizName) {
|
||||
return metrics.stream().filter(metric -> bizName.equalsIgnoreCase(metric.getBizName()))
|
||||
.findFirst().orElse(null);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.chat;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
@@ -41,6 +42,14 @@ public class ChatQueryContext implements Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
public boolean needSQL() {
|
||||
return !request.getText2SQLType().equals(Text2SQLType.NONE);
|
||||
}
|
||||
|
||||
public DataSetSchema getDataSetSchema(Long dataSetId) {
|
||||
return semanticSchema.getDataSetSchema(dataSetId);
|
||||
}
|
||||
|
||||
public List<SemanticQuery> getCandidateQueries() {
|
||||
candidateQueries = candidateQueries.stream()
|
||||
.sorted(Comparator.comparing(
|
||||
|
||||
@@ -10,7 +10,6 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -31,7 +30,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
try {
|
||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectedS2SQL())) {
|
||||
String s2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
if (Objects.isNull(s2SQL)) {
|
||||
return;
|
||||
}
|
||||
doCorrect(chatQueryContext, semanticParseInfo);
|
||||
|
||||
@@ -36,8 +36,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
|
||||
+ "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard."
|
||||
+ "\n3.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
|
||||
+ "\n4.ALWAYS use `with` statement if nested aggregation is needed."
|
||||
+ "\n5.ALWAYS enclose alias created by `AS` command in underscores."
|
||||
+ "\n6.ALWAYS translate alias created by `AS` command to the same language as the `#Question`."
|
||||
+ "\n5.ALWAYS enclose alias declared by `AS` command in underscores."
|
||||
+ "\n6.Alias created by `AS` command must be in the same language ast the `Question`."
|
||||
+ "\n#Question:{{question}} #InputSQL:{{sql}} #Response:";
|
||||
|
||||
public LLMSqlCorrector() {
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -18,9 +14,7 @@ import org.springframework.util.CollectionUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/** Perform SQL corrections on the "Select" section in S2SQL. */
|
||||
@Slf4j
|
||||
@@ -42,13 +36,11 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
&& aggregateFields.size() == selectFields.size()) {
|
||||
return;
|
||||
}
|
||||
correctS2SQL = addFieldsToSelect(chatQueryContext, semanticParseInfo, correctS2SQL);
|
||||
correctS2SQL = addFieldsToSelect(semanticParseInfo, correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
protected String addFieldsToSelect(ChatQueryContext chatQueryContext,
|
||||
SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||
correctS2SQL = addTagDefaultFields(chatQueryContext, semanticParseInfo, correctS2SQL);
|
||||
protected String addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||
|
||||
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
|
||||
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
|
||||
@@ -70,34 +62,4 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
return addFieldsToSelectSql;
|
||||
}
|
||||
|
||||
private String addTagDefaultFields(ChatQueryContext chatQueryContext,
|
||||
SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||
// If it is in DETAIL mode and select *, add default metrics and dimensions.
|
||||
boolean hasAsterisk = SqlSelectFunctionHelper.hasAsterisk(correctS2SQL);
|
||||
if (!(hasAsterisk && QueryType.DETAIL.equals(semanticParseInfo.getQueryType()))) {
|
||||
return correctS2SQL;
|
||||
}
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
DataSetSchema dataSetSchema =
|
||||
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
Set<String> needAddDefaultFields = new HashSet<>();
|
||||
if (Objects.nonNull(dataSetSchema)) {
|
||||
if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultMetrics())) {
|
||||
Set<String> metrics = dataSetSchema.getTagDefaultMetrics().stream()
|
||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||
needAddDefaultFields.addAll(metrics);
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultDimensions())) {
|
||||
Set<String> dimensions = dataSetSchema.getTagDefaultDimensions().stream()
|
||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||
needAddDefaultFields.addAll(dimensions);
|
||||
}
|
||||
}
|
||||
// remove * in sql and add default fields.
|
||||
if (!CollectionUtils.isEmpty(needAddDefaultFields)) {
|
||||
correctS2SQL =
|
||||
SqlRemoveHelper.removeAsteriskAndAddFields(correctS2SQL, needAddDefaultFields);
|
||||
}
|
||||
return correctS2SQL;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,10 +10,9 @@ import org.springframework.stereotype.Service;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/** model word nature */
|
||||
@Service
|
||||
@Slf4j
|
||||
public class ModelWordBuilder extends BaseWordWithAliasBuilder {
|
||||
public class DataSetWordBuilder extends BaseWordWithAliasBuilder {
|
||||
|
||||
@Override
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
@@ -1,37 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.builder;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class EntityWordBuilder extends BaseWordWithAliasBuilder {
|
||||
|
||||
@Override
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
if (Objects.isNull(schemaElement)) {
|
||||
return result;
|
||||
}
|
||||
result.add(getOneWordNature(word, schemaElement, false));
|
||||
result.addAll(getOneWordNatureAlias(schemaElement, false));
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
|
||||
String nature = DictWordType.NATURE_SPILT + schemaElement.getModel()
|
||||
+ DictWordType.NATURE_SPILT + schemaElement.getId() + DictWordType.ENTITY.getType();
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY * 2, nature));
|
||||
return dictWord;
|
||||
}
|
||||
}
|
||||
@@ -13,8 +13,7 @@ public class WordBuilderFactory {
|
||||
static {
|
||||
wordNatures.put(DictWordType.DIMENSION, new DimensionWordBuilder());
|
||||
wordNatures.put(DictWordType.METRIC, new MetricWordBuilder());
|
||||
wordNatures.put(DictWordType.DATASET, new ModelWordBuilder());
|
||||
wordNatures.put(DictWordType.ENTITY, new EntityWordBuilder());
|
||||
wordNatures.put(DictWordType.DATASET, new DataSetWordBuilder());
|
||||
wordNatures.put(DictWordType.VALUE, new ValueWordBuilder());
|
||||
wordNatures.put(DictWordType.TERM, new TermWordBuilder());
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ public class HanlpHelper {
|
||||
HanLP.Config.CustomDictionaryPath);
|
||||
|
||||
HanLP.Config.CoreDictionaryPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.BiGramDictionaryPath;
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CoreDictionaryPath;
|
||||
HanLP.Config.CoreDictionaryTransformMatrixDictionaryPath = hanlpPropertiesPath + FILE_SPILT
|
||||
+ HanLP.Config.CoreDictionaryTransformMatrixDictionaryPath;
|
||||
HanLP.Config.BiGramDictionaryPath =
|
||||
|
||||
@@ -18,7 +18,9 @@ import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/** nature parse helper */
|
||||
/**
|
||||
* nature parse helper
|
||||
*/
|
||||
@Slf4j
|
||||
public class NatureHelper {
|
||||
|
||||
@@ -35,9 +37,6 @@ public class NatureHelper {
|
||||
case DIMENSION:
|
||||
result = SchemaElementType.DIMENSION;
|
||||
break;
|
||||
case ENTITY:
|
||||
result = SchemaElementType.ENTITY;
|
||||
break;
|
||||
case DATASET:
|
||||
result = SchemaElementType.DATASET;
|
||||
break;
|
||||
@@ -53,10 +52,9 @@ public class NatureHelper {
|
||||
return result;
|
||||
}
|
||||
|
||||
private static boolean isDataSetOrEntity(S2Term term, Integer model) {
|
||||
private static boolean isDataSet(S2Term term, Integer model) {
|
||||
String natureStr = term.nature.toString();
|
||||
return (DictWordType.NATURE_SPILT + model).equals(natureStr)
|
||||
|| natureStr.endsWith(DictWordType.ENTITY.getType());
|
||||
return (DictWordType.NATURE_SPILT + model).equals(natureStr);
|
||||
}
|
||||
|
||||
public static Integer getDataSetByNature(Nature nature) {
|
||||
@@ -122,8 +120,8 @@ public class NatureHelper {
|
||||
}
|
||||
|
||||
private static long getDataSetCount(List<S2Term> terms) {
|
||||
return terms.stream()
|
||||
.filter(term -> isDataSetOrEntity(term, getDataSetByNature(term.nature))).count();
|
||||
return terms.stream().filter(term -> isDataSet(term, getDataSetByNature(term.nature)))
|
||||
.count();
|
||||
}
|
||||
|
||||
private static long getDimensionValueCount(List<S2Term> terms) {
|
||||
|
||||
@@ -20,22 +20,25 @@ import java.util.Objects;
|
||||
*/
|
||||
@Slf4j
|
||||
public class EmbeddingMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(ChatQueryContext chatQueryContext) {
|
||||
// 1. query from embedding by queryText
|
||||
if (MapModeEnum.STRICT.equals(chatQueryContext.getRequest().getMapModeEnum())) {
|
||||
// Check if the map mode is LOOSE
|
||||
if (!MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum())) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 1. Query from embedding by queryText
|
||||
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||
List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy);
|
||||
|
||||
// Process match results
|
||||
HanlpHelper.transLetterOriginal(matchResults);
|
||||
|
||||
// 2. build SchemaElementMatch by info
|
||||
// 2. Build SchemaElementMatch based on match results
|
||||
for (EmbeddingResult matchResult : matchResults) {
|
||||
Long elementId = Retrieval.getLongId(matchResult.getId());
|
||||
Long dataSetId = Retrieval.getLongId(matchResult.getMetadata().get("dataSetId"));
|
||||
|
||||
// Skip if dataSetId is null
|
||||
if (Objects.isNull(dataSetId)) {
|
||||
continue;
|
||||
}
|
||||
@@ -43,14 +46,19 @@ public class EmbeddingMapper extends BaseMapper {
|
||||
SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
|
||||
SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId,
|
||||
chatQueryContext.getSemanticSchema());
|
||||
|
||||
// Skip if schemaElement is null
|
||||
if (schemaElement == null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Build SchemaElementMatch object
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(schemaElement).frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||
.word(matchResult.getName()).similarity(matchResult.getSimilarity())
|
||||
.detectWord(matchResult.getDetectWord()).build();
|
||||
// 3. add to mapInfo
|
||||
|
||||
// 3. Add SchemaElementMatch to mapInfo
|
||||
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/** A mapper capable of converting the VALUE of entity dimension values into ID types. */
|
||||
@Slf4j
|
||||
public class EntityMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(ChatQueryContext chatQueryContext) {
|
||||
SchemaMapInfo schemaMapInfo = chatQueryContext.getMapInfo();
|
||||
for (Long dataSetId : schemaMapInfo.getMatchedDataSetInfos()) {
|
||||
List<SchemaElementMatch> schemaElementMatchList =
|
||||
schemaMapInfo.getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElement entity = getEntity(dataSetId, chatQueryContext);
|
||||
if (entity == null || entity.getId() == null) {
|
||||
continue;
|
||||
}
|
||||
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
|
||||
.filter(schemaElementMatch -> SchemaElementType.VALUE
|
||||
.equals(schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
|
||||
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
|
||||
continue;
|
||||
}
|
||||
if (!checkExistSameEntitySchemaElements(schemaElementMatch,
|
||||
schemaElementMatchList)) {
|
||||
SchemaElementMatch entitySchemaElementMath = new SchemaElementMatch();
|
||||
BeanUtils.copyProperties(schemaElementMatch, entitySchemaElementMath);
|
||||
entitySchemaElementMath.setElement(entity);
|
||||
schemaElementMatchList.add(entitySchemaElementMath);
|
||||
}
|
||||
schemaElementMatch.getElement().setType(SchemaElementType.ID);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch,
|
||||
List<SchemaElementMatch> schemaElementMatchList) {
|
||||
List<SchemaElementMatch> entitySchemaElements = schemaElementMatchList.stream()
|
||||
.filter(schemaElementMatch -> SchemaElementType.ENTITY
|
||||
.equals(schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : entitySchemaElements) {
|
||||
if (schemaElementMatch.getElement().getId()
|
||||
.equals(valueSchemaElementMatch.getElement().getId())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private SchemaElement getEntity(Long dataSetId, ChatQueryContext chatQueryContext) {
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
DataSetSchema modelSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
|
||||
if (modelSchema != null && modelSchema.getEntity() != null) {
|
||||
return modelSchema.getEntity();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -33,32 +33,32 @@ public class KeywordMapper extends BaseMapper {
|
||||
@Override
|
||||
public void doMap(ChatQueryContext chatQueryContext) {
|
||||
String queryText = chatQueryContext.getRequest().getQueryText();
|
||||
// 1.hanlpDict Match
|
||||
|
||||
// 1. hanlpDict Match
|
||||
List<S2Term> terms =
|
||||
HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
|
||||
HanlpDictMatchStrategy hanlpMatchStrategy =
|
||||
ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||
List<HanlpMapResult> hanlpMatchResults = getMatches(chatQueryContext, hanlpMatchStrategy);
|
||||
convertMapResultToMapInfo(hanlpMatchResults, chatQueryContext, terms);
|
||||
|
||||
List<HanlpMapResult> matchResults = getMatches(chatQueryContext, hanlpMatchStrategy);
|
||||
|
||||
convertHanlpMapResultToMapInfo(matchResults, chatQueryContext, terms);
|
||||
|
||||
// 2.database Match
|
||||
// 2. database Match
|
||||
DatabaseMatchStrategy databaseMatchStrategy =
|
||||
ContextUtils.getBean(DatabaseMatchStrategy.class);
|
||||
List<DatabaseMapResult> databaseResults =
|
||||
List<DatabaseMapResult> databaseMatchResults =
|
||||
getMatches(chatQueryContext, databaseMatchStrategy);
|
||||
convertDatabaseMapResultToMapInfo(chatQueryContext, databaseResults);
|
||||
convertMapResultToMapInfo(chatQueryContext, databaseMatchResults);
|
||||
}
|
||||
|
||||
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults,
|
||||
private void convertMapResultToMapInfo(List<HanlpMapResult> mapResults,
|
||||
ChatQueryContext chatQueryContext, List<S2Term> terms) {
|
||||
if (CollectionUtils.isEmpty(mapResults)) {
|
||||
return;
|
||||
}
|
||||
|
||||
HanlpHelper.transLetterOriginal(mapResults);
|
||||
Map<String, Long> wordNatureToFrequency = terms.stream()
|
||||
.collect(Collectors.toMap(entry -> entry.getWord() + entry.getNature(),
|
||||
Map<String, Long> wordNatureToFrequency =
|
||||
terms.stream().collect(Collectors.toMap(term -> term.getWord() + term.getNature(),
|
||||
term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2));
|
||||
|
||||
for (HanlpMapResult hanlpMapResult : mapResults) {
|
||||
@@ -74,9 +74,10 @@ public class KeywordMapper extends BaseMapper {
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
SchemaElement element = getSchemaElement(dataSetId, elementType, elementID,
|
||||
chatQueryContext.getSemanticSchema());
|
||||
if (element == null) {
|
||||
if (Objects.isNull(element)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element).frequency(frequency).word(hanlpMapResult.getName())
|
||||
@@ -88,7 +89,7 @@ public class KeywordMapper extends BaseMapper {
|
||||
}
|
||||
}
|
||||
|
||||
private void convertDatabaseMapResultToMapInfo(ChatQueryContext chatQueryContext,
|
||||
private void convertMapResultToMapInfo(ChatQueryContext chatQueryContext,
|
||||
List<DatabaseMapResult> mapResults) {
|
||||
for (DatabaseMapResult match : mapResults) {
|
||||
SchemaElement schemaElement = match.getSchemaElement();
|
||||
|
||||
@@ -94,9 +94,8 @@ public class MapFilter {
|
||||
SchemaElement element = schemaElementMatch.getElement();
|
||||
SchemaElementType type = element.getType();
|
||||
|
||||
boolean isEntityOrDatasetOrId = SchemaElementType.ENTITY.equals(type)
|
||||
|| SchemaElementType.DATASET.equals(type)
|
||||
|| SchemaElementType.ID.equals(type);
|
||||
boolean isEntityOrDatasetOrId =
|
||||
SchemaElementType.DATASET.equals(type) || SchemaElementType.ID.equals(type);
|
||||
|
||||
return !isEntityOrDatasetOrId && needRemovePredicate.test(element);
|
||||
});
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.SearchService;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -17,7 +15,6 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* SearchMatchStrategy encapsulates a concrete matching algorithm executed during search process.
|
||||
@@ -66,16 +63,6 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
knowledgeBaseService.suffixSearch(detectSegment, SEARCH_SIZE,
|
||||
chatQueryContext.getModelIdToDataSetIds(), detectDataSetIds);
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
// remove entity name where search
|
||||
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||
List<String> natures = entry.getNatures().stream()
|
||||
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType()))
|
||||
.collect(Collectors.toList());
|
||||
if (CollectionUtils.isEmpty(natures)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}).collect(Collectors.toList());
|
||||
MatchText matchText =
|
||||
MatchText.builder().regText(regText).detectSegment(detectSegment).build();
|
||||
regTextMap.put(matchText, hanlpMapResults);
|
||||
|
||||
@@ -20,7 +20,8 @@ public class TermDescMapper extends BaseMapper {
|
||||
public void doMap(ChatQueryContext chatQueryContext) {
|
||||
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
|
||||
List<SchemaElement> termElements = mapInfo.getTermDescriptionToMap();
|
||||
if (CollectionUtils.isEmpty(termElements)) {
|
||||
if (CollectionUtils.isEmpty(termElements)
|
||||
|| chatQueryContext.getRequest().isDescriptionMapped()) {
|
||||
return;
|
||||
}
|
||||
for (SchemaElement schemaElement : termElements) {
|
||||
@@ -39,6 +40,7 @@ public class TermDescMapper extends BaseMapper {
|
||||
queryContext.setSemanticSchema(chatQueryContext.getSemanticSchema());
|
||||
queryContext.setModelIdToDataSetIds(chatQueryContext.getModelIdToDataSetIds());
|
||||
queryContext.setChatWorkflowState(chatQueryContext.getChatWorkflowState());
|
||||
queryContext.getRequest().setDescriptionMapped(true);
|
||||
return queryContext;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,95 +1,28 @@
|
||||
package com.tencent.supersonic.headless.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/** QueryTypeParser resolves query type as either AGGREGATE or DETAIL or ID. */
|
||||
/** QueryTypeParser resolves query type as either AGGREGATE or DETAIL */
|
||||
@Slf4j
|
||||
public class QueryTypeParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
chatQueryContext.getCandidateQueries().forEach(query -> {
|
||||
SemanticParseInfo parseInfo = query.getParseInfo();
|
||||
String s2SQL = parseInfo.getSqlInfo().getParsedS2SQL();
|
||||
QueryType queryType = QueryType.DETAIL;
|
||||
|
||||
List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries();
|
||||
User user = chatQueryContext.getRequest().getUser();
|
||||
|
||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||
// 1.init S2SQL
|
||||
Long dataSetId = semanticQuery.getParseInfo().getDataSetId();
|
||||
DataSetSchema dataSetSchema =
|
||||
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
semanticQuery.initS2Sql(dataSetSchema, user);
|
||||
// 2.set queryType
|
||||
QueryType queryType = getQueryType(chatQueryContext, semanticQuery);
|
||||
semanticQuery.getParseInfo().setQueryType(queryType);
|
||||
}
|
||||
}
|
||||
|
||||
private QueryType getQueryType(ChatQueryContext chatQueryContext, SemanticQuery semanticQuery) {
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getParsedS2SQL())) {
|
||||
return QueryType.DETAIL;
|
||||
}
|
||||
|
||||
// 1. entity queryType
|
||||
Long dataSetId = parseInfo.getDataSetId();
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getParsedS2SQL());
|
||||
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
|
||||
if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) {
|
||||
Set<String> ids = semanticSchema.getEntities(dataSetId).stream()
|
||||
.map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(ids)
|
||||
&& ids.stream().anyMatch(whereFilterByTimeFields::contains)) {
|
||||
return QueryType.ID;
|
||||
}
|
||||
if (SqlSelectFunctionHelper.hasAggregateFunction(s2SQL)) {
|
||||
queryType = QueryType.AGGREGATE;
|
||||
}
|
||||
}
|
||||
|
||||
// 2. AGG queryType
|
||||
if (SqlSelectFunctionHelper.hasAggregateFunction(sqlInfo.getParsedS2SQL())) {
|
||||
return QueryType.AGGREGATE;
|
||||
}
|
||||
|
||||
return QueryType.DETAIL;
|
||||
parseInfo.setQueryType(queryType);
|
||||
});
|
||||
}
|
||||
|
||||
private static List<String> filterByTimeFields(List<String> whereFields) {
|
||||
return whereFields.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId,
|
||||
SemanticSchema semanticSchema) {
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL());
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet =
|
||||
metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
return selectFields.stream().anyMatch(metricNameSet::contains);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,6 +49,7 @@ public class LLMResponseService {
|
||||
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL);
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
}
|
||||
|
||||
|
||||
@@ -36,15 +36,13 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
+ "\n#Task: You will be provided with a natural language question asked by users,"
|
||||
+ "please convert it to a SQL query so that relevant data could be returned "
|
||||
+ "by executing the SQL query against underlying database." + "\n#Rules:"
|
||||
+ "\n1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate."
|
||||
+ "\n2.ALWAYS be cautious, word in the `Schema` does not mean it must appear in the SQL."
|
||||
+ "\n3.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
|
||||
+ "\n4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
||||
+ "\n5.DO NOT calculate date range using functions."
|
||||
+ "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
|
||||
+ "\n7.ALWAYS use `with` statement if nested aggregation is needed."
|
||||
+ "\n8.ALWAYS enclose alias created by `AS` command in underscores."
|
||||
+ "\n9.ALWAYS translate alias created by `AS` command to the same language as the `#Question`."
|
||||
+ "\n1.SQL columns and values must be mentioned in the `Schema`, DO NOT hallucinate."
|
||||
+ "\n2.ALWAYS specify time range using `>`,`<`,`>=`,`<=` operator."
|
||||
+ "\n3.DO NOT include time range in the where clause if not explicitly expressed in the `Question`."
|
||||
+ "\n4.DO NOT calculate date range using functions."
|
||||
+ "\n5.ALWAYS use `with` statement if nested aggregation is needed."
|
||||
+ "\n6.ALWAYS enclose alias declared by `AS` command in underscores."
|
||||
+ "\n7.Alias created by `AS` command must be in the same language ast the `Question`."
|
||||
+ "\n#Exemplars: {{exemplar}}"
|
||||
+ "\n#Query: Question:{{question}},Schema:{{schema}},SideInfo:{{information}}";
|
||||
|
||||
|
||||
@@ -34,15 +34,13 @@ public class RuleSqlParser implements SemanticParser {
|
||||
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
|
||||
List<RuleSemanticQuery> queries =
|
||||
RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(chatQueryContext);
|
||||
chatQueryContext.getCandidateQueries().add(query);
|
||||
}
|
||||
candidateQueries.addAll(chatQueryContext.getCandidateQueries());
|
||||
chatQueryContext.getCandidateQueries().clear();
|
||||
candidateQueries.addAll(queries);
|
||||
}
|
||||
chatQueryContext.setCandidateQueries(candidateQueries);
|
||||
|
||||
auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
|
||||
|
||||
candidateQueries.forEach(query -> query.buildS2Sql(
|
||||
chatQueryContext.getDataSetSchema(query.getParseInfo().getDataSetId())));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,87 +1,24 @@
|
||||
package com.tencent.supersonic.headless.chat.query;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.Filter;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@ToString
|
||||
@Data
|
||||
public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
|
||||
protected SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
|
||||
@Override
|
||||
public SemanticParseInfo getParseInfo() {
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setParseInfo(SemanticParseInfo parseInfo) {
|
||||
this.parseInfo = parseInfo;
|
||||
}
|
||||
|
||||
protected QueryStructReq convertQueryStruct() {
|
||||
return QueryReqBuilder.buildStructReq(parseInfo);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SemanticQueryReq buildSemanticQueryReq() {
|
||||
return QueryReqBuilder.buildS2SQLReq(parseInfo.getSqlInfo(), parseInfo.getDataSetId());
|
||||
}
|
||||
|
||||
protected void initS2SqlByStruct(DataSetSchema dataSetSchema) {
|
||||
QueryStructReq queryStructReq = convertQueryStruct();
|
||||
convertBizNameToName(dataSetSchema, queryStructReq);
|
||||
QuerySqlReq querySQLReq = queryStructReq.convert();
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
|
||||
}
|
||||
|
||||
protected void convertBizNameToName(DataSetSchema dataSetSchema,
|
||||
QueryStructReq queryStructReq) {
|
||||
Map<String, String> bizNameToName = dataSetSchema.getBizNameToName();
|
||||
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
|
||||
|
||||
List<Order> orders = queryStructReq.getOrders();
|
||||
if (CollectionUtils.isNotEmpty(orders)) {
|
||||
for (Order order : orders) {
|
||||
order.setColumn(bizNameToName.get(order.getColumn()));
|
||||
}
|
||||
}
|
||||
List<Aggregator> aggregators = queryStructReq.getAggregators();
|
||||
if (CollectionUtils.isNotEmpty(aggregators)) {
|
||||
for (Aggregator aggregator : aggregators) {
|
||||
aggregator.setColumn(bizNameToName.get(aggregator.getColumn()));
|
||||
}
|
||||
}
|
||||
List<String> groups = queryStructReq.getGroups();
|
||||
if (CollectionUtils.isNotEmpty(groups)) {
|
||||
groups = groups.stream().map(bizNameToName::get).collect(Collectors.toList());
|
||||
queryStructReq.setGroups(groups);
|
||||
}
|
||||
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
|
||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||
dimensionFilters
|
||||
.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
}
|
||||
List<Filter> metricFilters = queryStructReq.getMetricFilters();
|
||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||
metricFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.query;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
@@ -13,7 +12,7 @@ public interface SemanticQuery {
|
||||
|
||||
SemanticQueryReq buildSemanticQueryReq() throws SqlParseException;
|
||||
|
||||
void initS2Sql(DataSetSchema dataSetSchema, User user);
|
||||
void buildS2Sql(DataSetSchema dataSetSchema);
|
||||
|
||||
SemanticParseInfo getParseInfo();
|
||||
|
||||
|
||||
@@ -3,16 +3,17 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
import com.fasterxml.jackson.annotation.JsonValue;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
@@ -45,22 +46,23 @@ public class LLMReq {
|
||||
private SchemaElement primaryKey;
|
||||
|
||||
public List<String> getFieldNameList() {
|
||||
List<String> fieldNameList = new ArrayList<>();
|
||||
Set<String> fieldNameList = new HashSet<>();
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
fieldNameList.addAll(metrics.stream().map(metric -> metric.getName())
|
||||
.collect(Collectors.toList()));
|
||||
fieldNameList.addAll(
|
||||
metrics.stream().map(SchemaElement::getName).collect(Collectors.toList()));
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(dimensions)) {
|
||||
fieldNameList.addAll(dimensions.stream().map(dimension -> dimension.getName())
|
||||
fieldNameList.addAll(dimensions.stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(values)) {
|
||||
fieldNameList.addAll(values.stream().map(ElementValue::getFieldName)
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
if (Objects.nonNull(partitionTime)) {
|
||||
fieldNameList.add(partitionTime.getName());
|
||||
}
|
||||
if (Objects.nonNull(primaryKey)) {
|
||||
fieldNameList.add(primaryKey.getName());
|
||||
}
|
||||
return fieldNameList;
|
||||
return new ArrayList<>(fieldNameList);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +76,7 @@ public class LLMReq {
|
||||
public enum SqlGenType {
|
||||
ONE_PASS_SELF_CONSISTENCY("1_pass_self_consistency");
|
||||
|
||||
private String name;
|
||||
private final String name;
|
||||
|
||||
SqlGenType(String name) {
|
||||
this.name = name;
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
@@ -24,7 +23,7 @@ public class LLMSqlQuery extends LLMSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initS2Sql(DataSetSchema dataSetSchema, User user) {
|
||||
public void buildS2Sql(DataSetSchema dataSetSchema) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL());
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.Filter;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
@@ -10,6 +13,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
@@ -18,16 +22,16 @@ import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.TERM;
|
||||
@@ -51,14 +55,24 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initS2Sql(DataSetSchema dataSetSchema, User user) {
|
||||
initS2SqlByStruct(dataSetSchema);
|
||||
public void buildS2Sql(DataSetSchema dataSetSchema) {
|
||||
QueryStructReq queryStructReq = convertQueryStruct();
|
||||
convertBizNameToName(dataSetSchema, queryStructReq);
|
||||
QuerySqlReq querySQLReq = queryStructReq.convert();
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
|
||||
}
|
||||
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
parseInfo.setQueryMode(getQueryMode());
|
||||
protected QueryStructReq convertQueryStruct() {
|
||||
return QueryReqBuilder.buildStructReq(parseInfo);
|
||||
}
|
||||
|
||||
protected void fillParseInfo(ChatQueryContext chatQueryContext, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
|
||||
parseInfo.setQueryMode(getQueryMode());
|
||||
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId));
|
||||
parseInfo.setQueryConfig(semanticSchema.getQueryConfig(dataSetId));
|
||||
fillSchemaElement(parseInfo, semanticSchema);
|
||||
fillScore(parseInfo);
|
||||
fillDateConfByInherited(parseInfo, chatQueryContext);
|
||||
@@ -111,31 +125,13 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
}
|
||||
|
||||
private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) {
|
||||
Set<Long> dataSetIds =
|
||||
parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement)
|
||||
.map(SchemaElement::getDataSetId).collect(Collectors.toSet());
|
||||
Long dataSetId = dataSetIds.iterator().next();
|
||||
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId));
|
||||
parseInfo.setQueryConfig(semanticSchema.getQueryConfig(dataSetId));
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();
|
||||
Map<Long, List<SchemaElementMatch>> id2Values = new HashMap<>();
|
||||
|
||||
for (SchemaElementMatch schemaMatch : parseInfo.getElementMatches()) {
|
||||
SchemaElement element = schemaMatch.getElement();
|
||||
element.setOrder(1 - schemaMatch.getSimilarity());
|
||||
switch (element.getType()) {
|
||||
case ID:
|
||||
SchemaElement entityElement =
|
||||
semanticSchema.getElement(SchemaElementType.ENTITY, element.getId());
|
||||
if (entityElement != null) {
|
||||
if (id2Values.containsKey(element.getId())) {
|
||||
id2Values.get(element.getId()).add(schemaMatch);
|
||||
} else {
|
||||
id2Values.put(element.getId(),
|
||||
new ArrayList<>(Arrays.asList(schemaMatch)));
|
||||
}
|
||||
}
|
||||
break;
|
||||
case VALUE:
|
||||
SchemaElement dimElement =
|
||||
semanticSchema.getElement(SchemaElementType.DIMENSION, element.getId());
|
||||
@@ -144,7 +140,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
dim2Values.get(element.getId()).add(schemaMatch);
|
||||
} else {
|
||||
dim2Values.put(element.getId(),
|
||||
new ArrayList<>(Arrays.asList(schemaMatch)));
|
||||
new ArrayList<>(Collections.singletonList(schemaMatch)));
|
||||
}
|
||||
}
|
||||
break;
|
||||
@@ -154,23 +150,20 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
case METRIC:
|
||||
parseInfo.getMetrics().add(element);
|
||||
break;
|
||||
case ENTITY:
|
||||
parseInfo.setEntity(element);
|
||||
break;
|
||||
default:
|
||||
}
|
||||
}
|
||||
addToFilters(id2Values, parseInfo, semanticSchema, SchemaElementType.ENTITY);
|
||||
addToFilters(dim2Values, parseInfo, semanticSchema, SchemaElementType.DIMENSION);
|
||||
}
|
||||
|
||||
private void addToFilters(Map<Long, List<SchemaElementMatch>> id2Values,
|
||||
SemanticParseInfo parseInfo, SemanticSchema semanticSchema, SchemaElementType entity) {
|
||||
SemanticParseInfo parseInfo, SemanticSchema semanticSchema,
|
||||
SchemaElementType elementType) {
|
||||
if (id2Values == null || id2Values.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
for (Entry<Long, List<SchemaElementMatch>> entry : id2Values.entrySet()) {
|
||||
SchemaElement dimension = semanticSchema.getElement(entity, entry.getKey());
|
||||
SchemaElement dimension = semanticSchema.getElement(elementType, entry.getKey());
|
||||
if (dimension.isPartitionTime()) {
|
||||
continue;
|
||||
}
|
||||
@@ -182,13 +175,11 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
dimensionFilter.setName(dimension.getName());
|
||||
dimensionFilter.setOperator(FilterOperatorEnum.EQUALS);
|
||||
dimensionFilter.setElementID(schemaMatch.getElement().getId());
|
||||
parseInfo.setEntity(
|
||||
semanticSchema.getElement(SchemaElementType.ENTITY, entry.getKey()));
|
||||
parseInfo.getDimensionFilters().add(dimensionFilter);
|
||||
} else {
|
||||
QueryFilter dimensionFilter = new QueryFilter();
|
||||
List<String> values = new ArrayList<>();
|
||||
entry.getValue().stream().forEach(i -> values.add(i.getWord()));
|
||||
entry.getValue().forEach(i -> values.add(i.getWord()));
|
||||
dimensionFilter.setValue(values);
|
||||
dimensionFilter.setBizName(dimension.getBizName());
|
||||
dimensionFilter.setName(dimension.getName());
|
||||
@@ -216,33 +207,60 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
return convertQueryMultiStruct();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setParseInfo(SemanticParseInfo parseInfo) {
|
||||
this.parseInfo = parseInfo;
|
||||
}
|
||||
|
||||
public static List<RuleSemanticQuery> resolve(Long dataSetId,
|
||||
List<SchemaElementMatch> candidateElementMatches, ChatQueryContext chatQueryContext) {
|
||||
List<RuleSemanticQuery> matchedQueries = new ArrayList<>();
|
||||
|
||||
for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) {
|
||||
List<SchemaElementMatch> matches =
|
||||
semanticQuery.match(candidateElementMatches, chatQueryContext);
|
||||
|
||||
if (matches.size() > 0) {
|
||||
if (!matches.isEmpty()) {
|
||||
RuleSemanticQuery query =
|
||||
QueryManager.createRuleQuery(semanticQuery.getQueryMode());
|
||||
query.getParseInfo().getElementMatches().addAll(matches);
|
||||
query.fillParseInfo(chatQueryContext, dataSetId);
|
||||
matchedQueries.add(query);
|
||||
}
|
||||
}
|
||||
return matchedQueries;
|
||||
}
|
||||
|
||||
protected QueryStructReq convertQueryStruct() {
|
||||
return QueryReqBuilder.buildStructReq(parseInfo);
|
||||
}
|
||||
|
||||
protected QueryMultiStructReq convertQueryMultiStruct() {
|
||||
return QueryReqBuilder.buildMultiStructReq(parseInfo);
|
||||
}
|
||||
|
||||
|
||||
protected void convertBizNameToName(DataSetSchema dataSetSchema,
|
||||
QueryStructReq queryStructReq) {
|
||||
Map<String, String> bizNameToName = dataSetSchema.getBizNameToName();
|
||||
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
|
||||
|
||||
List<Order> orders = queryStructReq.getOrders();
|
||||
if (CollectionUtils.isNotEmpty(orders)) {
|
||||
for (Order order : orders) {
|
||||
order.setColumn(bizNameToName.get(order.getColumn()));
|
||||
}
|
||||
}
|
||||
List<Aggregator> aggregators = queryStructReq.getAggregators();
|
||||
if (CollectionUtils.isNotEmpty(aggregators)) {
|
||||
for (Aggregator aggregator : aggregators) {
|
||||
aggregator.setColumn(bizNameToName.get(aggregator.getColumn()));
|
||||
}
|
||||
}
|
||||
List<String> groups = queryStructReq.getGroups();
|
||||
if (CollectionUtils.isNotEmpty(groups)) {
|
||||
groups = groups.stream().map(bizNameToName::get).collect(Collectors.toList());
|
||||
queryStructReq.setGroups(groups);
|
||||
}
|
||||
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
|
||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||
dimensionFilters
|
||||
.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
}
|
||||
List<Filter> metricFilters = queryStructReq.getMetricFilters();
|
||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||
metricFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule.detail;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ENTITY;
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.VALUE;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class DetailFilterQuery extends DetailListQuery {
|
||||
|
||||
public static final String QUERY_MODE = "DETAIL_LIST_FILTER";
|
||||
|
||||
public DetailFilterQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1);
|
||||
queryMatcher.addOption(ENTITY, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule.detail;
|
||||
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
@Component
|
||||
public class DetailIdQuery extends DetailListQuery {
|
||||
|
||||
public static final String QUERY_MODE = "DETAIL_ID";
|
||||
|
||||
public DetailIdQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user