33 Commits

Author SHA1 Message Date
jerryjzhang
a2f54d4c80 [improvement][headless]Remove unnecessary Database class
Some checks are pending
supersonic CentOS CI / build (21) (push) Waiting to run
supersonic mac CI / build (21) (push) Waiting to run
supersonic ubuntu CI / build (21) (push) Waiting to run
supersonic windows CI / build (21) (push) Waiting to run
2024-12-29 21:58:39 +08:00
jerryjzhang
6486257c9e [improvement][chat&headless]Remove deprecated system time fields. 2024-12-29 03:22:42 +08:00
HB
6f5e477e3c [fix][queryStat] Fields that are inconsistent between the table Filed and DO (#1986) 2024-12-28 15:15:06 +08:00
yxm-coding
683f01c33b 修复Network中所有前端路由资源请求均报404错误 #1982 (#1985) 2024-12-27 20:54:53 +08:00
yxm-coding
3e1e5ae209 fix(launchers): update addViewController to correctly redirect to the front-end page when accessing the domain while logged in (#1981) 2024-12-27 15:35:50 +08:00
jerryjzhang
0612833618 [fix][chat]Fix logic in s2sql parsing. 2024-12-27 14:18:20 +08:00
lexluo09
a23d1071a3 [improvement][chat] Optimize the logic for obtaining the generic thread pool (#1979) 2024-12-26 23:37:55 +08:00
jerryjzhang
94267f6028 [improvement][chat]Introduce AllFieldMapper to increase parsing robustness when normal pipeline fails.
[improvement][chat]Introduce `AllFieldMapper` to increase parsing robustness when normal pipeline fails.
2024-12-26 23:20:43 +08:00
jerryjzhang
a4d2df4063 [improvement][project]Adapt docker related scripts to new version. 2024-12-26 15:02:28 +08:00
jerryjzhang
d04a086c88 [improvement][chat]Support reviewing query memory based on direct user feedback. 2024-12-26 09:47:13 +08:00
jerryjzhang
68963b9ec9 [improvement][project]Adjust files based on code style. 2024-12-26 09:12:12 +08:00
jerryjzhang
d40400d2a4 [fix][chat]Memory enabled by the review task should be stored in embedding store. 2024-12-26 00:11:12 +08:00
lexluo09
c483bb891a [fix][chat] Fix the issue with the order of parallel execution in the map. (#1976) 2024-12-25 20:52:52 +08:00
zehuiHuang
6738aba19e Issues 1974 (#1975)
* [fix][common]Support 'BETWEEN AND' query condition parameter parsing `CURRENT`. #1972
2024-12-25 19:33:40 +08:00
zehuiHuang
493a8035cd [fix][common]Support 'BETWEEN AND' query condition parameter parsing CURRENT. #1972 (#1973) 2024-12-25 19:33:26 +08:00
lwhy
b425c49c5b [fix][headless] Unexpected update of dimensions when modifying a model (#1971) 2024-12-25 19:33:10 +08:00
jerryjzhang
4dca6eec5a [improvement][project]Adjust github issue forms. 2024-12-23 20:24:35 +08:00
pisces
642d6a02e1 feat(chat-sdk/chatitem): 消息支持导出图表图片 (#1937) 2024-12-23 09:05:56 +08:00
jerryjzhang
5de5b0a5e2 [fix][headless]Fix issue in determining mysql version to support with statement. 2024-12-22 21:40:58 +08:00
lexluo09
8c6ae62522 [improvement][chat] Change the embedding to execute in parallel (#1967) 2024-12-21 20:32:03 +08:00
jerryjzhang
7dc013dfb3 [fix][project]Use SpringDoc to support swagger in Spring 3.x 2024-12-21 19:48:20 +08:00
lexluo09
72780f9acf [improvement][common] The thread pool adopts a generic thread pool configuration. (#1966) 2024-12-21 19:01:53 +08:00
jerryjzhang
7b49412bde Merge branch 'master' of github.com:tencentmusic/supersonic 2024-12-21 18:49:41 +08:00
jerryjzhang
9f63aca132 [fix][chat]Fix minor logic issue. 2024-12-21 18:49:27 +08:00
lexluo09
f7fce0217f [improvement][chat] Use a generic thread pool to perform concurrent mapping. (#1965) 2024-12-21 11:58:02 +08:00
jerryjzhang
c2d155705f Merge remote-tracking branch 'origin/master' 2024-12-20 12:43:36 +08:00
jerryjzhang
5faf5f3ac4 [improvement][project]Adjust github issue forms. 2024-12-20 12:43:22 +08:00
jerryjzhang
d88d8b3beb [project]Adjust the LICENSE to impose stricter restrictions in commercial scenarios. 2024-12-19 22:49:03 +08:00
jerryjzhang
4cb2256351 [improvement][headless]Merge function of QueryConverter abstraction to QueryParser. 2024-12-19 21:45:24 +08:00
lexluo09
8b69d57c4b [improvement][chat] Fix the issue with the DatabaseMatchStrategy variable under multi-threading (#1963) 2024-12-19 10:04:17 +08:00
jerryjzhang
91856ddebd [improvement][chat]Inject schema info into the prompt of LLMSqlCorrector. 2024-12-19 09:55:32 +08:00
jerryjzhang
94e97c9a1d [improvement][chat]Use accept() pattern to improve code readability. 2024-12-19 09:47:38 +08:00
wwsheng009
b57eed47e2 SAP HANA DATABASE Source support improvement[优化SAPhana数据库的支持] (#1959)
* (improvement)(database) update the support for sap hana database source

* (fix)(common) add the default timeout for ZhipuAiEmbeddingModel,avoid the program error
2024-12-17 21:52:19 +08:00
190 changed files with 1840 additions and 2299 deletions

View File

@@ -11,59 +11,28 @@ body:
If it is an idea or help wanted, please go to:
[Github Discussion](https://github.com/tencentmusic/supersonic/discussions)
- type: checkboxes
- type: input
id: version
attributes:
label: Search before asking
description: >
Please make sure to search in the [issues](https://github.com/tencentmusic/supersonic/issues?q=is%3Aissue) first to see
whether the same issue was reported already.
options:
- label: >
I had searched in the [issues](https://github.com/tencentmusic/supersonic/issues?q=is%3Aissue) and found no similar
issues.
required: true
label: SuperSonic version
description: Please tell us which version you are using.
placeholder: "0.9.8"
validations:
required: true
- type: textarea
- type: input
id: organization
attributes:
label: Version
description: What is the current version
placeholder: >
Please provide the version you are using.
If it is the trunk version, please input commit id.
label: Your organization
description: Please tell us your organization so that we can provide you better support and advice.
placeholder: "TME..."
validations:
required: true
- type: textarea
attributes:
label: What's Wrong?
description: Describe the bug.
placeholder: >
Describe the specific problem, the more detailed the better.
validations:
required: true
- type: textarea
attributes:
label: What You Expected?
validations:
required: true
- type: textarea
attributes:
label: How to Reproduce?
placeholder: >
Please try to give reproducing steps to facilitate quick location of the problem.
- What actions were performed
- Table building statement
- Import statement
- Cluster information: number of nodes, configuration, etc.
If it is hard to reproduce, please also explain the general scene.
- type: textarea
attributes:
label: Anything Else?
label: Description
description: Describe the bug you met.
- type: checkboxes
attributes:
@@ -74,16 +43,6 @@ body:
options:
- label: Yes I am willing to submit a PR!
- type: checkboxes
attributes:
label: Code of Conduct
description: The Code of Conduct helps create a safe space for everyone. We require that everyone agrees to it.
options:
- label: >
I agree to follow this project's
[Code of Conduct](https://www.apache.org/foundation/policies/conduct)
required: true
- type: markdown
attributes:
value: "Thanks for completing our form!"

View File

@@ -8,30 +8,20 @@ body:
attributes:
value: |
Thank you very much for your good enhancement for SuperSonic.
- type: checkboxes
attributes:
label: Search before asking
description: >
Please make sure to search in the [issues](https://github.com/tencentmusic/supersonic/issues?q=is%3Aissue) first to see
whether the same issue was reported already.
options:
- label: >
I had searched in the [issues](https://github.com/tencentmusic/supersonic/issues?q=is%3Aissue) and found no similar
issues.
required: true
- type: textarea
attributes:
label: Description
description: Describe the enhancement what you want, including motivation if it exists.
- type: textarea
- type: input
id: organization
attributes:
label: Solution
placeholder: >
Add overview of proposed solution.
Add related materials like links if they exist.
label: Your organization
description: Please tell us your organization so that we can provide you better support and advice.
placeholder: "TME..."
validations:
required: true
- type: checkboxes
attributes:
@@ -42,16 +32,6 @@ body:
options:
- label: Yes I am willing to submit a PR!
- type: checkboxes
attributes:
label: Code of Conduct
description: The Code of Conduct helps create a safe space for everyone. We require that everyone agrees to it.
options:
- label: >
I agree to follow this project's
[Code of Conduct](https://www.apache.org/foundation/policies/conduct)
required: true
- type: markdown
attributes:
value: "Thanks for completing our form!"

View File

@@ -8,33 +8,19 @@ body:
value: |
Thank you very much for your good ideas and suggestions for SuperSonic
- type: checkboxes
attributes:
label: Search before asking
description: >
Please make sure to search in the [issues](https://github.com/tencentmusic/supersonic/issues?q=is%3Aissue) first to see
whether the same issue was reported already.
options:
- label: >
I had searched in the [issues](https://github.com/tencentmusic/supersonic/issues?q=is%3Aissue) and found no similar
issues.
required: true
- type: textarea
attributes:
label: Description
description: Describe your ideas and needs.
- type: textarea
- type: input
id: organization
attributes:
label: Use case
placeholder: >
What problem does this feature mainly solve, or what scenarios it is suitable for.
- type: textarea
attributes:
label: Related issues
description: Is there currently another issue associated with this?
label: Your organization
description: Please tell us your organization so that we can provide you better support and advice.
placeholder: "TME..."
validations:
required: true
- type: checkboxes
attributes:
@@ -45,16 +31,4 @@ body:
options:
- label: Yes I am willing to submit a PR!
- type: checkboxes
attributes:
label: Code of Conduct
description: The Code of Conduct helps create a safe space for everyone. We require that everyone agrees to it.
options:
- label: >
I agree to follow this project's
[Code of Conduct](https://www.apache.org/foundation/policies/conduct)
required: true
- type: markdown
attributes:
value: "Thanks for completing our form!"

View File

@@ -8,6 +8,7 @@ body:
value: |
## Ask a Question about SuperSonic
Please provide a detailed description of your question or the clarification you seek regarding the SuperSonic project.
- type: textarea
id: describe-question
attributes:
@@ -16,43 +17,12 @@ body:
placeholder: "Type your question here..."
validations:
required: true
- type: textarea
id: additional-context
- type: input
id: organization
attributes:
label: Provide any additional context or information
description: If your question is related to a specific part of the SuperSonic project or if you have already looked through certain documentation, please provide that information here.
placeholder: "Add context here..."
label: Your organization
description: Please tell us your organization so that we can provide you better support and advice.
placeholder: "TME..."
validations:
required: false
- type: textarea
id: tried-to-resolve
attributes:
label: What have you tried to resolve your question
description: Let us know what you have done to try and understand or resolve your question. This can help us provide you with the most useful guidance.
placeholder: "I've already tried..."
validations:
required: false
- type: textarea
id: environment
attributes:
label: Your environment
description: Share details about your environment to help us reproduce the issue. Include your operating system, version of SuperSonic, and any other relevant details.
placeholder: "OS, SuperSonic version, etc..."
validations:
required: false
- type: textarea
id: screenshots-logs
attributes:
label: Screenshots or Logs
description: If applicable, add screenshots or logs to help explain your problem.
placeholder: "Paste your logs or attach screenshots here..."
validations:
required: false
- type: textarea
id: additional-information
attributes:
label: Additional information
description: Add any other context or details you think might be helpful for understanding your question.
placeholder: "Any other information..."
validations:
required: false
required: true

3
.gitignore vendored
View File

@@ -19,4 +19,5 @@ assembly/runtime/*
chm_db/
__pycache__/
/dict
assembly/build/*-SNAPSHOT
assembly/build/*-SNAPSHOT
**/node_modules/

18
LICENSE
View File

@@ -1,19 +1,11 @@
SuperSonic is licensed under the MIT License, with the following additional conditions:
1. You may provide SuperSonic to third parties as a commercial software or service. However,
when the following conditions are met, you must contact the producer to obtain a commercial license:
a. Multi-tenant SaaS service: Unless explicitly authorized by SuperSonic in writing, you may not use the
SuperSonic source code to operate a multi-tenant SaaS service.
b. LOGO and copyright information: In the process of using SuperSonic, you may not remove or modify
the LOGO or copyright information on the SuperSonic UI. This restriction is inapplicable to uses of
SuperSonic that do not involve its frontend components.
SuperSonic is licensed under the MIT License, you can freely use or integrate SuperSonic within
your organization. However, if you want to provide or integrate SuperSonic to third parties
as a commercial software or service, you must contact the producer to obtain a commercial license.
Please contact jerryjzhang@tencent.com by email to inquire about licensing matters.
2. As a contributor, you should agree that:
As a SuperSonic contributor, you should agree that:
a. The producer can adjust the open-source agreement to be more strict or relaxed as deemed necessary.
a. The producer can adjust the open-source agreement to be stricter or relaxed as deemed necessary.
b. Your contributed code may be used for commercial purposes, including but not limited to its business operations.
Terms of the MIT License:

View File

@@ -17,6 +17,8 @@ public class ChatMemoryFilter {
private Integer agentId;
private Long queryId;
private String question;
private List<String> questions;

View File

@@ -26,4 +26,8 @@ public class ChatMemoryUpdateReq {
private MemoryReviewResult humanReviewRet;
private String humanReviewCmt;
private MemoryReviewResult llmReviewRet;
private String llmReviewCmt;
}

View File

@@ -44,7 +44,7 @@ public class SqlExecutor implements ChatQueryExecutor {
Text2SQLExemplar.class);
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
memoryService.createMemory(ChatMemory.builder()
memoryService.createMemory(ChatMemory.builder().queryId(queryResult.getQueryId())
.agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING)
.question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo())
.dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql())
@@ -77,6 +77,7 @@ public class SqlExecutor implements ChatQueryExecutor {
long startTime = System.currentTimeMillis();
QueryResult queryResult = new QueryResult();
queryResult.setQueryId(executeContext.getRequest().getQueryId());
queryResult.setChatContext(parseInfo);
queryResult.setQueryMode(parseInfo.getQueryMode());
queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime);

View File

@@ -3,11 +3,13 @@ package com.tencent.supersonic.chat.server.memory;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.server.utils.ModelConfigHelper;
@@ -123,7 +125,10 @@ public class MemoryReviewTask {
if (MemoryReviewResult.POSITIVE.equals(m.getLlmReviewRet())) {
m.setStatus(MemoryStatus.ENABLED);
}
memoryService.updateMemory(m);
ChatMemoryUpdateReq memoryUpdateReq = ChatMemoryUpdateReq.builder().id(m.getId())
.status(m.getStatus()).llmReviewRet(m.getLlmReviewRet())
.llmReviewCmt(m.getLlmReviewCmt()).build();
memoryService.updateMemory(memoryUpdateReq, User.getDefaultUser());
}
}
}

View File

@@ -100,10 +100,12 @@ public class NL2SQLParser implements ChatQueryParser {
queryNLReq.setMapModeEnum(mode);
doParse(queryNLReq, parseResp);
}
if (parseResp.getSelectedParses().isEmpty()) {
if (parseResp.getSelectedParses().isEmpty() && candidateParses.isEmpty()) {
queryNLReq.setMapModeEnum(MapModeEnum.LOOSE);
doParse(queryNLReq, parseResp);
}
if (parseResp.getSelectedParses().isEmpty()) {
errMsg.append(parseResp.getErrorMsg());
continue;
@@ -137,11 +139,18 @@ public class NL2SQLParser implements ChatQueryParser {
SemanticParseInfo userSelectParse = parseContext.getRequest().getSelectedParse();
queryNLReq.setSelectedParseInfo(Objects.nonNull(userSelectParse) ? userSelectParse
: parseContext.getResponse().getSelectedParses().get(0));
parseContext.setResponse(new ChatParseResp(parseContext.getResponse().getQueryId()));
rewriteMultiTurn(parseContext, queryNLReq);
addDynamicExemplars(parseContext, queryNLReq);
doParse(queryNLReq, parseContext.getResponse());
// try again with all semantic fields passed to LLM
if (parseContext.getResponse().getState().equals(ParseResp.ParseState.FAILED)) {
queryNLReq.setSelectedParseInfo(null);
queryNLReq.setMapModeEnum(MapModeEnum.ALL);
doParse(queryNLReq, parseContext.getResponse());
}
}
}

View File

@@ -23,6 +23,9 @@ public class ChatMemoryDO {
@TableField("agent_id")
private Integer agentId;
@TableField("query_id")
private Long queryId;
@TableField("question")
private String question;

View File

@@ -2,11 +2,7 @@ package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
import lombok.*;
import java.util.Date;
@@ -20,6 +16,8 @@ public class ChatMemory {
private Integer agentId;
private Long queryId;
private String question;
private String sideInfo;

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import lombok.Data;
@@ -8,6 +9,7 @@ import lombok.Data;
@Data
public class ExecuteContext {
private ChatExecuteReq request;
private QueryResult response;
private Agent agent;
private SemanticParseInfo parseInfo;

View File

@@ -43,12 +43,17 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
@Override
public void process(ExecuteContext executeContext, QueryResult queryResult) {
public boolean accept(ExecuteContext executeContext) {
Agent agent = executeContext.getAgent();
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
return Objects.nonNull(chatApp) && chatApp.isEnable();
}
@Override
public void process(ExecuteContext executeContext) {
QueryResult queryResult = executeContext.getResponse();
Agent agent = executeContext.getAgent();
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
if (Objects.isNull(chatApp) || !chatApp.isEnable()) {
return;
}
Map<String, Object> variable = new HashMap<>();
variable.put("question", executeContext.getRequest().getQueryText());

View File

@@ -27,17 +27,18 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
private static final int recommend_dimension_size = 5;
@Override
public void process(ExecuteContext executeContext, QueryResult queryResult) {
public boolean accept(ExecuteContext executeContext) {
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
return QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())
&& !CollectionUtils.isEmpty(semanticParseInfo.getMetrics());
}
@Override
public void process(ExecuteContext executeContext) {
QueryResult queryResult = executeContext.getResponse();
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
if (!QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())
|| CollectionUtils.isEmpty(semanticParseInfo.getMetrics())) {
return;
}
Long dataSetId = semanticParseInfo.getDataSetId();
Optional<SchemaElement> firstMetric = semanticParseInfo.getMetrics().stream().findFirst();
if (!firstMetric.isPresent()) {
return;
}
List<SchemaElement> dimensionRecommended =
getDimensions(firstMetric.get().getId(), dataSetId);
queryResult.setRecommendedDimensions(dimensionRecommended);

View File

@@ -1,11 +1,12 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
/** A ExecuteResultProcessor wraps things up before returning execution results to the users. */
public interface ExecuteResultProcessor extends ResultProcessor {
void process(ExecuteContext executeContext, QueryResult queryResult);
boolean accept(ExecuteContext executeContext);
void process(ExecuteContext executeContext);
}

View File

@@ -59,14 +59,18 @@ import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
public class MetricRatioCalcProcessor implements ExecuteResultProcessor {
@Override
public void process(ExecuteContext executeContext, QueryResult queryResult) {
public boolean accept(ExecuteContext executeContext) {
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
AggregatorConfig aggregatorConfig = ContextUtils.getBean(AggregatorConfig.class);
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics())
|| !aggregatorConfig.getEnableRatio()
|| !QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())) {
return;
}
return !CollectionUtils.isEmpty(semanticParseInfo.getMetrics())
&& aggregatorConfig.getEnableRatio()
&& QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType());
}
@Override
public void process(ExecuteContext executeContext) {
QueryResult queryResult = executeContext.getResponse();
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
AggregateInfo aggregateInfo = getAggregateInfo(executeContext.getRequest().getUser(),
semanticParseInfo, queryResult);
queryResult.setAggregateInfo(aggregateInfo);

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
@@ -16,14 +15,7 @@ import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import org.springframework.util.CollectionUtils;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;
/**
@@ -34,17 +26,20 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
private static final int METRIC_RECOMMEND_SIZE = 5;
@Override
public void process(ExecuteContext executeContext, QueryResult queryResult) {
public boolean accept(ExecuteContext executeContext) {
SemanticParseInfo parseInfo = executeContext.getParseInfo();
return Objects.nonNull(parseInfo.getQueryType())
&& parseInfo.getQueryType().equals(QueryType.AGGREGATE)
&& !CollectionUtils.isEmpty(parseInfo.getMetrics())
&& parseInfo.getMetrics().size() <= METRIC_RECOMMEND_SIZE;
}
@Override
public void process(ExecuteContext executeContext) {
fillSimilarMetric(executeContext.getParseInfo());
}
private void fillSimilarMetric(SemanticParseInfo parseInfo) {
if (Objects.isNull(parseInfo.getQueryType())
|| !parseInfo.getQueryType().equals(QueryType.AGGREGATE)
|| parseInfo.getMetrics().size() > METRIC_RECOMMEND_SIZE
|| CollectionUtils.isEmpty(parseInfo.getMetrics())) {
return;
}
List<String> metricNames =
Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
Map<String, Object> filterCondition = new HashMap<>();

View File

@@ -43,14 +43,17 @@ public class ErrorMsgRewriteProcessor implements ParseResultProcessor {
.enable(false).build());
}
@Override
public boolean accept(ParseContext parseContext) {
ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_ERROR_MESSAGE);
return StringUtils.isNotBlank(parseContext.getResponse().getErrorMsg())
&& Objects.nonNull(chatApp) && chatApp.isEnable();
}
@Override
public void process(ParseContext parseContext) {
String errMsg = parseContext.getResponse().getErrorMsg();
ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_ERROR_MESSAGE);
if (StringUtils.isBlank(errMsg) || Objects.isNull(chatApp) || !chatApp.isEnable()) {
return;
}
Map<String, Object> variables = new HashMap<>();
variables.put("user_question", parseContext.getRequest().getQueryText());
variables.put("system_message", errMsg);

View File

@@ -8,7 +8,6 @@ import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
@@ -21,12 +20,7 @@ import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;
/**
@@ -34,6 +28,12 @@ import java.util.stream.Collectors;
**/
@Slf4j
public class ParseInfoFormatProcessor implements ParseResultProcessor {
@Override
public boolean accept(ParseContext parseContext) {
return !parseContext.getResponse().getSelectedParses().isEmpty();
}
@Override
public void process(ParseContext parseContext) {
parseContext.getResponse().getSelectedParses().forEach(p -> {
@@ -216,9 +216,6 @@ public class ParseInfoFormatProcessor implements ParseResultProcessor {
}
private static boolean isPartitionDimension(DataSetSchema dataSetSchema, String sqlFieldName) {
if (TimeDimensionEnum.containsTimeDimension(sqlFieldName)) {
return true;
}
if (Objects.isNull(dataSetSchema) || Objects.isNull(dataSetSchema.getPartitionDimension())
|| Objects.isNull(dataSetSchema.getPartitionDimension().getName())) {
return false;

View File

@@ -6,5 +6,7 @@ import com.tencent.supersonic.chat.server.processor.ResultProcessor;
/** A ParseResultProcessor wraps things up before returning parsing results to the users. */
public interface ParseResultProcessor extends ResultProcessor {
boolean accept(ParseContext parseContext);
void process(ParseContext parseContext);
}

View File

@@ -23,6 +23,11 @@ import java.util.stream.Collectors;
@Slf4j
public class QueryRecommendProcessor implements ParseResultProcessor {
@Override
public boolean accept(ParseContext parseContext) {
return true;
}
@Override
public void process(ParseContext parseContext) {
CompletableFuture.runAsync(() -> doProcess(parseContext));

View File

@@ -10,6 +10,11 @@ import lombok.extern.slf4j.Slf4j;
@Slf4j
public class TimeCostCalcProcessor implements ParseResultProcessor {
@Override
public boolean accept(ParseContext parseContext) {
return true;
}
@Override
public void process(ParseContext parseContext) {
ChatParseResp parseResp = parseContext.getResponse();

View File

@@ -53,7 +53,7 @@ public class ChatController {
}
@PostMapping("/updateQAFeedback")
public Boolean updateQAFeedback(@RequestParam(value = "id") Integer id,
public Boolean updateQAFeedback(@RequestParam(value = "id") Long id,
@RequestParam(value = "score") Integer score,
@RequestParam(value = "feedback", required = false) String feedback) {
return chatService.updateFeedback(id, score, feedback);

View File

@@ -13,7 +13,6 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import java.util.List;
@@ -24,7 +23,7 @@ public interface ChatManageService {
boolean updateChatName(Long chatId, String chatName, String userName);
boolean updateFeedback(Integer id, Integer score, String feedback);
boolean updateFeedback(Long id, Integer score, String feedback);
boolean updateChatIsTop(Long chatId, int isTop);

View File

@@ -14,8 +14,6 @@ public interface MemoryService {
void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user);
void updateMemory(ChatMemory memory);
void batchDelete(List<Long> ids);
PageInfo<ChatMemory> pageMemories(PageMemoryReq pageMemoryReq);

View File

@@ -20,13 +20,13 @@ import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.stream.Collectors;
@Slf4j
@@ -42,7 +42,9 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
@Autowired
private ChatModelService chatModelService;
private ExecutorService executorService = Executors.newFixedThreadPool(1);
@Autowired
@Qualifier("chatExecutor")
private ThreadPoolExecutor executor;
@Override
public List<Agent> getAgents(User user, AuthType authType) {
@@ -108,7 +110,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
* @param agent
*/
private void executeAgentExamplesAsync(Agent agent) {
executorService.execute(() -> doExecuteAgentExamples(agent));
executor.execute(() -> doExecuteAgentExamples(agent));
}
private synchronized void doExecuteAgentExamples(Agent agent) {

View File

@@ -2,9 +2,9 @@ package com.tencent.supersonic.chat.server.service.impl;
import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.chat.api.pojo.request.*;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
@@ -15,11 +15,12 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.QueryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.persistence.repository.ChatRepository;
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@@ -38,6 +39,8 @@ public class ChatManageServiceImpl implements ChatManageService {
private ChatRepository chatRepository;
@Autowired
private ChatQueryRepository chatQueryRepository;
@Autowired
private MemoryService memoryService;
@Override
public Long addChat(User user, String chatName, Integer agentId) {
@@ -64,11 +67,28 @@ public class ChatManageServiceImpl implements ChatManageService {
}
@Override
public boolean updateFeedback(Integer id, Integer score, String feedback) {
public boolean updateFeedback(Long id, Integer score, String feedback) {
QueryDO intelligentQueryDO = new QueryDO();
intelligentQueryDO.setId(id);
intelligentQueryDO.setQuestionId(id);
intelligentQueryDO.setScore(score);
intelligentQueryDO.setFeedback(feedback);
// enable or disable memory based on user feedback
if (score >= 5 || score <= 1) {
ChatMemoryFilter memoryFilter = ChatMemoryFilter.builder().queryId(id).build();
List<ChatMemory> memories = memoryService.getMemories(memoryFilter);
memories.forEach(m -> {
MemoryStatus status = score >= 5 ? MemoryStatus.ENABLED : MemoryStatus.DISABLED;
MemoryReviewResult reviewResult =
score >= 5 ? MemoryReviewResult.POSITIVE : MemoryReviewResult.NEGATIVE;
ChatMemoryUpdateReq memoryUpdateReq = ChatMemoryUpdateReq.builder().id(m.getId())
.status(status).humanReviewRet(reviewResult)
.humanReviewCmt("Reviewed as per user feedback").build();
memoryService.updateMemory(memoryUpdateReq, User.getDefaultUser());
});
}
return chatRepository.updateFeedback(intelligentQueryDO);
}

View File

@@ -95,7 +95,12 @@ public class ChatQueryServiceImpl implements ChatQueryService {
ParseContext parseContext = buildParseContext(chatParseReq, new ChatParseResp(queryId));
chatQueryParsers.forEach(p -> p.parse(parseContext));
parseResultProcessors.forEach(p -> p.process(parseContext));
for (ParseResultProcessor processor : parseResultProcessors) {
if (processor.accept(parseContext)) {
processor.process(parseContext);
}
}
if (!parseContext.needFeedback()) {
chatManageService.batchAddParse(chatParseReq, parseContext.getResponse());
@@ -116,9 +121,12 @@ public class ChatQueryServiceImpl implements ChatQueryService {
}
}
executeContext.setResponse(queryResult);
if (queryResult != null) {
for (ExecuteResultProcessor processor : executeResultProcessors) {
processor.process(executeContext, queryResult);
if (processor.accept(executeContext)) {
processor.process(executeContext);
}
}
saveQueryResult(chatExecuteReq, queryResult);
}

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
@@ -9,6 +10,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.mapper.ChatMemoryMapper;
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.chat.server.service.MemoryService;
@@ -16,7 +18,6 @@ import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.common.util.BeanMapper;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
@@ -34,6 +35,9 @@ public class MemoryServiceImpl implements MemoryService {
@Autowired
private ChatMemoryRepository chatMemoryRepository;
@Autowired
private ChatMemoryMapper chatMemoryMapper;
@Autowired
private ExemplarService exemplarService;
@@ -57,20 +61,36 @@ public class MemoryServiceImpl implements MemoryService {
ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId());
boolean hadEnabled =
MemoryStatus.ENABLED.toString().equals(chatMemoryDO.getStatus().trim());
chatMemoryDO.setUpdatedBy(user.getName());
chatMemoryDO.setUpdatedAt(new Date());
BeanMapper.mapper(chatMemoryUpdateReq, chatMemoryDO);
if (MemoryStatus.ENABLED.equals(chatMemoryUpdateReq.getStatus()) && !hadEnabled) {
enableMemory(chatMemoryDO);
} else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) {
disableMemory(chatMemoryDO);
}
chatMemoryRepository.updateMemory(chatMemoryDO);
}
@Override
public void updateMemory(ChatMemory memory) {
chatMemoryRepository.updateMemory(getMemoryDO(memory));
LambdaUpdateWrapper<ChatMemoryDO> updateWrapper = new LambdaUpdateWrapper<>();
updateWrapper.eq(ChatMemoryDO::getId, chatMemoryDO.getId());
if (Objects.nonNull(chatMemoryUpdateReq.getStatus())) {
updateWrapper.set(ChatMemoryDO::getStatus, chatMemoryUpdateReq.getStatus());
}
if (Objects.nonNull(chatMemoryUpdateReq.getLlmReviewRet())) {
updateWrapper.set(ChatMemoryDO::getLlmReviewRet,
chatMemoryUpdateReq.getLlmReviewRet().toString());
}
if (Objects.nonNull(chatMemoryUpdateReq.getLlmReviewCmt())) {
updateWrapper.set(ChatMemoryDO::getLlmReviewCmt, chatMemoryUpdateReq.getLlmReviewCmt());
}
if (Objects.nonNull(chatMemoryUpdateReq.getHumanReviewRet())) {
updateWrapper.set(ChatMemoryDO::getHumanReviewRet,
chatMemoryUpdateReq.getHumanReviewRet().toString());
}
if (Objects.nonNull(chatMemoryUpdateReq.getHumanReviewCmt())) {
updateWrapper.set(ChatMemoryDO::getHumanReviewCmt,
chatMemoryUpdateReq.getHumanReviewCmt());
}
updateWrapper.set(ChatMemoryDO::getUpdatedAt, new Date());
updateWrapper.set(ChatMemoryDO::getUpdatedBy, user.getName());
chatMemoryMapper.update(updateWrapper);
}
@Override
@@ -120,7 +140,7 @@ public class MemoryServiceImpl implements MemoryService {
return chatMemoryDOS.stream().map(this::getMemory).collect(Collectors.toList());
}
private void enableMemory(ChatMemoryDO memory) {
public void enableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.ENABLED.toString());
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
Text2SQLExemplar.builder().question(memory.getQuestion())
@@ -128,7 +148,7 @@ public class MemoryServiceImpl implements MemoryService {
.sql(memory.getS2sql()).build());
}
private void disableMemory(ChatMemoryDO memory) {
public void disableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.DISABLED.toString());
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
Text2SQLExemplar.builder().question(memory.getQuestion())

View File

@@ -81,7 +81,7 @@ public class Configuration {
.setUnquotedCasing(Casing.TO_UPPER).setConformance(sqlDialect.getConformance())
.setLex(Lex.BIG_QUERY);
if (EngineType.HANADB.equals(engineType)) {
parserConfig = parserConfig.setQuoting(Quoting.DOUBLE_QUOTE);
parserConfig = parserConfig.setQuoting(Quoting.DOUBLE_QUOTE);
}
parserConfig = parserConfig.setQuotedCasing(Casing.UNCHANGED);
parserConfig = parserConfig.setUnquotedCasing(Casing.UNCHANGED);

View File

@@ -21,10 +21,11 @@ public class SqlDialectFactory {
.withDatabaseProduct(DatabaseProduct.BIG_QUERY).withLiteralQuoteString("'")
.withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED)
.withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(false);
public static final Context HANADB_CONTEXT = SqlDialect.EMPTY_CONTEXT
.withDatabaseProduct(DatabaseProduct.BIG_QUERY).withLiteralQuoteString("'")
.withIdentifierQuoteString("\"").withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED)
.withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(true);
public static final Context HANADB_CONTEXT =
SqlDialect.EMPTY_CONTEXT.withDatabaseProduct(DatabaseProduct.BIG_QUERY)
.withLiteralQuoteString("'").withIdentifierQuoteString("\"")
.withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED)
.withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(true);
private static Map<EngineType, SemanticSqlDialect> sqlDialectMap;
static {

View File

@@ -0,0 +1,36 @@
package com.tencent.supersonic.common.config;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.stereotype.Component;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
@Component
public class ThreadPoolConfig {
@Bean("commonExecutor")
public ThreadPoolExecutor getCommonExecutor() {
return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS,
new LinkedBlockingQueue<>(1024),
new ThreadFactoryBuilder().setNameFormat("supersonic-common-pool-").build(),
new ThreadPoolExecutor.CallerRunsPolicy());
}
@Bean("mapExecutor")
public ThreadPoolExecutor getMapExecutor() {
return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS, new LinkedBlockingQueue<>(),
new ThreadFactoryBuilder().setNameFormat("supersonic-map-pool-").build(),
new ThreadPoolExecutor.CallerRunsPolicy());
}
@Bean("chatExecutor")
public ThreadPoolExecutor getChatExecutor() {
return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS,
new LinkedBlockingQueue<>(1024),
new ThreadFactoryBuilder().setNameFormat("supersonic-chat-pool-").build(),
new ThreadPoolExecutor.CallerRunsPolicy());
}
}

View File

@@ -0,0 +1,24 @@
package com.tencent.supersonic.common.jsqlparser;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.statement.select.SelectItem;
import java.util.Set;
public class AliasAcquireVisitor extends ExpressionVisitorAdapter {
private Set<String> aliases;
public AliasAcquireVisitor(Set<String> aliases) {
this.aliases = aliases;
}
@Override
public void visit(SelectItem selectItem) {
Alias alias = selectItem.getAlias();
if (alias != null) {
aliases.add(alias.getName());
}
}
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.common.jsqlparser;
import com.google.common.collect.Sets;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
@@ -11,6 +12,7 @@ import java.util.Set;
public class FieldAcquireVisitor extends ExpressionVisitorAdapter {
private Set<String> fields;
private Set<String> aliases = Sets.newHashSet();
public FieldAcquireVisitor(Set<String> fields) {
this.fields = fields;
@@ -26,8 +28,9 @@ public class FieldAcquireVisitor extends ExpressionVisitorAdapter {
public void visit(SelectItem selectItem) {
Alias alias = selectItem.getAlias();
if (alias != null) {
fields.add(alias.getName());
aliases.add(alias.getName());
}
Expression expression = selectItem.getExpression();
if (expression != null) {
expression.accept(this);

View File

@@ -0,0 +1,39 @@
package com.tencent.supersonic.common.jsqlparser;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.statement.select.SelectItemVisitorAdapter;
import org.apache.commons.lang3.StringUtils;
import java.util.HashMap;
import java.util.Map;
public class FieldAliasReplaceNameVisitor extends SelectItemVisitorAdapter {
private Map<String, String> fieldNameMap;
private Map<String, String> aliasToActualExpression = new HashMap<>();
public FieldAliasReplaceNameVisitor(Map<String, String> fieldNameMap) {
this.fieldNameMap = fieldNameMap;
}
@Override
public void visit(SelectItem selectExpressionItem) {
Alias alias = selectExpressionItem.getAlias();
if (alias == null) {
return;
}
String aliasName = alias.getName();
String replaceValue = fieldNameMap.get(aliasName);
if (StringUtils.isBlank(replaceValue)) {
return;
}
aliasToActualExpression.put(aliasName, replaceValue);
alias.setName(replaceValue);
}
public Map<String, String> getAliasToActualExpression() {
return aliasToActualExpression;
}
}

View File

@@ -7,15 +7,7 @@ import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.LikeExpression;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.*;
import net.sf.jsqlparser.schema.Column;
import org.apache.commons.collections.CollectionUtils;
@@ -34,6 +26,29 @@ public class FieldAndValueAcquireVisitor extends ExpressionVisitorAdapter {
this.fieldExpressions = fieldExpressions;
}
public void visit(Between between) {
Expression leftExpression = between.getLeftExpression();
String columnName = null;
if (leftExpression instanceof Column) {
Column column = (Column) leftExpression;
columnName = column.getColumnName();
}
Expression betweenExpressionStart = between.getBetweenExpressionStart();
Expression betweenExpressionEnd = between.getBetweenExpressionEnd();
FieldExpression fieldExpressionStart = new FieldExpression();
fieldExpressionStart.setFieldName(columnName);
fieldExpressionStart.setFieldValue(getFieldValue(betweenExpressionStart));
fieldExpressionStart.setOperator(JsqlConstants.GREATER_THAN_EQUALS);
fieldExpressions.add(fieldExpressionStart);
FieldExpression fieldExpressionEnd = new FieldExpression();
fieldExpressionEnd.setFieldName(columnName);
fieldExpressionEnd.setFieldValue(getFieldValue(betweenExpressionEnd));
fieldExpressionEnd.setOperator(JsqlConstants.MINOR_THAN_EQUALS);
fieldExpressions.add(fieldExpressionEnd);
}
public void visit(LikeExpression expr) {
Expression leftExpression = expr.getLeftExpression();
Expression rightExpression = expr.getRightExpression();

View File

@@ -26,6 +26,7 @@ public class JsqlConstants {
public static final String EQUAL_CONSTANT = " 1 = 1 ";
public static final String IN_CONSTANT = " 1 in (1) ";
public static final String LIKE_CONSTANT = "1 like 1";
public static final String BETWEEN_AND_CONSTANT = "1 between 2 and 3";
public static final String IN = "IN";
public static final Map<String, String> rightMap = Stream.of(
new AbstractMap.SimpleEntry<>("<=", "<="), new AbstractMap.SimpleEntry<>("<", "<"),

View File

@@ -1,35 +1,17 @@
package com.tencent.supersonic.common.jsqlparser;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.GroupByElement;
import net.sf.jsqlparser.statement.select.OrderByElement;
import net.sf.jsqlparser.statement.select.ParenthesedSelect;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
import net.sf.jsqlparser.statement.select.SetOperationList;
import net.sf.jsqlparser.statement.select.*;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.*;
/** Sql Parser add Helper */
@Slf4j
@@ -144,42 +126,7 @@ public class SqlAddHelper {
return sql;
}
PlainSelect plainSelect = (PlainSelect) selectStatement;
List<String> chNameList = TimeDimensionEnum.getChNameList();
Boolean dateWhere = false;
for (String chName : chNameList) {
if (expression.toString().contains(chName)) {
dateWhere = true;
}
}
List<PlainSelect> plainSelectList = SqlSelectHelper.getWithItem(selectStatement);
if (!CollectionUtils.isEmpty(plainSelectList) && dateWhere) {
List<String> withNameList = SqlSelectHelper.getWithName(sql);
for (int i = 0; i < plainSelectList.size(); i++) {
if (plainSelectList.get(i).getFromItem() instanceof Table) {
Table table = (Table) plainSelectList.get(i).getFromItem();
if (withNameList.contains(table.getName())) {
continue;
}
}
Set<String> result = new HashSet<>();
List<PlainSelect> subPlainSelectList = new ArrayList<>();
subPlainSelectList.add(plainSelectList.get(i));
SqlSelectHelper.getWhereFields(subPlainSelectList, result);
if (TimeDimensionEnum.containsZhTimeDimension(new ArrayList<>(result))) {
continue;
}
Expression subWhere = plainSelectList.get(i).getWhere();
addWhere(plainSelectList.get(i), subWhere, expression);
}
return selectStatement.toString();
}
if (plainSelect.getFromItem() instanceof ParenthesedSelect && dateWhere) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) plainSelect.getFromItem();
PlainSelect subPlainSelect = parenthesedSelect.getPlainSelect();
Expression subWhere = subPlainSelect.getWhere();
addWhere(subPlainSelect, subWhere, expression);
return selectStatement.toString();
}
Expression where = plainSelect.getWhere();
addWhere(plainSelect, where, expression);

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic.common.jsqlparser;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.statement.select.PlainSelect;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
@@ -12,7 +12,7 @@ import java.util.Objects;
@Slf4j
public class SqlDateSelectHelper {
public static DateVisitor.DateBoundInfo getDateBoundInfo(String sql) {
public static DateVisitor.DateBoundInfo getDateBoundInfo(String sql, String dateField) {
List<PlainSelect> plainSelectList = SqlSelectHelper.getPlainSelect(sql);
if (plainSelectList.size() != 1) {
return null;
@@ -25,7 +25,7 @@ public class SqlDateSelectHelper {
if (Objects.isNull(where)) {
return null;
}
DateVisitor dateVisitor = new DateVisitor(TimeDimensionEnum.getChNameList());
DateVisitor dateVisitor = new DateVisitor(Collections.singletonList(dateField));
where.accept(dateVisitor);
return dateVisitor.getDateBoundInfo();
}

View File

@@ -8,16 +8,7 @@ import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.LikeExpression;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
import net.sf.jsqlparser.expression.operators.relational.*;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.select.AllColumns;
@@ -183,6 +174,8 @@ public class SqlRemoveHelper {
handleInExpression((InExpression) expression, removeFieldNames);
} else if (expression instanceof LikeExpression) {
handleLikeExpression((LikeExpression) expression, removeFieldNames);
} else if (expression instanceof Between) {
handleBetweenExpression((Between) expression, removeFieldNames);
}
} catch (JSQLParserException e) {
log.error("JSQLParserException", e);
@@ -226,6 +219,17 @@ public class SqlRemoveHelper {
updateLikeExpression(likeExpression, constantExpression);
}
private static void handleBetweenExpression(Between between, Set<String> removeFieldNames)
throws JSQLParserException {
String columnName = SqlSelectHelper.getColumnName(between.getLeftExpression());
if (!removeFieldNames.contains(columnName)) {
return;
}
Between constantExpression =
(Between) CCJSqlParserUtil.parseCondExpression(JsqlConstants.BETWEEN_AND_CONSTANT);
updateBetweenExpression(between, constantExpression);
}
private static void updateComparisonOperator(ComparisonOperator original,
ComparisonOperator constantExpression) {
original.setLeftExpression(constantExpression.getLeftExpression());
@@ -245,6 +249,12 @@ public class SqlRemoveHelper {
original.setRightExpression(constantExpression.getRightExpression());
}
private static void updateBetweenExpression(Between between, Between constantExpression) {
between.setBetweenExpressionEnd(constantExpression.getBetweenExpressionEnd());
between.setBetweenExpressionStart(constantExpression.getBetweenExpressionStart());
between.setLeftExpression(constantExpression.getLeftExpression());
}
public static String removeHavingCondition(String sql, Set<String> removeFieldNames) {
Select selectStatement = SqlSelectHelper.getSelect(sql);
if (!(selectStatement instanceof PlainSelect)) {
@@ -373,6 +383,10 @@ public class SqlRemoveHelper {
LikeExpression likeExpression = (LikeExpression) expression;
Expression leftExpression = likeExpression.getLeftExpression();
return recursionBase(leftExpression, expression, sqlEditEnum);
} else if (expression instanceof Between) {
Between between = (Between) expression;
Expression leftExpression = between.getLeftExpression();
return recursionBase(leftExpression, expression, sqlEditEnum);
}
return expression;
}

View File

@@ -449,6 +449,23 @@ public class SqlReplaceHelper {
}
}
public static String replaceAliasFieldName(String sql, Map<String, String> fieldNameMap) {
Select selectStatement = SqlSelectHelper.getSelect(sql);
if (!(selectStatement instanceof PlainSelect)) {
return sql;
}
PlainSelect plainSelect = (PlainSelect) selectStatement;
FieldAliasReplaceNameVisitor visitor = new FieldAliasReplaceNameVisitor(fieldNameMap);
for (SelectItem selectItem : plainSelect.getSelectItems()) {
selectItem.accept(visitor);
}
Map<String, String> aliasToActualExpression = visitor.getAliasToActualExpression();
if (Objects.nonNull(aliasToActualExpression) && !aliasToActualExpression.isEmpty()) {
return replaceFields(selectStatement.toString(), aliasToActualExpression, true);
}
return selectStatement.toString();
}
public static String replaceAlias(String sql) {
Select selectStatement = SqlSelectHelper.getSelect(sql);
if (!(selectStatement instanceof PlainSelect)) {

View File

@@ -133,6 +133,15 @@ public class SqlSelectHelper {
return result;
}
public static Set<String> getAliasFields(PlainSelect plainSelect) {
Set<String> result = new HashSet<>();
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
for (SelectItem selectItem : selectItems) {
selectItem.accept(new AliasAcquireVisitor(result));
}
return result;
}
public static List<PlainSelect> getPlainSelect(Select selectStatement) {
if (selectStatement == null) {
return null;
@@ -264,10 +273,16 @@ public class SqlSelectHelper {
public static List<String> getAllSelectFields(String sql) {
List<PlainSelect> plainSelects = getPlainSelects(getPlainSelect(sql));
Set<String> results = new HashSet<>();
Set<String> aliases = new HashSet<>();
for (PlainSelect plainSelect : plainSelects) {
List<String> fields = getFieldsByPlainSelect(plainSelect);
Set<String> subaliases = getAliasFields(plainSelect);
subaliases.removeAll(fields);
results.addAll(fields);
aliases.addAll(subaliases);
}
// do not account in aliases
results.removeAll(aliases);
return new ArrayList<>(results);
}

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.common.pojo;
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.DateUtils;
import lombok.Data;
import org.springframework.util.CollectionUtils;
@@ -40,6 +39,8 @@ public class DateConf {
private boolean groupByDate;
private String dateField;
public List<String> getDateList() {
if (!CollectionUtils.isEmpty(dateList)) {
return dateList;
@@ -49,18 +50,6 @@ public class DateConf {
return DateUtils.getDateList(startDateStr, endDateStr, getPeriod());
}
public String getGroupByTimeDimension() {
if (DatePeriodEnum.DAY.equals(period)) {
return TimeDimensionEnum.DAY.getName();
} else if (DatePeriodEnum.WEEK.equals(period)) {
return TimeDimensionEnum.WEEK.getName();
} else if (DatePeriodEnum.MONTH.equals(period)) {
return TimeDimensionEnum.MONTH.getName();
} else {
return TimeDimensionEnum.DAY.getName();
}
}
@Override
public boolean equals(Object o) {
if (this == o) {

View File

@@ -1,73 +1,5 @@
package com.tencent.supersonic.common.pojo.enums;
import org.springframework.util.CollectionUtils;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public enum TimeDimensionEnum {
DAY("sys_imp_date", "数据日期"),
WEEK("sys_imp_week", "数据日期_周"),
MONTH("sys_imp_month", "数据日期_月");
private String name;
private String chName;
TimeDimensionEnum(String name, String chName) {
this.name = name;
this.chName = chName;
}
public static boolean containsTimeDimension(String fieldName) {
if (getNameList().contains(fieldName) || getChNameList().contains(fieldName)) {
return true;
}
return false;
}
public static List<String> getNameList() {
return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getName)
.collect(Collectors.toList());
}
public static List<String> getChNameList() {
return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getChName)
.collect(Collectors.toList());
}
public static Map<String, String> getChNameToNameMap() {
return Arrays.stream(TimeDimensionEnum.values()).collect(Collectors
.toMap(TimeDimensionEnum::getChName, TimeDimensionEnum::getName, (k1, k2) -> k1));
}
public static Map<String, String> getNameToNameMap() {
return Arrays.stream(TimeDimensionEnum.values()).collect(Collectors
.toMap(TimeDimensionEnum::getName, TimeDimensionEnum::getName, (k1, k2) -> k1));
}
public String getName() {
return name;
}
public String getChName() {
return chName;
}
/**
* Determine if a time dimension field is included in a Chinese/English text field
*
* @param fields field
* @return true/false
*/
public static boolean containsZhTimeDimension(List<String> fields) {
if (CollectionUtils.isEmpty(fields)) {
return false;
}
return fields.stream().anyMatch(field -> containsTimeDimension(field));
}
DAY, WEEK, MONTH;
}

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.ItemDateResp;
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
@@ -21,22 +20,14 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.StringJoiner;
import java.util.regex.Pattern;
import static com.tencent.supersonic.common.pojo.Constants.APOSTROPHE;
import static com.tencent.supersonic.common.pojo.Constants.COMMA;
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.*;
@Slf4j
@Component
@Data
public class DateModeUtils {
private final String sysDateCol = TimeDimensionEnum.DAY.getName();
private final String sysDateMonthCol = TimeDimensionEnum.MONTH.getName();
private final String sysDateWeekCol = TimeDimensionEnum.WEEK.getName();
@Value("${s2.query.parameter.sys.zipper.begin:start_}")
private String sysZipperDateColBegin;
@@ -60,8 +51,8 @@ public class DateModeUtils {
public String hasDataModeStr(ItemDateResp dateDate, DateConf dateInfo) {
if (Objects.isNull(dateDate) || StringUtils.isEmpty(dateDate.getStartDate())
|| StringUtils.isEmpty(dateDate.getStartDate())) {
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dateInfo.getStartDate(),
sysDateCol, dateInfo.getEndDate());
return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(),
dateInfo.getStartDate(), dateInfo.getDateField(), dateInfo.getEndDate());
} else {
log.info("dateDate:{}", dateDate);
}
@@ -79,27 +70,28 @@ public class DateModeUtils {
dateFormatStr, ChronoUnit.DAYS);
LocalDate dateMax = endData;
LocalDate dateMin = dateMax.minusDays(unit - 1);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dateMin, sysDateCol,
dateMax);
return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(),
dateMin, dateInfo.getDateField(), dateMax);
}
if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) {
Long unit = getInterval(dateInfo.getStartDate(), dateInfo.getEndDate(),
dateFormatStr, ChronoUnit.MONTHS);
return generateMonthSql(endData, unit, dateFormatStr);
return generateMonthSql(endData, unit, dateFormatStr, dateInfo);
}
}
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dateInfo.getStartDate(),
sysDateCol, dateInfo.getEndDate());
return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(),
dateInfo.getStartDate(), dateInfo.getDateField(), dateInfo.getEndDate());
}
public String generateMonthSql(LocalDate endData, Long unit, String dateFormatStr) {
public String generateMonthSql(LocalDate endData, Long unit, String dateFormatStr,
DateConf dateConf) {
LocalDate dateMax = endData;
List<String> months = generateMonthStr(dateMax, unit, dateFormatStr);
if (!CollectionUtils.isEmpty(months)) {
StringJoiner joiner = new StringJoiner(",");
months.stream().forEach(month -> joiner.add("'" + month + "'"));
return String.format("(%s in (%s))", sysDateCol, joiner.toString());
return String.format("(%s in (%s))", dateConf.getDateField(), joiner.toString());
}
return "";
}
@@ -116,8 +108,8 @@ public class DateModeUtils {
public String recentDayStr(ItemDateResp dateDate, DateConf dateInfo) {
ImmutablePair<String, String> dayRange = recentDay(dateDate, dateInfo);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dayRange.left, sysDateCol,
dayRange.right);
return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(), dayRange.left,
dateInfo.getDateField(), dayRange.right);
}
public ImmutablePair<String, String> recentDay(ItemDateResp dateDate, DateConf dateInfo) {
@@ -134,24 +126,25 @@ public class DateModeUtils {
return ImmutablePair.of(start, dateDate.getEndDate());
}
public String recentMonthStr(LocalDate endData, Long unit, String dateFormatStr) {
public String recentMonthStr(LocalDate endData, Long unit, String dateFormatStr,
DateConf dateInfo) {
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(dateFormatStr);
String endStr = endData.format(formatter);
String start = endData.minusMonths(unit).format(formatter);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateMonthCol, start, sysDateMonthCol,
endStr);
return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(), start,
dateInfo.getDateField(), endStr);
}
public String recentMonthStr(ItemDateResp dateDate, DateConf dateInfo) {
List<ImmutablePair<String, String>> range = recentMonth(dateDate, dateInfo);
if (range.size() == 1) {
return String.format("(%s >= '%s' and %s <= '%s')", sysDateMonthCol, range.get(0).left,
sysDateMonthCol, range.get(0).right);
return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(),
range.get(0).left, dateInfo.getDateField(), range.get(0).right);
}
if (range.size() > 0) {
StringJoiner joiner = new StringJoiner(",");
range.stream().forEach(month -> joiner.add("'" + month.left + "'"));
return String.format("(%s in (%s))", sysDateCol, joiner.toString());
return String.format("(%s in (%s))", dateInfo.getDateField(), joiner.toString());
}
return "";
}
@@ -181,17 +174,17 @@ public class DateModeUtils {
return ret;
}
public String recentWeekStr(LocalDate endData, Long unit) {
public String recentWeekStr(LocalDate endData, Long unit, DateConf dataInfo) {
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(DAY_FORMAT);
String start = endData.minusDays(unit * 7).format(formatter);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateWeekCol, start, sysDateWeekCol,
endData.format(formatter));
return String.format("(%s >= '%s' and %s <= '%s')", dataInfo.getDateField(), start,
dataInfo.getDateField(), endData.format(formatter));
}
public String recentWeekStr(ItemDateResp dateDate, DateConf dateInfo) {
ImmutablePair<String, String> dayRange = recentWeek(dateDate, dateInfo);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateWeekCol, dayRange.left,
sysDateWeekCol, dayRange.right);
return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(), dayRange.left,
dateInfo.getDateField(), dayRange.right);
}
public ImmutablePair<String, String> recentWeek(ItemDateResp dateDate, DateConf dateInfo) {
@@ -242,26 +235,27 @@ public class DateModeUtils {
* @return
*/
public String betweenDateStr(DateConf dateInfo) {
String dateField = dateInfo.getDateField();
if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) {
// startDate YYYYMM
if (!dateInfo.getStartDate().contains(Constants.MINUS)) {
return String.format("%s >= '%s' and %s <= '%s'", sysDateMonthCol,
dateInfo.getStartDate(), sysDateMonthCol, dateInfo.getEndDate());
return String.format("%s >= '%s' and %s <= '%s'", dateField,
dateInfo.getStartDate(), dateField, dateInfo.getEndDate());
}
LocalDate endData =
LocalDate.parse(dateInfo.getEndDate(), DateTimeFormatter.ofPattern(DAY_FORMAT));
LocalDate startData = LocalDate.parse(dateInfo.getStartDate(),
DateTimeFormatter.ofPattern(DAY_FORMAT));
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(MONTH_FORMAT);
return String.format("%s >= '%s' and %s <= '%s'", sysDateMonthCol,
startData.format(formatter), sysDateMonthCol, endData.format(formatter));
return String.format("%s >= '%s' and %s <= '%s'", dateField,
startData.format(formatter), dateField, endData.format(formatter));
}
if (DatePeriodEnum.WEEK.equals(dateInfo.getPeriod())) {
return String.format("%s >= '%s' and %s <= '%s'", sysDateWeekCol,
dateInfo.getStartDate(), sysDateWeekCol, dateInfo.getEndDate());
return String.format("%s >= '%s' and %s <= '%s'", dateField, dateInfo.getStartDate(),
dateField, dateInfo.getEndDate());
}
return String.format("%s >= '%s' and %s <= '%s'", sysDateCol, dateInfo.getStartDate(),
sysDateCol, dateInfo.getEndDate());
return String.format("%s >= '%s' and %s <= '%s'", dateField, dateInfo.getStartDate(),
dateField, dateInfo.getEndDate());
}
/**
@@ -273,12 +267,12 @@ public class DateModeUtils {
public String listDateStr(DateConf dateInfo) {
StringJoiner joiner = new StringJoiner(COMMA);
dateInfo.getDateList().stream().forEach(date -> joiner.add(APOSTROPHE + date + APOSTROPHE));
String dateCol = sysDateCol;
String dateCol = dateInfo.getDateField();
if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) {
dateCol = sysDateMonthCol;
dateCol = dateInfo.getDateField();
}
if (DatePeriodEnum.WEEK.equals(dateInfo.getPeriod())) {
dateCol = sysDateWeekCol;
dateCol = dateInfo.getDateField();
}
return String.format("(%s in (%s))", dateCol, joiner.toString());
}
@@ -299,25 +293,26 @@ public class DateModeUtils {
if (DatePeriodEnum.DAY.equals(dateInfo.getPeriod())) {
LocalDate dateMax = LocalDate.now().minusDays(1);
LocalDate dateMin = dateMax.minusDays(unit - 1);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dateMin, sysDateCol,
dateMax);
return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(), dateMin,
dateInfo.getDateField(), dateMax);
}
if (DatePeriodEnum.WEEK.equals(dateInfo.getPeriod())) {
LocalDate dateMax = LocalDate.now().minusDays(1);
return recentWeekStr(dateMax, unit.longValue());
return recentWeekStr(dateMax, unit.longValue(), dateInfo);
}
if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) {
LocalDate dateMax = LocalDate.now().minusDays(1);
return recentMonthStr(dateMax, unit.longValue(), MONTH_FORMAT);
return recentMonthStr(dateMax, unit.longValue(), MONTH_FORMAT, dateInfo);
}
if (DatePeriodEnum.YEAR.equals(dateInfo.getPeriod())) {
LocalDate dateMax = LocalDate.now().minusDays(1);
return recentMonthStr(dateMax, unit.longValue() * 12, MONTH_FORMAT);
return recentMonthStr(dateMax, unit.longValue() * 12, MONTH_FORMAT, dateInfo);
}
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol,
LocalDate.now().minusDays(2), sysDateCol, LocalDate.now().minusDays(1));
return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(),
LocalDate.now().minusDays(2), dateInfo.getDateField(),
LocalDate.now().minusDays(1));
}
public String getDateWhereStr(DateConf dateInfo) {
@@ -349,32 +344,7 @@ public class DateModeUtils {
}
public String getSysDateCol(DateConf dateInfo) {
if (DatePeriodEnum.DAY.equals(dateInfo.getPeriod())) {
return sysDateCol;
}
if (DatePeriodEnum.WEEK.equals(dateInfo.getPeriod())) {
return sysDateWeekCol;
}
if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) {
return sysDateMonthCol;
}
return "";
return dateInfo.getDateField();
}
public boolean isDateStr(String date) {
return Pattern.matches("[\\d\\s-:]+", date);
}
public DatePeriodEnum getPeriodByCol(String col) {
if (sysDateCol.equalsIgnoreCase(col)) {
return DatePeriodEnum.DAY;
}
if (sysDateWeekCol.equalsIgnoreCase(col)) {
return DatePeriodEnum.WEEK;
}
if (sysDateMonthCol.equalsIgnoreCase(col)) {
return DatePeriodEnum.MONTH;
}
return null;
}
}

View File

@@ -10,6 +10,8 @@ import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
import static java.time.Duration.ofSeconds;
@Service
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "ZHIPU";
@@ -30,8 +32,9 @@ public class ZhipuModelFactory implements ModelFactory, InitializingBean {
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequests(embeddingModelConfig.getLogRequests())
.maxRetries(embeddingModelConfig.getMaxRetries()).callTimeout(ofSeconds(60))
.connectTimeout(ofSeconds(60)).writeTimeout(ofSeconds(60))
.readTimeout(ofSeconds(60)).logRequests(embeddingModelConfig.getLogRequests())
.logResponses(embeddingModelConfig.getLogResponses()).build();
}

View File

@@ -11,31 +11,31 @@ class SqlDateSelectHelperTest {
String sql = "SELECT 维度1,sum(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1";
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql);
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql, "数据日期");
Assert.assertEquals(dateBoundInfo.getLowerBound(), ">=");
Assert.assertEquals(dateBoundInfo.getLowerDate(), "2023-11-17");
sql = "SELECT 维度1,sum(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1";
dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql);
dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql, "数据日期");
Assert.assertEquals(dateBoundInfo.getLowerBound(), ">");
Assert.assertEquals(dateBoundInfo.getLowerDate(), "2023-11-17");
sql = "SELECT 维度1,sum(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql);
dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql, "数据日期");
Assert.assertEquals(dateBoundInfo.getUpperBound(), "<=");
Assert.assertEquals(dateBoundInfo.getUpperDate(), "2023-11-17");
sql = "SELECT 维度1,sum(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1";
dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql);
dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql, "数据日期");
Assert.assertEquals(dateBoundInfo.getUpperBound(), "<");
Assert.assertEquals(dateBoundInfo.getUpperDate(), "2023-11-17");
sql = "SELECT 维度1,sum(播放量) FROM 数据库 " + "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-10-17' "
+ "AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql);
dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql, "数据日期");
Assert.assertEquals(dateBoundInfo.getUpperBound(), "<=");
Assert.assertEquals(dateBoundInfo.getUpperDate(), "2023-11-17");
Assert.assertEquals(dateBoundInfo.getLowerBound(), ">=");

View File

@@ -157,6 +157,14 @@ class SqlRemoveHelperTest {
replaceSql = SqlRemoveHelper.removeWhereCondition(sql, removeFieldNames);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1) "
+ "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", replaceSql);
sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "and 歌曲名 between '2023-08-09' and '2024-08-09' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'"
+ " order by 播放量 desc limit 11";
replaceSql = SqlRemoveHelper.removeWhereCondition(sql, removeFieldNames);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' "
+ "ORDER BY 播放量 DESC LIMIT 11", replaceSql);
}
@Test

View File

@@ -324,6 +324,39 @@ class SqlReplaceHelperTest {
replaceSql);
}
@Test
void testReplaceAliasFieldName() {
Map<String, String> map = new HashMap<>();
map.put("总访问次数", "\"总访问次数\"");
map.put("访问次数", "\"访问次数\"");
String sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where "
+ "datediff('day', 数据日期, '2023-09-05') <= 3 group by 部门 order by 总访问次数 desc limit 10";
String replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map);
System.out.println(replaceSql);
Assert.assertEquals("SELECT 部门, sum(访问次数) AS \"总访问次数\" FROM 超音数 WHERE "
+ "datediff('day', 数据日期, '2023-09-05') <= 3 GROUP BY 部门 ORDER BY \"总访问次数\" DESC LIMIT 10",
replaceSql);
sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where "
+ "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' "
+ "group by 部门 order by 总访问次数 desc limit 10";
replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map);
System.out.println(replaceSql);
Assert.assertEquals("SELECT 部门, sum(访问次数) AS \"总访问次数\" FROM 超音数 WHERE "
+ "(datediff('day', 数据日期, '2023-09-05') <= 3) AND 数据日期 = '2023-10-10' "
+ "GROUP BY 部门 ORDER BY \"总访问次数\" DESC LIMIT 10", replaceSql);
sql = "select 部门, sum(访问次数) as 访问次数 from 超音数 where "
+ "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' "
+ "group by 部门 order by 访问次数 desc limit 10";
replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map);
System.out.println(replaceSql);
Assert.assertEquals(
"SELECT 部门, sum(\"访问次数\") AS \"访问次数\" FROM 超音数 WHERE (datediff('day', 数据日期, "
+ "'2023-09-05') <= 3) AND 数据日期 = '2023-10-10' GROUP BY 部门 ORDER BY \"访问次数\" DESC LIMIT 10",
replaceSql);
}
@Test
void testReplaceAggAliasOrderbyField() {
String sql = "SELECT SUM(访问次数) AS top10总播放量 FROM (SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数 "

View File

@@ -1,5 +1,5 @@
# Use an official OpenJDK runtime as a parent image
FROM openjdk:8-jdk
FROM openjdk:21-jdk-bullseye
# Set the working directory in the container
WORKDIR /usr/src/app
@@ -7,8 +7,6 @@ WORKDIR /usr/src/app
# Argument to pass in the supersonic version at build time
ARG SUPERSONIC_VERSION
RUN apt-get update
# Install necessary packages, including Postgres client
RUN apt-get update && apt-get install -y postgresql-client

View File

@@ -1,3 +1,3 @@
#!/usr/bin/env bash
SUPERSONIC_VERSION=0.9.10-SNAPSHOT docker-compose -f docker-compose.yml -p supersonic up
SUPERSONIC_VERSION=latest docker-compose -f docker-compose.yml -p supersonic up

View File

@@ -11,8 +11,8 @@ services:
POSTGRES_PASSWORD: supersonic_password
ports:
- "15432:5432"
volumes:
- postgres_data:/var/lib/postgresql
# volumes:
# - postgres_data:/var/lib/postgresql
networks:
- supersonic_network
dns:
@@ -72,9 +72,9 @@ services:
- 114.114.114.114
- 8.8.8.8
- 8.8.4.4
volumes:
#volumes:
#1.Named Volumes are best for persistent data managed by Docker.
- supersonic_data:/usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-latest}
#- supersonic_data:/usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-latest}
#2.Bind Mounts are suitable for frequent modifications and debugging.
# - ./conf/application-prd.yaml:/usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-latest}/conf/application-prd.yaml
#3.Detailed Bind Mounts offer more control over the mount behavior.
@@ -84,9 +84,9 @@ services:
# bind:
# propagation: rprivate
# create_host_path: true
volumes:
postgres_data:
supersonic_data:
#volumes:
# postgres_data:
# supersonic_data:
networks:
supersonic_network:

View File

@@ -12,6 +12,7 @@ TAGS="latest"
# If VERSION is provided, add it to TAGS and tag the image as latest
if [ -n "$VERSION" ]; then
TAGS="$TAGS $VERSION"
echo "Tagging Docker images $IMAGE_NAME:$VERSION to $IMAGE_NAME:latest"
docker tag $IMAGE_NAME:$VERSION $IMAGE_NAME:latest
fi

View File

@@ -32,6 +32,16 @@ public class Dimension {
this.type = type;
this.isCreateDimension = isCreateDimension;
this.bizName = bizName;
this.expr = bizName;
}
public Dimension(String name, String bizName, String expr, DimensionType type,
Integer isCreateDimension) {
this.name = name;
this.type = type;
this.isCreateDimension = isCreateDimension;
this.bizName = bizName;
this.expr = expr;
}
public Dimension(String name, String bizName, DimensionType type, Integer isCreateDimension,
@@ -45,12 +55,7 @@ public class Dimension {
this.bizName = bizName;
}
public static Dimension getDefault() {
return new Dimension("数据日期", "imp_date", DimensionType.partition_time, 0, "imp_date",
Constants.DAY_FORMAT, new DimensionTimeTypeParams("false", "day"));
}
public String getFieldName() {
return bizName;
return expr;
}
}

View File

@@ -23,11 +23,20 @@ public class Measure {
private String alias;
public Measure(String name, String bizName, String expr, String agg, Integer isCreateMetric) {
this.name = name;
this.agg = agg;
this.isCreateMetric = isCreateMetric;
this.bizName = bizName;
this.expr = expr;
}
public Measure(String name, String bizName, String agg, Integer isCreateMetric) {
this.name = name;
this.agg = agg;
this.isCreateMetric = isCreateMetric;
this.bizName = bizName;
this.expr = bizName;
}
public Measure(String bizName, String constraint) {
@@ -38,4 +47,5 @@ public class Measure {
public String getFieldName() {
return expr;
}
}

View File

@@ -8,5 +8,5 @@ import java.util.List;
@Data
public class MetricDefineByMeasureParams extends MetricDefineParams {
private List<MeasureParam> measures = Lists.newArrayList();
private List<Measure> measures = Lists.newArrayList();
}

View File

@@ -18,6 +18,8 @@ public class ModelDetail {
private String queryType;
private String dbType;
private String sqlQuery;
private String tableQuery;

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.headless.api.pojo.enums;
public enum MapModeEnum {
STRICT(0), MODERATE(2), LOOSE(4);
STRICT(0), MODERATE(2), LOOSE(4), ALL(6);
public int threshold;

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.headless.api.pojo.enums;
import com.tencent.supersonic.headless.api.pojo.MeasureParam;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByMeasureParams;
import java.util.List;
@@ -32,7 +32,7 @@ public enum MetricType {
return true;
}
if (MetricDefineType.MEASURE.equals(metricDefineType)) {
List<MeasureParam> measures = typeParams.getMeasures();
List<Measure> measures = typeParams.getMeasures();
if (measures.size() > 1) {
return true;
}

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -31,7 +32,7 @@ public class DimensionReq extends SchemaItem {
private DataTypeEnums dataType;
private Map<String, Object> ext;
private Map<String, Object> ext = new HashMap();
private DimensionTimeTypeParams typeParams;
}

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.api.pojo.response;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.RecordInfo;
import com.tencent.supersonic.common.util.AESEncryptionUtil;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
@@ -64,4 +65,8 @@ public class DatabaseResp extends RecordInfo {
}
return "";
}
public String passwordDecrypt() {
return AESEncryptionUtil.aesDecryptECB(password);
}
}

View File

@@ -1,14 +1,26 @@
package com.tencent.supersonic.headless.api.pojo.response;
import com.google.common.collect.Sets;
import lombok.Data;
import lombok.ToString;
import java.util.List;
import java.util.Set;
@Data
@ToString(callSuper = true)
public class DimSchemaResp extends DimensionResp {
private Long useCnt = 0L;
private Set<String> fields = Sets.newHashSet();
@Override
public boolean equals(Object o) {
return super.equals(o);
}
@Override
public int hashCode() {
return super.hashCode();
}
}

View File

@@ -2,13 +2,9 @@ package com.tencent.supersonic.headless.api.pojo.response;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.DataFormat;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByFieldParams;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByMeasureParams;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByMetricParams;
import com.tencent.supersonic.headless.api.pojo.RelateDimension;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.*;
import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType;
import com.tencent.supersonic.headless.api.pojo.enums.MetricType;
import lombok.Data;
import lombok.ToString;
import org.apache.commons.collections.CollectionUtils;
@@ -69,6 +65,19 @@ public class MetricResp extends SchemaItem {
private boolean containsPartitionDimensions;
public void setMetricDefinition(MetricDefineType type, MetricDefineParams params) {
if (MetricDefineType.MEASURE.equals(type)) {
assert params instanceof MetricDefineByMeasureParams;
metricDefineByMeasureParams = (MetricDefineByMeasureParams) params;
} else if (MetricDefineType.FIELD.equals(type)) {
assert params instanceof MetricDefineByFieldParams;
metricDefineByFieldParams = (MetricDefineByFieldParams) params;
} else if (MetricDefineType.METRIC.equals(type)) {
assert params instanceof MetricDefineByMetricParams;
metricDefineByMetricParams = (MetricDefineByMetricParams) params;
}
}
public void setClassifications(String tag) {
if (StringUtils.isBlank(tag)) {
classifications = Lists.newArrayList();
@@ -105,4 +114,8 @@ public class MetricResp extends SchemaItem {
}
return "";
}
public boolean isDerived() {
return MetricType.isDerived(metricDefineType, metricDefineByMeasureParams);
}
}

View File

@@ -1,11 +1,26 @@
package com.tencent.supersonic.headless.api.pojo.response;
import com.google.common.collect.Sets;
import lombok.Data;
import lombok.ToString;
import java.util.Set;
@Data
@ToString(callSuper = true)
public class MetricSchemaResp extends MetricResp {
private Long useCnt = 0L;
private Set<String> fields = Sets.newHashSet();
@Override
public boolean equals(Object o) {
return super.equals(o);
}
@Override
public int hashCode() {
return super.hashCode();
}
}

View File

@@ -1,17 +1,9 @@
package com.tencent.supersonic.headless.api.pojo.response;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.Dimension;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.Field;
import com.tencent.supersonic.headless.api.pojo.Identify;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.*;
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
import lombok.*;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
@@ -26,6 +18,7 @@ import java.util.stream.Collectors;
@ToString(callSuper = true)
@AllArgsConstructor
@NoArgsConstructor
@Builder
public class ModelResp extends SchemaItem {
private Long domainId;
@@ -62,6 +55,14 @@ public class ModelResp extends SchemaItem {
return isOpen != null && isOpen == 1;
}
public List<Measure> getMeasures() {
return modelDetail != null ? modelDetail.getMeasures() : Lists.newArrayList();
}
public List<Identify> getIdentifiers() {
return modelDetail != null ? modelDetail.getIdentifiers() : Lists.newArrayList();
}
public List<Dimension> getTimeDimension() {
if (modelDetail == null) {
return Lists.newArrayList();

View File

@@ -3,7 +3,6 @@ package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
@@ -13,12 +12,7 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;
/**
@@ -47,18 +41,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
protected Map<String, String> getFieldNameMap(ChatQueryContext chatQueryContext,
Long dataSetId) {
Map<String, String> result = getFieldNameMapFromDB(chatQueryContext, dataSetId);
if (chatQueryContext.containsPartitionDimensions(dataSetId)) {
result.put(TimeDimensionEnum.DAY.getChName(), TimeDimensionEnum.DAY.getChName());
result.put(TimeDimensionEnum.MONTH.getChName(), TimeDimensionEnum.MONTH.getChName());
result.put(TimeDimensionEnum.WEEK.getChName(), TimeDimensionEnum.WEEK.getChName());
result.put(TimeDimensionEnum.DAY.getName(), TimeDimensionEnum.DAY.getChName());
result.put(TimeDimensionEnum.MONTH.getName(), TimeDimensionEnum.MONTH.getChName());
result.put(TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.WEEK.getChName());
}
return result;
return getFieldNameMapFromDB(chatQueryContext, dataSetId);
}
private static Map<String, String> getFieldNameMapFromDB(ChatQueryContext chatQueryContext,
@@ -126,7 +109,6 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
}
return elements.stream();
}).collect(Collectors.toSet());
dimensions.add(TimeDimensionEnum.DAY.getChName());
return dimensions;
}
@@ -142,8 +124,6 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
Set<String> removeFieldNames = new HashSet<>();
removeFieldNames.addAll(TimeDimensionEnum.getChNameList());
removeFieldNames.addAll(TimeDimensionEnum.getNameList());
Map<String, String> fieldNameMap =
getFieldNameMapFromDB(chatQueryContext, semanticParseInfo.getDataSetId());
removeFieldNames.removeIf(fieldName -> fieldNameMap.containsKey(fieldName));

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
@@ -53,10 +52,6 @@ public class GroupByCorrector extends BaseSemanticCorrector {
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
return false;
}
// if only date in select not add group by.
if (selectFields.size() == 1 && TimeDimensionEnum.containsZhTimeDimension(selectFields)) {
return false;
}
if (SqlSelectHelper.hasGroupBy(correctS2SQL)) {
log.debug("No need to add 'group by', existed 'group by' in s2sql:{}", correctS2SQL);
return false;

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
@@ -34,11 +35,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
+ "please take a review and help correct it if necessary." + "\n#Rules: "
+ "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),sql=(corrected sql if NEGATIVE; empty string if POSITIVE)`."
+ "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard."
+ "\n3.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
+ "\n4.ALWAYS use `with` statement if nested aggregation is needed."
+ "\n5.ALWAYS enclose alias declared by `AS` command in underscores."
+ "\n6.Alias created by `AS` command must be in the same language ast the `Question`."
+ "\n#Question:{{question}} #InputSQL:{{sql}} #Response:";
+ "\n3.SQL columns and values must be mentioned in the `#Schema`."
+ "\n#Question:{{question}} #Schema:{{schema}} #InputSQL:{{sql}} #Response:";
public LLMSqlCorrector() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL修正")
@@ -67,12 +65,15 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
return;
}
Text2SQLExemplar exemplar = (Text2SQLExemplar) semanticParseInfo.getProperties()
.get(Text2SQLExemplar.PROPERTY_KEY);
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(chatApp.getChatModelConfig());
SemanticSqlExtractor extractor =
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
Prompt prompt = generatePrompt(chatQueryContext.getRequest().getQueryText(),
semanticParseInfo, chatApp.getPrompt());
semanticParseInfo, chatApp.getPrompt(), exemplar);
SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText());
keyPipelineLog.info("LLMSqlCorrector modelReq:\n{} \nmodelResp:\n{}", prompt.text(), s2Sql);
if ("NEGATIVE".equals(s2Sql.getOpinion()) && StringUtils.isNotBlank(s2Sql.getSql())) {
@@ -81,10 +82,11 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
}
private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo,
String promptTemplate) {
String promptTemplate, Text2SQLExemplar exemplar) {
Map<String, Object> variable = new HashMap<>();
variable.put("question", queryText);
variable.put("sql", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
variable.put("schema", exemplar.getDbSchema());
return PromptTemplate.from(promptTemplate).apply(variable);
}

View File

@@ -1,14 +1,8 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.AggregateEnum;
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.jsqlparser.SqlAsHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.*;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
@@ -21,11 +15,7 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;
/** Perform schema corrections on the Schema information in S2SQL. */
@@ -144,8 +134,6 @@ public class SchemaCorrector extends BaseSemanticCorrector {
Set<String> removeFieldNames = whereExpressionList.stream()
.filter(fieldExpression -> StringUtils.isBlank(fieldExpression.getFunction()))
.filter(fieldExpression -> !TimeDimensionEnum
.containsTimeDimension(fieldExpression.getFieldName()))
.filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue()
.equals(fieldExpression.getOperator()))
.filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName()))

View File

@@ -5,7 +5,6 @@ import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlDateSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
@@ -44,8 +43,7 @@ public class TimeCorrector extends BaseSemanticCorrector {
DataSetSchema dataSetSchema =
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
if (Objects.isNull(dataSetSchema) || Objects.isNull(dataSetSchema.getPartitionDimension())
|| Objects.isNull(dataSetSchema.getPartitionDimension().getName())
|| TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|| Objects.isNull(dataSetSchema.getPartitionDimension().getName())) {
return;
}
String partitionDimension = dataSetSchema.getPartitionDimension().getName();
@@ -75,7 +73,8 @@ public class TimeCorrector extends BaseSemanticCorrector {
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL,
semanticParseInfo.getDateInfo().getDateField());
if (dateBoundInfo != null && StringUtils.isBlank(dateBoundInfo.getLowerBound())
&& StringUtils.isNotBlank(dateBoundInfo.getUpperBound())

View File

@@ -0,0 +1,40 @@
package com.tencent.supersonic.headless.chat.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
import java.util.Map;
@Slf4j
public class AllFieldMapper extends BaseMapper {
@Override
public boolean accept(ChatQueryContext chatQueryContext) {
return MapModeEnum.ALL.equals(chatQueryContext.getRequest().getMapModeEnum());
}
@Override
public void doMap(ChatQueryContext chatQueryContext) {
Map<Long, DataSetSchema> schemaMap =
chatQueryContext.getSemanticSchema().getDataSetSchemaMap();
for (Map.Entry<Long, DataSetSchema> entry : schemaMap.entrySet()) {
List<SchemaElement> schemaElements = Lists.newArrayList();
schemaElements.addAll(entry.getValue().getDimensions());
schemaElements.addAll(entry.getValue().getMetrics());
for (SchemaElement schemaElement : schemaElements) {
chatQueryContext.getMapInfo().getMatchedElements(entry.getKey())
.add(SchemaElementMatch.builder().word(schemaElement.getName())
.element(schemaElement).detectWord(schemaElement.getName())
.similarity(1.0).build());
}
}
}
}

View File

@@ -27,6 +27,10 @@ public abstract class BaseMapper implements SchemaMapper {
@Override
public void map(ChatQueryContext chatQueryContext) {
if (!accept(chatQueryContext)) {
return;
}
String simpleName = this.getClass().getSimpleName();
long startTime = System.currentTimeMillis();
log.debug("before {},mapInfo:{}", simpleName,
@@ -46,6 +50,10 @@ public abstract class BaseMapper implements SchemaMapper {
public abstract void doMap(ChatQueryContext chatQueryContext);
protected boolean accept(ChatQueryContext chatQueryContext) {
return true;
}
public void addToSchemaMap(SchemaMapInfo schemaMap, Long dataSetId,
SchemaElementMatch newElementMatch) {
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =

View File

@@ -7,6 +7,8 @@ import com.tencent.supersonic.headless.chat.knowledge.MapResult;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
import java.util.HashMap;
@@ -14,10 +16,17 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;
@Service
@Slf4j
public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStrategy<T> {
@Autowired
@Qualifier("mapExecutor")
private ThreadPoolExecutor executor;
@Override
public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
@@ -63,6 +72,18 @@ public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStr
}
}
protected void executeTasks(List<Callable<Void>> tasks) {
try {
executor.invokeAll(tasks);
for (Callable<Void> future : tasks) {
future.call();
}
} catch (Exception e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Task execution interrupted", e);
}
}
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
if (MapModeEnum.STRICT.equals(mapModeEnum)) {
return 1.0d;

View File

@@ -27,12 +27,12 @@ import java.util.stream.Collectors;
@Slf4j
public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult> {
private List<SchemaElement> allElements;
private ThreadLocal<List<SchemaElement>> allElements = ThreadLocal.withInitial(ArrayList::new);
@Override
public Map<MatchText, List<DatabaseMapResult>> match(ChatQueryContext chatQueryContext,
List<S2Term> terms, Set<Long> detectDataSetIds) {
this.allElements = getSchemaElements(chatQueryContext);
allElements.set(getSchemaElements(chatQueryContext));
return super.match(chatQueryContext, terms, detectDataSetIds);
}
@@ -43,7 +43,7 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult
}
Double metricDimensionThresholdConfig = getThreshold(chatQueryContext);
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements.get());
List<DatabaseMapResult> results = new ArrayList<>();
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
String name = entry.getKey();

View File

@@ -20,12 +20,13 @@ import java.util.Objects;
*/
@Slf4j
public class EmbeddingMapper extends BaseMapper {
public void doMap(ChatQueryContext chatQueryContext) {
// Check if the map mode is LOOSE
if (!MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum())) {
return;
}
@Override
public boolean accept(ChatQueryContext chatQueryContext) {
return MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum());
}
public void doMap(ChatQueryContext chatQueryContext) {
// 1. Query from embedding by queryText
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy);
@@ -62,4 +63,5 @@ public class EmbeddingMapper extends BaseMapper {
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
}
}
}

View File

@@ -16,10 +16,11 @@ import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_NUMBER;
@@ -40,7 +41,7 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
@Override
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, Set<String> detectSegments) {
Set<EmbeddingResult> results = new HashSet<>();
Set<EmbeddingResult> results = ConcurrentHashMap.newKeySet();
int embeddingMapperBatch = Integer
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
@@ -52,12 +53,24 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
List<List<String>> queryTextsSubList =
Lists.partition(queryTextsList, embeddingMapperBatch);
List<Callable<Void>> tasks = new ArrayList<>();
for (List<String> queryTextsSub : queryTextsSubList) {
tasks.add(createTask(chatQueryContext, detectDataSetIds, queryTextsSub, results));
}
executeTasks(tasks);
return new ArrayList<>(results);
}
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
List<String> queryTextsSub, Set<EmbeddingResult> results) {
return () -> {
List<EmbeddingResult> oneRoundResults =
detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext);
selectResultInOneRound(results, oneRoundResults);
}
return new ArrayList<>(results);
synchronized (results) {
selectResultInOneRound(results, oneRoundResults);
}
return null;
};
}
private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds,

View File

@@ -35,11 +35,11 @@ public class MapperConfig extends ParameterConfig {
"维度值相似度阈值在动态调整中的最低值", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_TEXT_SIZE =
new Parameter("s2.mapper.embedding.word.size", "4", "用于向量召回文本长度",
new Parameter("s2.mapper.embedding.word.size", "3", "用于向量召回文本长度",
"为提高向量召回效率, 按指定长度进行向量语义召回", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_TEXT_STEP =
new Parameter("s2.mapper.embedding.word.step", "3", "向量召回文本每步长度",
new Parameter("s2.mapper.embedding.word.step", "2", "向量召回文本每步长度",
"为提高向量召回效率, 按指定每步长度进行召回", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_BATCH =
@@ -51,7 +51,7 @@ public class MapperConfig extends ParameterConfig {
"每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_THRESHOLD =
new Parameter("s2.mapper.embedding.threshold", "0.98", "向量召回相似度阈值", "相似度小于该阈值的则舍弃",
new Parameter("s2.mapper.embedding.threshold", "0.9", "向量召回相似度阈值", "相似度小于该阈值的则舍弃",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =

View File

@@ -9,22 +9,23 @@ import lombok.extern.slf4j.Slf4j;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Slf4j
public class TimeFieldMapper extends BaseMapper {
public class PartitionTimeMapper extends BaseMapper {
@Override
public boolean accept(ChatQueryContext chatQueryContext) {
return !(chatQueryContext.getRequest().getText2SQLType().equals(Text2SQLType.ONLY_RULE)
|| chatQueryContext.getMapInfo().isEmpty());
}
@Override
public void doMap(ChatQueryContext chatQueryContext) {
if (chatQueryContext.getRequest().getText2SQLType().equals(Text2SQLType.ONLY_RULE)) {
return;
}
Map<Long, DataSetSchema> schemaMap =
chatQueryContext.getSemanticSchema().getDataSetSchemaMap();
for (Map.Entry<Long, DataSetSchema> entry : schemaMap.entrySet()) {
List<SchemaElement> timeDims = entry.getValue().getDimensions().stream()
.filter(dim -> dim.getTimeFormat() != null).collect(Collectors.toList());
.filter(SchemaElement::isPartitionTime).toList();
for (SchemaElement schemaElement : timeDims) {
chatQueryContext.getMapInfo().getMatchedElements(entry.getKey())
.add(SchemaElementMatch.builder().word(schemaElement.getName())

View File

@@ -21,14 +21,16 @@ import java.util.stream.Collectors;
@Slf4j
public class QueryFilterMapper extends BaseMapper {
private double similarity = 1.0;
private final double similarity = 1.0;
@Override
public boolean accept(ChatQueryContext chatQueryContext) {
return !chatQueryContext.getRequest().getDataSetIds().isEmpty();
}
@Override
public void doMap(ChatQueryContext chatQueryContext) {
Set<Long> dataSetIds = chatQueryContext.getRequest().getDataSetIds();
if (CollectionUtils.isEmpty(dataSetIds)) {
return;
}
SchemaMapInfo schemaMapInfo = chatQueryContext.getMapInfo();
clearOtherSchemaElementMatch(dataSetIds, schemaMapInfo);
for (Long dataSetId : dataSetIds) {

View File

@@ -8,10 +8,11 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
@Service
@Slf4j
@@ -25,28 +26,38 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
Set<Long> detectDataSetIds) {
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(terms);
String text = chatQueryContext.getRequest().getQueryText();
Set<T> results = new HashSet<>();
Set<T> results = ConcurrentHashMap.newKeySet();
List<Callable<Void>> tasks = new ArrayList<>();
Set<String> detectSegments = new HashSet<>();
for (Integer startIndex = 0; startIndex <= text.length() - 1;) {
for (Integer index = startIndex; index <= text.length();) {
for (int startIndex = 0; startIndex <= text.length() - 1;) {
for (int index = startIndex; index <= text.length();) {
int offset = mapperHelper.getStepOffset(terms, startIndex);
index = mapperHelper.getStepIndex(regOffsetToLength, index);
if (index <= text.length()) {
String detectSegment = text.substring(startIndex, index).trim();
detectSegments.add(detectSegment);
List<T> oneRoundResults =
detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset);
selectResultInOneRound(results, oneRoundResults);
Callable<Void> task = createTask(chatQueryContext, detectDataSetIds,
detectSegment, offset, results);
tasks.add(task);
}
}
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
}
executeTasks(tasks);
return new ArrayList<>(results);
}
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
String detectSegment, int offset, Set<T> results) {
return () -> {
List<T> oneRoundResults =
detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset);
synchronized (results) {
selectResultInOneRound(results, oneRoundResults);
}
return null;
};
}
public abstract List<T> detectByStep(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, String detectSegment, int offset);
}

View File

@@ -16,14 +16,15 @@ import java.util.List;
@Slf4j
public class TermDescMapper extends BaseMapper {
@Override
public boolean accept(ChatQueryContext chatQueryContext) {
return !(CollectionUtils.isEmpty(chatQueryContext.getMapInfo().getTermDescriptionToMap())
|| chatQueryContext.getRequest().isDescriptionMapped());
}
@Override
public void doMap(ChatQueryContext chatQueryContext) {
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
List<SchemaElement> termElements = mapInfo.getTermDescriptionToMap();
if (CollectionUtils.isEmpty(termElements)
|| chatQueryContext.getRequest().isDescriptionMapped()) {
return;
}
List<SchemaElement> termElements = chatQueryContext.getMapInfo().getTermDescriptionToMap();
for (SchemaElement schemaElement : termElements) {
ChatQueryContext queryCtx =
buildQueryContext(chatQueryContext, schemaElement.getDescription());

View File

@@ -2,7 +2,10 @@ package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.QueryManager;
@@ -50,6 +53,15 @@ public class LLMResponseService {
parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
parseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL);
DataSetSchema dataSetSchema =
queryCtx.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
SchemaElement partitionDimension = dataSetSchema.getPartitionDimension();
if (Objects.nonNull(partitionDimension)) {
DateConf dateConf = new DateConf();
dateConf.setDateField(partitionDimension.getName());
parseInfo.setDateInfo(dateConf);
}
queryCtx.getCandidateQueries().add(semanticQuery);
}

View File

@@ -2,6 +2,8 @@ package com.tencent.supersonic.headless.chat.parser.rule;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
@@ -57,6 +59,10 @@ public class TimeRangeParser implements SemanticParser {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
SemanticParseInfo parseInfo = query.getParseInfo();
if (queryContext.containsPartitionDimensions(parseInfo.getDataSetId())) {
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap()
.get(parseInfo.getDataSetId());
SchemaElement partitionDimension = dataSetSchema.getPartitionDimension();
dateConf.setDateField(partitionDimension.getName());
parseInfo.setDateInfo(dateConf);
}
parseInfo.setScore(parseInfo.getScore() + dateConf.getDetectWord().length());

View File

@@ -4,18 +4,8 @@ import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.*;
import com.tencent.supersonic.headless.api.pojo.request.*;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
import com.tencent.supersonic.headless.chat.query.QueryManager;
@@ -25,13 +15,8 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.TERM;
@@ -233,8 +218,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
protected void convertBizNameToName(DataSetSchema dataSetSchema,
QueryStructReq queryStructReq) {
Map<String, String> bizNameToName = dataSetSchema.getBizNameToName();
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
List<Order> orders = queryStructReq.getOrders();
if (CollectionUtils.isNotEmpty(orders)) {
for (Order order : orders) {

View File

@@ -3,14 +3,13 @@ package com.tencent.supersonic.headless.chat.query.rule.detail;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import lombok.extern.slf4j.Slf4j;
import java.time.LocalDate;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -33,10 +32,13 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
chatQueryContext.getSemanticSchema().getDataSetSchemaMap();
DataSetSchema dataSetSchema = dataSetSchemaMap.get(parseInfo.getDataSetId());
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getDetailTypeTimeDefaultConfig();
SchemaElement partitionDimension = dataSetSchema.getPartitionDimension();
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())
if (Objects.nonNull(partitionDimension) && Objects.nonNull(timeDefaultConfig)
&& Objects.nonNull(timeDefaultConfig.getUnit())
&& timeDefaultConfig.getUnit() != -1) {
DateConf dateInfo = new DateConf();
dateInfo.setDateField(partitionDimension.getName());
int unit = timeDefaultConfig.getUnit();
String startDate = LocalDate.now().minusDays(unit).toString();
String endDate = startDate;

View File

@@ -3,14 +3,13 @@ package com.tencent.supersonic.headless.chat.query.rule.metric;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import lombok.extern.slf4j.Slf4j;
import java.time.LocalDate;
import java.util.List;
import java.util.Objects;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.METRIC;
@@ -40,10 +39,12 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap()
.get(parseInfo.getDataSetId());
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getMetricTypeTimeDefaultConfig();
DateConf dateInfo = new DateConf();
// 加上时间!=-1 判断
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())
SchemaElement partitionDimension = dataSetSchema.getPartitionDimension();
if (Objects.nonNull(partitionDimension) && Objects.nonNull(timeDefaultConfig)
&& Objects.nonNull(timeDefaultConfig.getUnit())
&& timeDefaultConfig.getUnit() != -1) {
DateConf dateInfo = new DateConf();
dateInfo.setDateField(partitionDimension.getName());
int unit = timeDefaultConfig.getUnit();
String startDate = LocalDate.now().minusDays(unit).toString();
String endDate = startDate;
@@ -55,8 +56,8 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
dateInfo.setPeriod(timeDefaultConfig.getPeriod());
dateInfo.setStartDate(startDate);
dateInfo.setEndDate(endDate);
// 时间不为-1才设置时间所以移到这里
parseInfo.setDateInfo(dateInfo);
}
}
}

View File

@@ -7,9 +7,7 @@ import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
@@ -22,13 +20,7 @@ import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;
@Slf4j
@@ -179,14 +171,7 @@ public class QueryReqBuilder {
if (Objects.isNull(dateConf)) {
return "";
}
String dateField = TimeDimensionEnum.DAY.getName();
if (DatePeriodEnum.MONTH.equals(dateConf.getPeriod())) {
dateField = TimeDimensionEnum.MONTH.getName();
}
if (DatePeriodEnum.WEEK.equals(dateConf.getPeriod())) {
dateField = TimeDimensionEnum.WEEK.getName();
}
return dateField;
return dateConf.getDateField();
}
public static QueryStructReq buildStructRatioReq(SemanticParseInfo parseInfo,

View File

@@ -7,7 +7,7 @@ public class HanadbAdaptor extends DefaultDbAdaptor {
@Override
public String rewriteSql(String sql) {
return sql.replaceAll("`", "\"");
return sql.replaceAll("`(.*?)`", "\"$1\"").replaceAll("\"([A-Z0-9_]+?)\"", "$1");
}
}

View File

@@ -8,7 +8,7 @@ import org.springframework.context.annotation.Configuration;
@Configuration
public class ExecutorConfig {
@Value("${s2.metricParser.agg.mysql.lowVersion:5.7}")
@Value("${s2.metricParser.agg.mysql.lowVersion:8.0}")
private String mysqlLowVersion;
@Value("${s2.metricParser.agg.ck.lowVersion:20.4}")

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.headless.core.executor;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.core.pojo.Database;
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
import com.tencent.supersonic.headless.core.utils.SqlUtils;
@@ -38,7 +38,7 @@ public class JdbcExecutor implements QueryExecutor {
SqlUtils sqlUtils = ContextUtils.getBean(SqlUtils.class);
String sql = StringUtils.normalizeSpace(queryStatement.getSql());
log.info("executing SQL: {}", sql);
Database database = queryStatement.getOntology().getDatabase();
DatabaseResp database = queryStatement.getOntology().getDatabase();
SemanticQueryResp queryResultWithColumns = new SemanticQueryResp();
try {
SqlUtils sqlUtil = sqlUtils.init(database);

View File

@@ -1,49 +0,0 @@
package com.tencent.supersonic.headless.core.pojo;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.RecordInfo;
import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.common.util.AESEncryptionUtil;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@AllArgsConstructor
@NoArgsConstructor
@Builder
public class Database extends RecordInfo {
private Long id;
private Long domainId;
private String name;
private String description;
private String version;
private String url;
private String username;
private String password;
private String database;
private String schema;
/** mysql,clickhouse */
private EngineType type;
private List<String> admins = Lists.newArrayList();
private List<String> viewers = Lists.newArrayList();
public String passwordDecrypt() {
return AESEncryptionUtil.aesDecryptECB(password);
}
}

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.core.pojo;
import com.alibaba.druid.pool.DruidDataSource;
import com.tencent.supersonic.headless.api.pojo.enums.DataType;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.core.utils.JdbcDataSourceUtils;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
@@ -106,7 +107,7 @@ public class JdbcDataSource {
}
}
public void removeDatasource(Database database) {
public void removeDatasource(DatabaseResp database) {
String key = getDataSourceKey(database);
@@ -128,7 +129,7 @@ public class JdbcDataSource {
}
}
public DruidDataSource getDataSource(Database database) throws RuntimeException {
public DruidDataSource getDataSource(DatabaseResp database) throws RuntimeException {
String name = database.getName();
String jdbcUrl = database.getUrl();
@@ -239,7 +240,7 @@ public class JdbcDataSource {
return druidDataSource;
}
private String getDataSourceKey(Database database) {
private String getDataSourceKey(DatabaseResp database) {
return JdbcDataSourceUtils.getKey(database.getName(), database.getUrl(),
database.getUsername(), database.passwordDecrypt(), "", false);
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.core.translator.parser.s2sql;
package com.tencent.supersonic.headless.core.pojo;
import lombok.Builder;
import lombok.Data;

View File

@@ -0,0 +1,42 @@
package com.tencent.supersonic.headless.core.pojo;
import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.DataModel;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.Dimension;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.Materialization;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.Metric;
import lombok.Data;
import java.util.*;
import java.util.stream.Collectors;
@Data
public class Ontology {
private List<Metric> metrics = new ArrayList<>();
private Map<String, DataModel> dataModelMap = new HashMap<>();
private Map<String, List<Dimension>> dimensionMap = new HashMap<>();
private List<Materialization> materializationList = new ArrayList<>();
private List<JoinRelation> joinRelations;
private DatabaseResp database;
public List<Dimension> getDimensions() {
return dimensionMap.values().stream().flatMap(Collection::stream)
.collect(Collectors.toList());
}
public EngineType getDatabaseType() {
if (Objects.nonNull(database)) {
return EngineType.fromString(database.getType().toUpperCase());
}
return null;
}
public String getDatabaseVersion() {
if (Objects.nonNull(database)) {
return database.getVersion();
}
return null;
}
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.core.translator.parser.s2sql;
package com.tencent.supersonic.headless.core.pojo;
import com.google.common.collect.Sets;
import com.tencent.supersonic.common.pojo.ColumnOrder;
@@ -9,7 +9,7 @@ import java.util.List;
import java.util.Set;
@Data
public class OntologyQueryParam {
public class OntologyQuery {
private Set<String> metrics = Sets.newHashSet();
private Set<String> dimensions = Sets.newHashSet();
private String where;

View File

@@ -1,8 +1,6 @@
package com.tencent.supersonic.headless.core.pojo;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.Ontology;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.OntologyQueryParam;
import lombok.Data;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Triple;
@@ -11,17 +9,18 @@ import org.apache.commons.lang3.tuple.Triple;
public class QueryStatement {
private Long dataSetId;
private String dataSetName;
private String sql;
private String errMsg;
private StructQueryParam structQueryParam;
private SqlQueryParam sqlQueryParam;
private OntologyQueryParam ontologyQueryParam;
private StructQuery structQuery;
private SqlQuery sqlQuery;
private OntologyQuery ontologyQuery;
private Integer status = 0;
private Boolean isS2SQL = false;
private Boolean enableOptimize = true;
private Triple<String, String, String> minMaxTime;
private Ontology ontology;
private SemanticSchemaResp semanticSchemaResp;
private SemanticSchemaResp semanticSchema;
private Integer limit = 1000;
private Boolean isTranslated = false;

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.core.pojo;
import lombok.Data;
@Data
public class SqlQueryParam {
public class SqlQuery {
private String sql;
private String table;
private boolean supportWith = true;

View File

@@ -12,7 +12,7 @@ import java.util.ArrayList;
import java.util.List;
@Data
public class StructQueryParam {
public class StructQuery {
private List<String> groups = new ArrayList();
private List<Aggregator> aggregators = new ArrayList();
private List<Order> orders = new ArrayList();

Some files were not shown because too many files have changed in this diff Show More