33 Commits

Author SHA1 Message Date
jerryjzhang
a2f54d4c80 [improvement][headless]Remove unnecessary Database class
Some checks are pending
supersonic CentOS CI / build (21) (push) Waiting to run
supersonic mac CI / build (21) (push) Waiting to run
supersonic ubuntu CI / build (21) (push) Waiting to run
supersonic windows CI / build (21) (push) Waiting to run
2024-12-29 21:58:39 +08:00
jerryjzhang
6486257c9e [improvement][chat&headless]Remove deprecated system time fields. 2024-12-29 03:22:42 +08:00
HB
6f5e477e3c [fix][queryStat] Fields that are inconsistent between the table Filed and DO (#1986) 2024-12-28 15:15:06 +08:00
yxm-coding
683f01c33b 修复Network中所有前端路由资源请求均报404错误 #1982 (#1985) 2024-12-27 20:54:53 +08:00
yxm-coding
3e1e5ae209 fix(launchers): update addViewController to correctly redirect to the front-end page when accessing the domain while logged in (#1981) 2024-12-27 15:35:50 +08:00
jerryjzhang
0612833618 [fix][chat]Fix logic in s2sql parsing. 2024-12-27 14:18:20 +08:00
lexluo09
a23d1071a3 [improvement][chat] Optimize the logic for obtaining the generic thread pool (#1979) 2024-12-26 23:37:55 +08:00
jerryjzhang
94267f6028 [improvement][chat]Introduce AllFieldMapper to increase parsing robustness when normal pipeline fails.
[improvement][chat]Introduce `AllFieldMapper` to increase parsing robustness when normal pipeline fails.
2024-12-26 23:20:43 +08:00
jerryjzhang
a4d2df4063 [improvement][project]Adapt docker related scripts to new version. 2024-12-26 15:02:28 +08:00
jerryjzhang
d04a086c88 [improvement][chat]Support reviewing query memory based on direct user feedback. 2024-12-26 09:47:13 +08:00
jerryjzhang
68963b9ec9 [improvement][project]Adjust files based on code style. 2024-12-26 09:12:12 +08:00
jerryjzhang
d40400d2a4 [fix][chat]Memory enabled by the review task should be stored in embedding store. 2024-12-26 00:11:12 +08:00
lexluo09
c483bb891a [fix][chat] Fix the issue with the order of parallel execution in the map. (#1976) 2024-12-25 20:52:52 +08:00
zehuiHuang
6738aba19e Issues 1974 (#1975)
* [fix][common]Support 'BETWEEN AND' query condition parameter parsing `CURRENT`. #1972
2024-12-25 19:33:40 +08:00
zehuiHuang
493a8035cd [fix][common]Support 'BETWEEN AND' query condition parameter parsing CURRENT. #1972 (#1973) 2024-12-25 19:33:26 +08:00
lwhy
b425c49c5b [fix][headless] Unexpected update of dimensions when modifying a model (#1971) 2024-12-25 19:33:10 +08:00
jerryjzhang
4dca6eec5a [improvement][project]Adjust github issue forms. 2024-12-23 20:24:35 +08:00
pisces
642d6a02e1 feat(chat-sdk/chatitem): 消息支持导出图表图片 (#1937) 2024-12-23 09:05:56 +08:00
jerryjzhang
5de5b0a5e2 [fix][headless]Fix issue in determining mysql version to support with statement. 2024-12-22 21:40:58 +08:00
lexluo09
8c6ae62522 [improvement][chat] Change the embedding to execute in parallel (#1967) 2024-12-21 20:32:03 +08:00
jerryjzhang
7dc013dfb3 [fix][project]Use SpringDoc to support swagger in Spring 3.x 2024-12-21 19:48:20 +08:00
lexluo09
72780f9acf [improvement][common] The thread pool adopts a generic thread pool configuration. (#1966) 2024-12-21 19:01:53 +08:00
jerryjzhang
7b49412bde Merge branch 'master' of github.com:tencentmusic/supersonic 2024-12-21 18:49:41 +08:00
jerryjzhang
9f63aca132 [fix][chat]Fix minor logic issue. 2024-12-21 18:49:27 +08:00
lexluo09
f7fce0217f [improvement][chat] Use a generic thread pool to perform concurrent mapping. (#1965) 2024-12-21 11:58:02 +08:00
jerryjzhang
c2d155705f Merge remote-tracking branch 'origin/master' 2024-12-20 12:43:36 +08:00
jerryjzhang
5faf5f3ac4 [improvement][project]Adjust github issue forms. 2024-12-20 12:43:22 +08:00
jerryjzhang
d88d8b3beb [project]Adjust the LICENSE to impose stricter restrictions in commercial scenarios. 2024-12-19 22:49:03 +08:00
jerryjzhang
4cb2256351 [improvement][headless]Merge function of QueryConverter abstraction to QueryParser. 2024-12-19 21:45:24 +08:00
lexluo09
8b69d57c4b [improvement][chat] Fix the issue with the DatabaseMatchStrategy variable under multi-threading (#1963) 2024-12-19 10:04:17 +08:00
jerryjzhang
91856ddebd [improvement][chat]Inject schema info into the prompt of LLMSqlCorrector. 2024-12-19 09:55:32 +08:00
jerryjzhang
94e97c9a1d [improvement][chat]Use accept() pattern to improve code readability. 2024-12-19 09:47:38 +08:00
wwsheng009
b57eed47e2 SAP HANA DATABASE Source support improvement[优化SAPhana数据库的支持] (#1959)
* (improvement)(database) update the support for sap hana database source

* (fix)(common) add the default timeout for ZhipuAiEmbeddingModel,avoid the program error
2024-12-17 21:52:19 +08:00
190 changed files with 1840 additions and 2299 deletions

View File

@@ -11,59 +11,28 @@ body:
If it is an idea or help wanted, please go to: If it is an idea or help wanted, please go to:
[Github Discussion](https://github.com/tencentmusic/supersonic/discussions) [Github Discussion](https://github.com/tencentmusic/supersonic/discussions)
- type: checkboxes - type: input
id: version
attributes: attributes:
label: Search before asking label: SuperSonic version
description: > description: Please tell us which version you are using.
Please make sure to search in the [issues](https://github.com/tencentmusic/supersonic/issues?q=is%3Aissue) first to see placeholder: "0.9.8"
whether the same issue was reported already. validations:
options:
- label: >
I had searched in the [issues](https://github.com/tencentmusic/supersonic/issues?q=is%3Aissue) and found no similar
issues.
required: true required: true
- type: textarea - type: input
id: organization
attributes: attributes:
label: Version label: Your organization
description: What is the current version description: Please tell us your organization so that we can provide you better support and advice.
placeholder: > placeholder: "TME..."
Please provide the version you are using.
If it is the trunk version, please input commit id.
validations: validations:
required: true required: true
- type: textarea - type: textarea
attributes: attributes:
label: What's Wrong? label: Description
description: Describe the bug. description: Describe the bug you met.
placeholder: >
Describe the specific problem, the more detailed the better.
validations:
required: true
- type: textarea
attributes:
label: What You Expected?
validations:
required: true
- type: textarea
attributes:
label: How to Reproduce?
placeholder: >
Please try to give reproducing steps to facilitate quick location of the problem.
- What actions were performed
- Table building statement
- Import statement
- Cluster information: number of nodes, configuration, etc.
If it is hard to reproduce, please also explain the general scene.
- type: textarea
attributes:
label: Anything Else?
- type: checkboxes - type: checkboxes
attributes: attributes:
@@ -74,16 +43,6 @@ body:
options: options:
- label: Yes I am willing to submit a PR! - label: Yes I am willing to submit a PR!
- type: checkboxes
attributes:
label: Code of Conduct
description: The Code of Conduct helps create a safe space for everyone. We require that everyone agrees to it.
options:
- label: >
I agree to follow this project's
[Code of Conduct](https://www.apache.org/foundation/policies/conduct)
required: true
- type: markdown - type: markdown
attributes: attributes:
value: "Thanks for completing our form!" value: "Thanks for completing our form!"

View File

@@ -8,30 +8,20 @@ body:
attributes: attributes:
value: | value: |
Thank you very much for your good enhancement for SuperSonic. Thank you very much for your good enhancement for SuperSonic.
- type: checkboxes
attributes:
label: Search before asking
description: >
Please make sure to search in the [issues](https://github.com/tencentmusic/supersonic/issues?q=is%3Aissue) first to see
whether the same issue was reported already.
options:
- label: >
I had searched in the [issues](https://github.com/tencentmusic/supersonic/issues?q=is%3Aissue) and found no similar
issues.
required: true
- type: textarea - type: textarea
attributes: attributes:
label: Description label: Description
description: Describe the enhancement what you want, including motivation if it exists. description: Describe the enhancement what you want, including motivation if it exists.
- type: textarea - type: input
id: organization
attributes: attributes:
label: Solution label: Your organization
placeholder: > description: Please tell us your organization so that we can provide you better support and advice.
Add overview of proposed solution. placeholder: "TME..."
validations:
Add related materials like links if they exist. required: true
- type: checkboxes - type: checkboxes
attributes: attributes:
@@ -42,16 +32,6 @@ body:
options: options:
- label: Yes I am willing to submit a PR! - label: Yes I am willing to submit a PR!
- type: checkboxes
attributes:
label: Code of Conduct
description: The Code of Conduct helps create a safe space for everyone. We require that everyone agrees to it.
options:
- label: >
I agree to follow this project's
[Code of Conduct](https://www.apache.org/foundation/policies/conduct)
required: true
- type: markdown - type: markdown
attributes: attributes:
value: "Thanks for completing our form!" value: "Thanks for completing our form!"

View File

@@ -8,33 +8,19 @@ body:
value: | value: |
Thank you very much for your good ideas and suggestions for SuperSonic Thank you very much for your good ideas and suggestions for SuperSonic
- type: checkboxes
attributes:
label: Search before asking
description: >
Please make sure to search in the [issues](https://github.com/tencentmusic/supersonic/issues?q=is%3Aissue) first to see
whether the same issue was reported already.
options:
- label: >
I had searched in the [issues](https://github.com/tencentmusic/supersonic/issues?q=is%3Aissue) and found no similar
issues.
required: true
- type: textarea - type: textarea
attributes: attributes:
label: Description label: Description
description: Describe your ideas and needs. description: Describe your ideas and needs.
- type: textarea - type: input
id: organization
attributes: attributes:
label: Use case label: Your organization
placeholder: > description: Please tell us your organization so that we can provide you better support and advice.
What problem does this feature mainly solve, or what scenarios it is suitable for. placeholder: "TME..."
validations:
- type: textarea required: true
attributes:
label: Related issues
description: Is there currently another issue associated with this?
- type: checkboxes - type: checkboxes
attributes: attributes:
@@ -45,16 +31,4 @@ body:
options: options:
- label: Yes I am willing to submit a PR! - label: Yes I am willing to submit a PR!
- type: checkboxes
attributes:
label: Code of Conduct
description: The Code of Conduct helps create a safe space for everyone. We require that everyone agrees to it.
options:
- label: >
I agree to follow this project's
[Code of Conduct](https://www.apache.org/foundation/policies/conduct)
required: true
- type: markdown
attributes:
value: "Thanks for completing our form!"

View File

@@ -8,6 +8,7 @@ body:
value: | value: |
## Ask a Question about SuperSonic ## Ask a Question about SuperSonic
Please provide a detailed description of your question or the clarification you seek regarding the SuperSonic project. Please provide a detailed description of your question or the clarification you seek regarding the SuperSonic project.
- type: textarea - type: textarea
id: describe-question id: describe-question
attributes: attributes:
@@ -16,43 +17,12 @@ body:
placeholder: "Type your question here..." placeholder: "Type your question here..."
validations: validations:
required: true required: true
- type: textarea
id: additional-context - type: input
id: organization
attributes: attributes:
label: Provide any additional context or information label: Your organization
description: If your question is related to a specific part of the SuperSonic project or if you have already looked through certain documentation, please provide that information here. description: Please tell us your organization so that we can provide you better support and advice.
placeholder: "Add context here..." placeholder: "TME..."
validations: validations:
required: false required: true
- type: textarea
id: tried-to-resolve
attributes:
label: What have you tried to resolve your question
description: Let us know what you have done to try and understand or resolve your question. This can help us provide you with the most useful guidance.
placeholder: "I've already tried..."
validations:
required: false
- type: textarea
id: environment
attributes:
label: Your environment
description: Share details about your environment to help us reproduce the issue. Include your operating system, version of SuperSonic, and any other relevant details.
placeholder: "OS, SuperSonic version, etc..."
validations:
required: false
- type: textarea
id: screenshots-logs
attributes:
label: Screenshots or Logs
description: If applicable, add screenshots or logs to help explain your problem.
placeholder: "Paste your logs or attach screenshots here..."
validations:
required: false
- type: textarea
id: additional-information
attributes:
label: Additional information
description: Add any other context or details you think might be helpful for understanding your question.
placeholder: "Any other information..."
validations:
required: false

1
.gitignore vendored
View File

@@ -20,3 +20,4 @@ chm_db/
__pycache__/ __pycache__/
/dict /dict
assembly/build/*-SNAPSHOT assembly/build/*-SNAPSHOT
**/node_modules/

18
LICENSE
View File

@@ -1,19 +1,11 @@
SuperSonic is licensed under the MIT License, with the following additional conditions: SuperSonic is licensed under the MIT License, you can freely use or integrate SuperSonic within
your organization. However, if you want to provide or integrate SuperSonic to third parties
1. You may provide SuperSonic to third parties as a commercial software or service. However, as a commercial software or service, you must contact the producer to obtain a commercial license.
when the following conditions are met, you must contact the producer to obtain a commercial license:
a. Multi-tenant SaaS service: Unless explicitly authorized by SuperSonic in writing, you may not use the
SuperSonic source code to operate a multi-tenant SaaS service.
b. LOGO and copyright information: In the process of using SuperSonic, you may not remove or modify
the LOGO or copyright information on the SuperSonic UI. This restriction is inapplicable to uses of
SuperSonic that do not involve its frontend components.
Please contact jerryjzhang@tencent.com by email to inquire about licensing matters. Please contact jerryjzhang@tencent.com by email to inquire about licensing matters.
2. As a contributor, you should agree that: As a SuperSonic contributor, you should agree that:
a. The producer can adjust the open-source agreement to be more strict or relaxed as deemed necessary. a. The producer can adjust the open-source agreement to be stricter or relaxed as deemed necessary.
b. Your contributed code may be used for commercial purposes, including but not limited to its business operations. b. Your contributed code may be used for commercial purposes, including but not limited to its business operations.
Terms of the MIT License: Terms of the MIT License:

View File

@@ -17,6 +17,8 @@ public class ChatMemoryFilter {
private Integer agentId; private Integer agentId;
private Long queryId;
private String question; private String question;
private List<String> questions; private List<String> questions;

View File

@@ -26,4 +26,8 @@ public class ChatMemoryUpdateReq {
private MemoryReviewResult humanReviewRet; private MemoryReviewResult humanReviewRet;
private String humanReviewCmt; private String humanReviewCmt;
private MemoryReviewResult llmReviewRet;
private String llmReviewCmt;
} }

View File

@@ -44,7 +44,7 @@ public class SqlExecutor implements ChatQueryExecutor {
Text2SQLExemplar.class); Text2SQLExemplar.class);
MemoryService memoryService = ContextUtils.getBean(MemoryService.class); MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
memoryService.createMemory(ChatMemory.builder() memoryService.createMemory(ChatMemory.builder().queryId(queryResult.getQueryId())
.agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING) .agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING)
.question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo()) .question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo())
.dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql()) .dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql())
@@ -77,6 +77,7 @@ public class SqlExecutor implements ChatQueryExecutor {
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
QueryResult queryResult = new QueryResult(); QueryResult queryResult = new QueryResult();
queryResult.setQueryId(executeContext.getRequest().getQueryId());
queryResult.setChatContext(parseInfo); queryResult.setChatContext(parseInfo);
queryResult.setQueryMode(parseInfo.getQueryMode()); queryResult.setQueryMode(parseInfo.getQueryMode());
queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime); queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime);

View File

@@ -3,11 +3,13 @@ package com.tencent.supersonic.chat.server.memory;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter; import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.pojo.ChatMemory; import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.AppModule; import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.server.utils.ModelConfigHelper; import com.tencent.supersonic.headless.server.utils.ModelConfigHelper;
@@ -123,7 +125,10 @@ public class MemoryReviewTask {
if (MemoryReviewResult.POSITIVE.equals(m.getLlmReviewRet())) { if (MemoryReviewResult.POSITIVE.equals(m.getLlmReviewRet())) {
m.setStatus(MemoryStatus.ENABLED); m.setStatus(MemoryStatus.ENABLED);
} }
memoryService.updateMemory(m); ChatMemoryUpdateReq memoryUpdateReq = ChatMemoryUpdateReq.builder().id(m.getId())
.status(m.getStatus()).llmReviewRet(m.getLlmReviewRet())
.llmReviewCmt(m.getLlmReviewCmt()).build();
memoryService.updateMemory(memoryUpdateReq, User.getDefaultUser());
} }
} }
} }

View File

@@ -100,10 +100,12 @@ public class NL2SQLParser implements ChatQueryParser {
queryNLReq.setMapModeEnum(mode); queryNLReq.setMapModeEnum(mode);
doParse(queryNLReq, parseResp); doParse(queryNLReq, parseResp);
} }
if (parseResp.getSelectedParses().isEmpty()) {
if (parseResp.getSelectedParses().isEmpty() && candidateParses.isEmpty()) {
queryNLReq.setMapModeEnum(MapModeEnum.LOOSE); queryNLReq.setMapModeEnum(MapModeEnum.LOOSE);
doParse(queryNLReq, parseResp); doParse(queryNLReq, parseResp);
} }
if (parseResp.getSelectedParses().isEmpty()) { if (parseResp.getSelectedParses().isEmpty()) {
errMsg.append(parseResp.getErrorMsg()); errMsg.append(parseResp.getErrorMsg());
continue; continue;
@@ -137,11 +139,18 @@ public class NL2SQLParser implements ChatQueryParser {
SemanticParseInfo userSelectParse = parseContext.getRequest().getSelectedParse(); SemanticParseInfo userSelectParse = parseContext.getRequest().getSelectedParse();
queryNLReq.setSelectedParseInfo(Objects.nonNull(userSelectParse) ? userSelectParse queryNLReq.setSelectedParseInfo(Objects.nonNull(userSelectParse) ? userSelectParse
: parseContext.getResponse().getSelectedParses().get(0)); : parseContext.getResponse().getSelectedParses().get(0));
parseContext.setResponse(new ChatParseResp(parseContext.getResponse().getQueryId())); parseContext.setResponse(new ChatParseResp(parseContext.getResponse().getQueryId()));
rewriteMultiTurn(parseContext, queryNLReq); rewriteMultiTurn(parseContext, queryNLReq);
addDynamicExemplars(parseContext, queryNLReq); addDynamicExemplars(parseContext, queryNLReq);
doParse(queryNLReq, parseContext.getResponse()); doParse(queryNLReq, parseContext.getResponse());
// try again with all semantic fields passed to LLM
if (parseContext.getResponse().getState().equals(ParseResp.ParseState.FAILED)) {
queryNLReq.setSelectedParseInfo(null);
queryNLReq.setMapModeEnum(MapModeEnum.ALL);
doParse(queryNLReq, parseContext.getResponse());
}
} }
} }

View File

@@ -23,6 +23,9 @@ public class ChatMemoryDO {
@TableField("agent_id") @TableField("agent_id")
private Integer agentId; private Integer agentId;
@TableField("query_id")
private Long queryId;
@TableField("question") @TableField("question")
private String question; private String question;

View File

@@ -2,11 +2,7 @@ package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import lombok.AllArgsConstructor; import lombok.*;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
import java.util.Date; import java.util.Date;
@@ -20,6 +16,8 @@ public class ChatMemory {
private Integer agentId; private Integer agentId;
private Long queryId;
private String question; private String question;
private String sideInfo; private String sideInfo;

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.server.pojo; package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq; import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import lombok.Data; import lombok.Data;
@@ -8,6 +9,7 @@ import lombok.Data;
@Data @Data
public class ExecuteContext { public class ExecuteContext {
private ChatExecuteReq request; private ChatExecuteReq request;
private QueryResult response;
private Agent agent; private Agent agent;
private SemanticParseInfo parseInfo; private SemanticParseInfo parseInfo;

View File

@@ -43,13 +43,18 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
@Override @Override
public void process(ExecuteContext executeContext, QueryResult queryResult) { public boolean accept(ExecuteContext executeContext) {
Agent agent = executeContext.getAgent(); Agent agent = executeContext.getAgent();
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY); ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
if (Objects.isNull(chatApp) || !chatApp.isEnable()) { return Objects.nonNull(chatApp) && chatApp.isEnable();
return;
} }
@Override
public void process(ExecuteContext executeContext) {
QueryResult queryResult = executeContext.getResponse();
Agent agent = executeContext.getAgent();
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
Map<String, Object> variable = new HashMap<>(); Map<String, Object> variable = new HashMap<>();
variable.put("question", executeContext.getRequest().getQueryText()); variable.put("question", executeContext.getRequest().getQueryText());
variable.put("data", queryResult.getTextResult()); variable.put("data", queryResult.getTextResult());

View File

@@ -27,17 +27,18 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
private static final int recommend_dimension_size = 5; private static final int recommend_dimension_size = 5;
@Override @Override
public void process(ExecuteContext executeContext, QueryResult queryResult) { public boolean accept(ExecuteContext executeContext) {
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo(); SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
if (!QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType()) return QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())
|| CollectionUtils.isEmpty(semanticParseInfo.getMetrics())) { && !CollectionUtils.isEmpty(semanticParseInfo.getMetrics());
return;
} }
@Override
public void process(ExecuteContext executeContext) {
QueryResult queryResult = executeContext.getResponse();
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
Long dataSetId = semanticParseInfo.getDataSetId(); Long dataSetId = semanticParseInfo.getDataSetId();
Optional<SchemaElement> firstMetric = semanticParseInfo.getMetrics().stream().findFirst(); Optional<SchemaElement> firstMetric = semanticParseInfo.getMetrics().stream().findFirst();
if (!firstMetric.isPresent()) {
return;
}
List<SchemaElement> dimensionRecommended = List<SchemaElement> dimensionRecommended =
getDimensions(firstMetric.get().getId(), dataSetId); getDimensions(firstMetric.get().getId(), dataSetId);
queryResult.setRecommendedDimensions(dimensionRecommended); queryResult.setRecommendedDimensions(dimensionRecommended);

View File

@@ -1,11 +1,12 @@
package com.tencent.supersonic.chat.server.processor.execute; package com.tencent.supersonic.chat.server.processor.execute;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext; import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.processor.ResultProcessor; import com.tencent.supersonic.chat.server.processor.ResultProcessor;
/** A ExecuteResultProcessor wraps things up before returning execution results to the users. */ /** A ExecuteResultProcessor wraps things up before returning execution results to the users. */
public interface ExecuteResultProcessor extends ResultProcessor { public interface ExecuteResultProcessor extends ResultProcessor {
void process(ExecuteContext executeContext, QueryResult queryResult); boolean accept(ExecuteContext executeContext);
void process(ExecuteContext executeContext);
} }

View File

@@ -59,14 +59,18 @@ import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
public class MetricRatioCalcProcessor implements ExecuteResultProcessor { public class MetricRatioCalcProcessor implements ExecuteResultProcessor {
@Override @Override
public void process(ExecuteContext executeContext, QueryResult queryResult) { public boolean accept(ExecuteContext executeContext) {
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo(); SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
AggregatorConfig aggregatorConfig = ContextUtils.getBean(AggregatorConfig.class); AggregatorConfig aggregatorConfig = ContextUtils.getBean(AggregatorConfig.class);
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics()) return !CollectionUtils.isEmpty(semanticParseInfo.getMetrics())
|| !aggregatorConfig.getEnableRatio() && aggregatorConfig.getEnableRatio()
|| !QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())) { && QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType());
return;
} }
@Override
public void process(ExecuteContext executeContext) {
QueryResult queryResult = executeContext.getResponse();
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
AggregateInfo aggregateInfo = getAggregateInfo(executeContext.getRequest().getUser(), AggregateInfo aggregateInfo = getAggregateInfo(executeContext.getRequest().getUser(),
semanticParseInfo, queryResult); semanticParseInfo, queryResult);
queryResult.setAggregateInfo(aggregateInfo); queryResult.setAggregateInfo(aggregateInfo);

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.chat.server.processor.execute; package com.tencent.supersonic.chat.server.processor.execute;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext; import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.common.pojo.enums.DictWordType;
@@ -16,14 +15,7 @@ import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult; import dev.langchain4j.store.embedding.RetrieveQueryResult;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.Collections; import java.util.*;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/** /**
@@ -34,17 +26,20 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
private static final int METRIC_RECOMMEND_SIZE = 5; private static final int METRIC_RECOMMEND_SIZE = 5;
@Override @Override
public void process(ExecuteContext executeContext, QueryResult queryResult) { public boolean accept(ExecuteContext executeContext) {
SemanticParseInfo parseInfo = executeContext.getParseInfo();
return Objects.nonNull(parseInfo.getQueryType())
&& parseInfo.getQueryType().equals(QueryType.AGGREGATE)
&& !CollectionUtils.isEmpty(parseInfo.getMetrics())
&& parseInfo.getMetrics().size() <= METRIC_RECOMMEND_SIZE;
}
@Override
public void process(ExecuteContext executeContext) {
fillSimilarMetric(executeContext.getParseInfo()); fillSimilarMetric(executeContext.getParseInfo());
} }
private void fillSimilarMetric(SemanticParseInfo parseInfo) { private void fillSimilarMetric(SemanticParseInfo parseInfo) {
if (Objects.isNull(parseInfo.getQueryType())
|| !parseInfo.getQueryType().equals(QueryType.AGGREGATE)
|| parseInfo.getMetrics().size() > METRIC_RECOMMEND_SIZE
|| CollectionUtils.isEmpty(parseInfo.getMetrics())) {
return;
}
List<String> metricNames = List<String> metricNames =
Collections.singletonList(parseInfo.getMetrics().iterator().next().getName()); Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
Map<String, Object> filterCondition = new HashMap<>(); Map<String, Object> filterCondition = new HashMap<>();

View File

@@ -43,14 +43,17 @@ public class ErrorMsgRewriteProcessor implements ParseResultProcessor {
.enable(false).build()); .enable(false).build());
} }
@Override
public boolean accept(ParseContext parseContext) {
ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_ERROR_MESSAGE);
return StringUtils.isNotBlank(parseContext.getResponse().getErrorMsg())
&& Objects.nonNull(chatApp) && chatApp.isEnable();
}
@Override @Override
public void process(ParseContext parseContext) { public void process(ParseContext parseContext) {
String errMsg = parseContext.getResponse().getErrorMsg(); String errMsg = parseContext.getResponse().getErrorMsg();
ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_ERROR_MESSAGE); ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_ERROR_MESSAGE);
if (StringUtils.isBlank(errMsg) || Objects.isNull(chatApp) || !chatApp.isEnable()) {
return;
}
Map<String, Object> variables = new HashMap<>(); Map<String, Object> variables = new HashMap<>();
variables.put("user_question", parseContext.getRequest().getQueryText()); variables.put("user_question", parseContext.getRequest().getQueryText());
variables.put("system_message", errMsg); variables.put("system_message", errMsg);

View File

@@ -8,7 +8,6 @@ import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
@@ -21,12 +20,7 @@ import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.Arrays; import java.util.*;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/** /**
@@ -34,6 +28,12 @@ import java.util.stream.Collectors;
**/ **/
@Slf4j @Slf4j
public class ParseInfoFormatProcessor implements ParseResultProcessor { public class ParseInfoFormatProcessor implements ParseResultProcessor {
@Override
public boolean accept(ParseContext parseContext) {
return !parseContext.getResponse().getSelectedParses().isEmpty();
}
@Override @Override
public void process(ParseContext parseContext) { public void process(ParseContext parseContext) {
parseContext.getResponse().getSelectedParses().forEach(p -> { parseContext.getResponse().getSelectedParses().forEach(p -> {
@@ -216,9 +216,6 @@ public class ParseInfoFormatProcessor implements ParseResultProcessor {
} }
private static boolean isPartitionDimension(DataSetSchema dataSetSchema, String sqlFieldName) { private static boolean isPartitionDimension(DataSetSchema dataSetSchema, String sqlFieldName) {
if (TimeDimensionEnum.containsTimeDimension(sqlFieldName)) {
return true;
}
if (Objects.isNull(dataSetSchema) || Objects.isNull(dataSetSchema.getPartitionDimension()) if (Objects.isNull(dataSetSchema) || Objects.isNull(dataSetSchema.getPartitionDimension())
|| Objects.isNull(dataSetSchema.getPartitionDimension().getName())) { || Objects.isNull(dataSetSchema.getPartitionDimension().getName())) {
return false; return false;

View File

@@ -6,5 +6,7 @@ import com.tencent.supersonic.chat.server.processor.ResultProcessor;
/** A ParseResultProcessor wraps things up before returning parsing results to the users. */ /** A ParseResultProcessor wraps things up before returning parsing results to the users. */
public interface ParseResultProcessor extends ResultProcessor { public interface ParseResultProcessor extends ResultProcessor {
boolean accept(ParseContext parseContext);
void process(ParseContext parseContext); void process(ParseContext parseContext);
} }

View File

@@ -23,6 +23,11 @@ import java.util.stream.Collectors;
@Slf4j @Slf4j
public class QueryRecommendProcessor implements ParseResultProcessor { public class QueryRecommendProcessor implements ParseResultProcessor {
@Override
public boolean accept(ParseContext parseContext) {
return true;
}
@Override @Override
public void process(ParseContext parseContext) { public void process(ParseContext parseContext) {
CompletableFuture.runAsync(() -> doProcess(parseContext)); CompletableFuture.runAsync(() -> doProcess(parseContext));

View File

@@ -10,6 +10,11 @@ import lombok.extern.slf4j.Slf4j;
@Slf4j @Slf4j
public class TimeCostCalcProcessor implements ParseResultProcessor { public class TimeCostCalcProcessor implements ParseResultProcessor {
@Override
public boolean accept(ParseContext parseContext) {
return true;
}
@Override @Override
public void process(ParseContext parseContext) { public void process(ParseContext parseContext) {
ChatParseResp parseResp = parseContext.getResponse(); ChatParseResp parseResp = parseContext.getResponse();

View File

@@ -53,7 +53,7 @@ public class ChatController {
} }
@PostMapping("/updateQAFeedback") @PostMapping("/updateQAFeedback")
public Boolean updateQAFeedback(@RequestParam(value = "id") Integer id, public Boolean updateQAFeedback(@RequestParam(value = "id") Long id,
@RequestParam(value = "score") Integer score, @RequestParam(value = "score") Integer score,
@RequestParam(value = "feedback", required = false) String feedback) { @RequestParam(value = "feedback", required = false) String feedback) {
return chatService.updateFeedback(id, score, feedback); return chatService.updateFeedback(id, score, feedback);

View File

@@ -13,7 +13,6 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import java.util.List; import java.util.List;
@@ -24,7 +23,7 @@ public interface ChatManageService {
boolean updateChatName(Long chatId, String chatName, String userName); boolean updateChatName(Long chatId, String chatName, String userName);
boolean updateFeedback(Integer id, Integer score, String feedback); boolean updateFeedback(Long id, Integer score, String feedback);
boolean updateChatIsTop(Long chatId, int isTop); boolean updateChatIsTop(Long chatId, int isTop);

View File

@@ -14,8 +14,6 @@ public interface MemoryService {
void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user); void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user);
void updateMemory(ChatMemory memory);
void batchDelete(List<Long> ids); void batchDelete(List<Long> ids);
PageInfo<ChatMemory> pageMemories(PageMemoryReq pageMemoryReq); PageInfo<ChatMemory> pageMemories(PageMemoryReq pageMemoryReq);

View File

@@ -20,13 +20,13 @@ import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.Executors;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Slf4j @Slf4j
@@ -42,7 +42,9 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
@Autowired @Autowired
private ChatModelService chatModelService; private ChatModelService chatModelService;
private ExecutorService executorService = Executors.newFixedThreadPool(1); @Autowired
@Qualifier("chatExecutor")
private ThreadPoolExecutor executor;
@Override @Override
public List<Agent> getAgents(User user, AuthType authType) { public List<Agent> getAgents(User user, AuthType authType) {
@@ -108,7 +110,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
* @param agent * @param agent
*/ */
private void executeAgentExamplesAsync(Agent agent) { private void executeAgentExamplesAsync(Agent agent) {
executorService.execute(() -> doExecuteAgentExamples(agent)); executor.execute(() -> doExecuteAgentExamples(agent));
} }
private synchronized void doExecuteAgentExamples(Agent agent) { private synchronized void doExecuteAgentExamples(Agent agent) {

View File

@@ -2,9 +2,9 @@ package com.tencent.supersonic.chat.server.service.impl;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.PageInfo; import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq; import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq; import com.tencent.supersonic.chat.api.pojo.request.*;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp; import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
@@ -15,11 +15,12 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.QueryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.QueryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.persistence.repository.ChatRepository; import com.tencent.supersonic.chat.server.persistence.repository.ChatRepository;
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.chat.server.service.ChatManageService; import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -38,6 +39,8 @@ public class ChatManageServiceImpl implements ChatManageService {
private ChatRepository chatRepository; private ChatRepository chatRepository;
@Autowired @Autowired
private ChatQueryRepository chatQueryRepository; private ChatQueryRepository chatQueryRepository;
@Autowired
private MemoryService memoryService;
@Override @Override
public Long addChat(User user, String chatName, Integer agentId) { public Long addChat(User user, String chatName, Integer agentId) {
@@ -64,11 +67,28 @@ public class ChatManageServiceImpl implements ChatManageService {
} }
@Override @Override
public boolean updateFeedback(Integer id, Integer score, String feedback) { public boolean updateFeedback(Long id, Integer score, String feedback) {
QueryDO intelligentQueryDO = new QueryDO(); QueryDO intelligentQueryDO = new QueryDO();
intelligentQueryDO.setId(id); intelligentQueryDO.setId(id);
intelligentQueryDO.setQuestionId(id);
intelligentQueryDO.setScore(score); intelligentQueryDO.setScore(score);
intelligentQueryDO.setFeedback(feedback); intelligentQueryDO.setFeedback(feedback);
// enable or disable memory based on user feedback
if (score >= 5 || score <= 1) {
ChatMemoryFilter memoryFilter = ChatMemoryFilter.builder().queryId(id).build();
List<ChatMemory> memories = memoryService.getMemories(memoryFilter);
memories.forEach(m -> {
MemoryStatus status = score >= 5 ? MemoryStatus.ENABLED : MemoryStatus.DISABLED;
MemoryReviewResult reviewResult =
score >= 5 ? MemoryReviewResult.POSITIVE : MemoryReviewResult.NEGATIVE;
ChatMemoryUpdateReq memoryUpdateReq = ChatMemoryUpdateReq.builder().id(m.getId())
.status(status).humanReviewRet(reviewResult)
.humanReviewCmt("Reviewed as per user feedback").build();
memoryService.updateMemory(memoryUpdateReq, User.getDefaultUser());
});
}
return chatRepository.updateFeedback(intelligentQueryDO); return chatRepository.updateFeedback(intelligentQueryDO);
} }

View File

@@ -95,7 +95,12 @@ public class ChatQueryServiceImpl implements ChatQueryService {
ParseContext parseContext = buildParseContext(chatParseReq, new ChatParseResp(queryId)); ParseContext parseContext = buildParseContext(chatParseReq, new ChatParseResp(queryId));
chatQueryParsers.forEach(p -> p.parse(parseContext)); chatQueryParsers.forEach(p -> p.parse(parseContext));
parseResultProcessors.forEach(p -> p.process(parseContext));
for (ParseResultProcessor processor : parseResultProcessors) {
if (processor.accept(parseContext)) {
processor.process(parseContext);
}
}
if (!parseContext.needFeedback()) { if (!parseContext.needFeedback()) {
chatManageService.batchAddParse(chatParseReq, parseContext.getResponse()); chatManageService.batchAddParse(chatParseReq, parseContext.getResponse());
@@ -116,9 +121,12 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
} }
executeContext.setResponse(queryResult);
if (queryResult != null) { if (queryResult != null) {
for (ExecuteResultProcessor processor : executeResultProcessors) { for (ExecuteResultProcessor processor : executeResultProcessors) {
processor.process(executeContext, queryResult); if (processor.accept(executeContext)) {
processor.process(executeContext);
}
} }
saveQueryResult(chatExecuteReq, queryResult); saveQueryResult(chatExecuteReq, queryResult);
} }

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.server.service.impl; package com.tencent.supersonic.chat.server.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.github.pagehelper.PageHelper; import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo; import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
@@ -9,6 +10,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq; import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq; import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.mapper.ChatMemoryMapper;
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository; import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatMemory; import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.service.MemoryService;
@@ -16,7 +18,6 @@ import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.service.ExemplarService; import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.common.util.BeanMapper;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
@@ -34,6 +35,9 @@ public class MemoryServiceImpl implements MemoryService {
@Autowired @Autowired
private ChatMemoryRepository chatMemoryRepository; private ChatMemoryRepository chatMemoryRepository;
@Autowired
private ChatMemoryMapper chatMemoryMapper;
@Autowired @Autowired
private ExemplarService exemplarService; private ExemplarService exemplarService;
@@ -57,20 +61,36 @@ public class MemoryServiceImpl implements MemoryService {
ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId()); ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId());
boolean hadEnabled = boolean hadEnabled =
MemoryStatus.ENABLED.toString().equals(chatMemoryDO.getStatus().trim()); MemoryStatus.ENABLED.toString().equals(chatMemoryDO.getStatus().trim());
chatMemoryDO.setUpdatedBy(user.getName());
chatMemoryDO.setUpdatedAt(new Date());
BeanMapper.mapper(chatMemoryUpdateReq, chatMemoryDO);
if (MemoryStatus.ENABLED.equals(chatMemoryUpdateReq.getStatus()) && !hadEnabled) { if (MemoryStatus.ENABLED.equals(chatMemoryUpdateReq.getStatus()) && !hadEnabled) {
enableMemory(chatMemoryDO); enableMemory(chatMemoryDO);
} else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) { } else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) {
disableMemory(chatMemoryDO); disableMemory(chatMemoryDO);
} }
chatMemoryRepository.updateMemory(chatMemoryDO);
}
@Override LambdaUpdateWrapper<ChatMemoryDO> updateWrapper = new LambdaUpdateWrapper<>();
public void updateMemory(ChatMemory memory) { updateWrapper.eq(ChatMemoryDO::getId, chatMemoryDO.getId());
chatMemoryRepository.updateMemory(getMemoryDO(memory)); if (Objects.nonNull(chatMemoryUpdateReq.getStatus())) {
updateWrapper.set(ChatMemoryDO::getStatus, chatMemoryUpdateReq.getStatus());
}
if (Objects.nonNull(chatMemoryUpdateReq.getLlmReviewRet())) {
updateWrapper.set(ChatMemoryDO::getLlmReviewRet,
chatMemoryUpdateReq.getLlmReviewRet().toString());
}
if (Objects.nonNull(chatMemoryUpdateReq.getLlmReviewCmt())) {
updateWrapper.set(ChatMemoryDO::getLlmReviewCmt, chatMemoryUpdateReq.getLlmReviewCmt());
}
if (Objects.nonNull(chatMemoryUpdateReq.getHumanReviewRet())) {
updateWrapper.set(ChatMemoryDO::getHumanReviewRet,
chatMemoryUpdateReq.getHumanReviewRet().toString());
}
if (Objects.nonNull(chatMemoryUpdateReq.getHumanReviewCmt())) {
updateWrapper.set(ChatMemoryDO::getHumanReviewCmt,
chatMemoryUpdateReq.getHumanReviewCmt());
}
updateWrapper.set(ChatMemoryDO::getUpdatedAt, new Date());
updateWrapper.set(ChatMemoryDO::getUpdatedBy, user.getName());
chatMemoryMapper.update(updateWrapper);
} }
@Override @Override
@@ -120,7 +140,7 @@ public class MemoryServiceImpl implements MemoryService {
return chatMemoryDOS.stream().map(this::getMemory).collect(Collectors.toList()); return chatMemoryDOS.stream().map(this::getMemory).collect(Collectors.toList());
} }
private void enableMemory(ChatMemoryDO memory) { public void enableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.ENABLED.toString()); memory.setStatus(MemoryStatus.ENABLED.toString());
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()), exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
Text2SQLExemplar.builder().question(memory.getQuestion()) Text2SQLExemplar.builder().question(memory.getQuestion())
@@ -128,7 +148,7 @@ public class MemoryServiceImpl implements MemoryService {
.sql(memory.getS2sql()).build()); .sql(memory.getS2sql()).build());
} }
private void disableMemory(ChatMemoryDO memory) { public void disableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.DISABLED.toString()); memory.setStatus(MemoryStatus.DISABLED.toString());
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()), exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
Text2SQLExemplar.builder().question(memory.getQuestion()) Text2SQLExemplar.builder().question(memory.getQuestion())

View File

@@ -21,9 +21,10 @@ public class SqlDialectFactory {
.withDatabaseProduct(DatabaseProduct.BIG_QUERY).withLiteralQuoteString("'") .withDatabaseProduct(DatabaseProduct.BIG_QUERY).withLiteralQuoteString("'")
.withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED) .withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED)
.withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(false); .withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(false);
public static final Context HANADB_CONTEXT = SqlDialect.EMPTY_CONTEXT public static final Context HANADB_CONTEXT =
.withDatabaseProduct(DatabaseProduct.BIG_QUERY).withLiteralQuoteString("'") SqlDialect.EMPTY_CONTEXT.withDatabaseProduct(DatabaseProduct.BIG_QUERY)
.withIdentifierQuoteString("\"").withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED) .withLiteralQuoteString("'").withIdentifierQuoteString("\"")
.withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED)
.withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(true); .withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(true);
private static Map<EngineType, SemanticSqlDialect> sqlDialectMap; private static Map<EngineType, SemanticSqlDialect> sqlDialectMap;

View File

@@ -0,0 +1,36 @@
package com.tencent.supersonic.common.config;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.stereotype.Component;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
@Component
public class ThreadPoolConfig {
@Bean("commonExecutor")
public ThreadPoolExecutor getCommonExecutor() {
return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS,
new LinkedBlockingQueue<>(1024),
new ThreadFactoryBuilder().setNameFormat("supersonic-common-pool-").build(),
new ThreadPoolExecutor.CallerRunsPolicy());
}
@Bean("mapExecutor")
public ThreadPoolExecutor getMapExecutor() {
return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS, new LinkedBlockingQueue<>(),
new ThreadFactoryBuilder().setNameFormat("supersonic-map-pool-").build(),
new ThreadPoolExecutor.CallerRunsPolicy());
}
@Bean("chatExecutor")
public ThreadPoolExecutor getChatExecutor() {
return new ThreadPoolExecutor(8, 16, 60 * 3, TimeUnit.SECONDS,
new LinkedBlockingQueue<>(1024),
new ThreadFactoryBuilder().setNameFormat("supersonic-chat-pool-").build(),
new ThreadPoolExecutor.CallerRunsPolicy());
}
}

View File

@@ -0,0 +1,24 @@
package com.tencent.supersonic.common.jsqlparser;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.statement.select.SelectItem;
import java.util.Set;
public class AliasAcquireVisitor extends ExpressionVisitorAdapter {
private Set<String> aliases;
public AliasAcquireVisitor(Set<String> aliases) {
this.aliases = aliases;
}
@Override
public void visit(SelectItem selectItem) {
Alias alias = selectItem.getAlias();
if (alias != null) {
aliases.add(alias.getName());
}
}
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.common.jsqlparser; package com.tencent.supersonic.common.jsqlparser;
import com.google.common.collect.Sets;
import net.sf.jsqlparser.expression.Alias; import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
@@ -11,6 +12,7 @@ import java.util.Set;
public class FieldAcquireVisitor extends ExpressionVisitorAdapter { public class FieldAcquireVisitor extends ExpressionVisitorAdapter {
private Set<String> fields; private Set<String> fields;
private Set<String> aliases = Sets.newHashSet();
public FieldAcquireVisitor(Set<String> fields) { public FieldAcquireVisitor(Set<String> fields) {
this.fields = fields; this.fields = fields;
@@ -26,8 +28,9 @@ public class FieldAcquireVisitor extends ExpressionVisitorAdapter {
public void visit(SelectItem selectItem) { public void visit(SelectItem selectItem) {
Alias alias = selectItem.getAlias(); Alias alias = selectItem.getAlias();
if (alias != null) { if (alias != null) {
fields.add(alias.getName()); aliases.add(alias.getName());
} }
Expression expression = selectItem.getExpression(); Expression expression = selectItem.getExpression();
if (expression != null) { if (expression != null) {
expression.accept(this); expression.accept(this);

View File

@@ -0,0 +1,39 @@
package com.tencent.supersonic.common.jsqlparser;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.statement.select.SelectItemVisitorAdapter;
import org.apache.commons.lang3.StringUtils;
import java.util.HashMap;
import java.util.Map;
public class FieldAliasReplaceNameVisitor extends SelectItemVisitorAdapter {
private Map<String, String> fieldNameMap;
private Map<String, String> aliasToActualExpression = new HashMap<>();
public FieldAliasReplaceNameVisitor(Map<String, String> fieldNameMap) {
this.fieldNameMap = fieldNameMap;
}
@Override
public void visit(SelectItem selectExpressionItem) {
Alias alias = selectExpressionItem.getAlias();
if (alias == null) {
return;
}
String aliasName = alias.getName();
String replaceValue = fieldNameMap.get(aliasName);
if (StringUtils.isBlank(replaceValue)) {
return;
}
aliasToActualExpression.put(aliasName, replaceValue);
alias.setName(replaceValue);
}
public Map<String, String> getAliasToActualExpression() {
return aliasToActualExpression;
}
}

View File

@@ -7,15 +7,7 @@ import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.expression.operators.relational.*;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.LikeExpression;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Column;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
@@ -34,6 +26,29 @@ public class FieldAndValueAcquireVisitor extends ExpressionVisitorAdapter {
this.fieldExpressions = fieldExpressions; this.fieldExpressions = fieldExpressions;
} }
public void visit(Between between) {
Expression leftExpression = between.getLeftExpression();
String columnName = null;
if (leftExpression instanceof Column) {
Column column = (Column) leftExpression;
columnName = column.getColumnName();
}
Expression betweenExpressionStart = between.getBetweenExpressionStart();
Expression betweenExpressionEnd = between.getBetweenExpressionEnd();
FieldExpression fieldExpressionStart = new FieldExpression();
fieldExpressionStart.setFieldName(columnName);
fieldExpressionStart.setFieldValue(getFieldValue(betweenExpressionStart));
fieldExpressionStart.setOperator(JsqlConstants.GREATER_THAN_EQUALS);
fieldExpressions.add(fieldExpressionStart);
FieldExpression fieldExpressionEnd = new FieldExpression();
fieldExpressionEnd.setFieldName(columnName);
fieldExpressionEnd.setFieldValue(getFieldValue(betweenExpressionEnd));
fieldExpressionEnd.setOperator(JsqlConstants.MINOR_THAN_EQUALS);
fieldExpressions.add(fieldExpressionEnd);
}
public void visit(LikeExpression expr) { public void visit(LikeExpression expr) {
Expression leftExpression = expr.getLeftExpression(); Expression leftExpression = expr.getLeftExpression();
Expression rightExpression = expr.getRightExpression(); Expression rightExpression = expr.getRightExpression();

View File

@@ -26,6 +26,7 @@ public class JsqlConstants {
public static final String EQUAL_CONSTANT = " 1 = 1 "; public static final String EQUAL_CONSTANT = " 1 = 1 ";
public static final String IN_CONSTANT = " 1 in (1) "; public static final String IN_CONSTANT = " 1 in (1) ";
public static final String LIKE_CONSTANT = "1 like 1"; public static final String LIKE_CONSTANT = "1 like 1";
public static final String BETWEEN_AND_CONSTANT = "1 between 2 and 3";
public static final String IN = "IN"; public static final String IN = "IN";
public static final Map<String, String> rightMap = Stream.of( public static final Map<String, String> rightMap = Stream.of(
new AbstractMap.SimpleEntry<>("<=", "<="), new AbstractMap.SimpleEntry<>("<", "<"), new AbstractMap.SimpleEntry<>("<=", "<="), new AbstractMap.SimpleEntry<>("<", "<"),

View File

@@ -1,35 +1,17 @@
package com.tencent.supersonic.common.jsqlparser; package com.tencent.supersonic.common.jsqlparser;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression; import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo; import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.select.GroupByElement;
import net.sf.jsqlparser.statement.select.OrderByElement;
import net.sf.jsqlparser.statement.select.ParenthesedSelect;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
import net.sf.jsqlparser.statement.select.SetOperationList;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.*;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
/** Sql Parser add Helper */ /** Sql Parser add Helper */
@Slf4j @Slf4j
@@ -144,42 +126,7 @@ public class SqlAddHelper {
return sql; return sql;
} }
PlainSelect plainSelect = (PlainSelect) selectStatement; PlainSelect plainSelect = (PlainSelect) selectStatement;
List<String> chNameList = TimeDimensionEnum.getChNameList();
Boolean dateWhere = false;
for (String chName : chNameList) {
if (expression.toString().contains(chName)) {
dateWhere = true;
}
}
List<PlainSelect> plainSelectList = SqlSelectHelper.getWithItem(selectStatement);
if (!CollectionUtils.isEmpty(plainSelectList) && dateWhere) {
List<String> withNameList = SqlSelectHelper.getWithName(sql);
for (int i = 0; i < plainSelectList.size(); i++) {
if (plainSelectList.get(i).getFromItem() instanceof Table) {
Table table = (Table) plainSelectList.get(i).getFromItem();
if (withNameList.contains(table.getName())) {
continue;
}
}
Set<String> result = new HashSet<>();
List<PlainSelect> subPlainSelectList = new ArrayList<>();
subPlainSelectList.add(plainSelectList.get(i));
SqlSelectHelper.getWhereFields(subPlainSelectList, result);
if (TimeDimensionEnum.containsZhTimeDimension(new ArrayList<>(result))) {
continue;
}
Expression subWhere = plainSelectList.get(i).getWhere();
addWhere(plainSelectList.get(i), subWhere, expression);
}
return selectStatement.toString();
}
if (plainSelect.getFromItem() instanceof ParenthesedSelect && dateWhere) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) plainSelect.getFromItem();
PlainSelect subPlainSelect = parenthesedSelect.getPlainSelect();
Expression subWhere = subPlainSelect.getWhere();
addWhere(subPlainSelect, subWhere, expression);
return selectStatement.toString();
}
Expression where = plainSelect.getWhere(); Expression where = plainSelect.getWhere();
addWhere(plainSelect, where, expression); addWhere(plainSelect, where, expression);

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic.common.jsqlparser; package com.tencent.supersonic.common.jsqlparser;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.statement.select.PlainSelect; import net.sf.jsqlparser.statement.select.PlainSelect;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
@@ -12,7 +12,7 @@ import java.util.Objects;
@Slf4j @Slf4j
public class SqlDateSelectHelper { public class SqlDateSelectHelper {
public static DateVisitor.DateBoundInfo getDateBoundInfo(String sql) { public static DateVisitor.DateBoundInfo getDateBoundInfo(String sql, String dateField) {
List<PlainSelect> plainSelectList = SqlSelectHelper.getPlainSelect(sql); List<PlainSelect> plainSelectList = SqlSelectHelper.getPlainSelect(sql);
if (plainSelectList.size() != 1) { if (plainSelectList.size() != 1) {
return null; return null;
@@ -25,7 +25,7 @@ public class SqlDateSelectHelper {
if (Objects.isNull(where)) { if (Objects.isNull(where)) {
return null; return null;
} }
DateVisitor dateVisitor = new DateVisitor(TimeDimensionEnum.getChNameList()); DateVisitor dateVisitor = new DateVisitor(Collections.singletonList(dateField));
where.accept(dateVisitor); where.accept(dateVisitor);
return dateVisitor.getDateBoundInfo(); return dateVisitor.getDateBoundInfo();
} }

View File

@@ -8,16 +8,7 @@ import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.Parenthesis; import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression; import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.expression.operators.relational.*;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.LikeExpression;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.select.AllColumns; import net.sf.jsqlparser.statement.select.AllColumns;
@@ -183,6 +174,8 @@ public class SqlRemoveHelper {
handleInExpression((InExpression) expression, removeFieldNames); handleInExpression((InExpression) expression, removeFieldNames);
} else if (expression instanceof LikeExpression) { } else if (expression instanceof LikeExpression) {
handleLikeExpression((LikeExpression) expression, removeFieldNames); handleLikeExpression((LikeExpression) expression, removeFieldNames);
} else if (expression instanceof Between) {
handleBetweenExpression((Between) expression, removeFieldNames);
} }
} catch (JSQLParserException e) { } catch (JSQLParserException e) {
log.error("JSQLParserException", e); log.error("JSQLParserException", e);
@@ -226,6 +219,17 @@ public class SqlRemoveHelper {
updateLikeExpression(likeExpression, constantExpression); updateLikeExpression(likeExpression, constantExpression);
} }
private static void handleBetweenExpression(Between between, Set<String> removeFieldNames)
throws JSQLParserException {
String columnName = SqlSelectHelper.getColumnName(between.getLeftExpression());
if (!removeFieldNames.contains(columnName)) {
return;
}
Between constantExpression =
(Between) CCJSqlParserUtil.parseCondExpression(JsqlConstants.BETWEEN_AND_CONSTANT);
updateBetweenExpression(between, constantExpression);
}
private static void updateComparisonOperator(ComparisonOperator original, private static void updateComparisonOperator(ComparisonOperator original,
ComparisonOperator constantExpression) { ComparisonOperator constantExpression) {
original.setLeftExpression(constantExpression.getLeftExpression()); original.setLeftExpression(constantExpression.getLeftExpression());
@@ -245,6 +249,12 @@ public class SqlRemoveHelper {
original.setRightExpression(constantExpression.getRightExpression()); original.setRightExpression(constantExpression.getRightExpression());
} }
private static void updateBetweenExpression(Between between, Between constantExpression) {
between.setBetweenExpressionEnd(constantExpression.getBetweenExpressionEnd());
between.setBetweenExpressionStart(constantExpression.getBetweenExpressionStart());
between.setLeftExpression(constantExpression.getLeftExpression());
}
public static String removeHavingCondition(String sql, Set<String> removeFieldNames) { public static String removeHavingCondition(String sql, Set<String> removeFieldNames) {
Select selectStatement = SqlSelectHelper.getSelect(sql); Select selectStatement = SqlSelectHelper.getSelect(sql);
if (!(selectStatement instanceof PlainSelect)) { if (!(selectStatement instanceof PlainSelect)) {
@@ -373,6 +383,10 @@ public class SqlRemoveHelper {
LikeExpression likeExpression = (LikeExpression) expression; LikeExpression likeExpression = (LikeExpression) expression;
Expression leftExpression = likeExpression.getLeftExpression(); Expression leftExpression = likeExpression.getLeftExpression();
return recursionBase(leftExpression, expression, sqlEditEnum); return recursionBase(leftExpression, expression, sqlEditEnum);
} else if (expression instanceof Between) {
Between between = (Between) expression;
Expression leftExpression = between.getLeftExpression();
return recursionBase(leftExpression, expression, sqlEditEnum);
} }
return expression; return expression;
} }

View File

@@ -449,6 +449,23 @@ public class SqlReplaceHelper {
} }
} }
public static String replaceAliasFieldName(String sql, Map<String, String> fieldNameMap) {
Select selectStatement = SqlSelectHelper.getSelect(sql);
if (!(selectStatement instanceof PlainSelect)) {
return sql;
}
PlainSelect plainSelect = (PlainSelect) selectStatement;
FieldAliasReplaceNameVisitor visitor = new FieldAliasReplaceNameVisitor(fieldNameMap);
for (SelectItem selectItem : plainSelect.getSelectItems()) {
selectItem.accept(visitor);
}
Map<String, String> aliasToActualExpression = visitor.getAliasToActualExpression();
if (Objects.nonNull(aliasToActualExpression) && !aliasToActualExpression.isEmpty()) {
return replaceFields(selectStatement.toString(), aliasToActualExpression, true);
}
return selectStatement.toString();
}
public static String replaceAlias(String sql) { public static String replaceAlias(String sql) {
Select selectStatement = SqlSelectHelper.getSelect(sql); Select selectStatement = SqlSelectHelper.getSelect(sql);
if (!(selectStatement instanceof PlainSelect)) { if (!(selectStatement instanceof PlainSelect)) {

View File

@@ -133,6 +133,15 @@ public class SqlSelectHelper {
return result; return result;
} }
public static Set<String> getAliasFields(PlainSelect plainSelect) {
Set<String> result = new HashSet<>();
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
for (SelectItem selectItem : selectItems) {
selectItem.accept(new AliasAcquireVisitor(result));
}
return result;
}
public static List<PlainSelect> getPlainSelect(Select selectStatement) { public static List<PlainSelect> getPlainSelect(Select selectStatement) {
if (selectStatement == null) { if (selectStatement == null) {
return null; return null;
@@ -264,10 +273,16 @@ public class SqlSelectHelper {
public static List<String> getAllSelectFields(String sql) { public static List<String> getAllSelectFields(String sql) {
List<PlainSelect> plainSelects = getPlainSelects(getPlainSelect(sql)); List<PlainSelect> plainSelects = getPlainSelects(getPlainSelect(sql));
Set<String> results = new HashSet<>(); Set<String> results = new HashSet<>();
Set<String> aliases = new HashSet<>();
for (PlainSelect plainSelect : plainSelects) { for (PlainSelect plainSelect : plainSelects) {
List<String> fields = getFieldsByPlainSelect(plainSelect); List<String> fields = getFieldsByPlainSelect(plainSelect);
Set<String> subaliases = getAliasFields(plainSelect);
subaliases.removeAll(fields);
results.addAll(fields); results.addAll(fields);
aliases.addAll(subaliases);
} }
// do not account in aliases
results.removeAll(aliases);
return new ArrayList<>(results); return new ArrayList<>(results);
} }

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.common.pojo; package com.tencent.supersonic.common.pojo;
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum; import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.DateUtils;
import lombok.Data; import lombok.Data;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@@ -40,6 +39,8 @@ public class DateConf {
private boolean groupByDate; private boolean groupByDate;
private String dateField;
public List<String> getDateList() { public List<String> getDateList() {
if (!CollectionUtils.isEmpty(dateList)) { if (!CollectionUtils.isEmpty(dateList)) {
return dateList; return dateList;
@@ -49,18 +50,6 @@ public class DateConf {
return DateUtils.getDateList(startDateStr, endDateStr, getPeriod()); return DateUtils.getDateList(startDateStr, endDateStr, getPeriod());
} }
public String getGroupByTimeDimension() {
if (DatePeriodEnum.DAY.equals(period)) {
return TimeDimensionEnum.DAY.getName();
} else if (DatePeriodEnum.WEEK.equals(period)) {
return TimeDimensionEnum.WEEK.getName();
} else if (DatePeriodEnum.MONTH.equals(period)) {
return TimeDimensionEnum.MONTH.getName();
} else {
return TimeDimensionEnum.DAY.getName();
}
}
@Override @Override
public boolean equals(Object o) { public boolean equals(Object o) {
if (this == o) { if (this == o) {

View File

@@ -1,73 +1,5 @@
package com.tencent.supersonic.common.pojo.enums; package com.tencent.supersonic.common.pojo.enums;
import org.springframework.util.CollectionUtils;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public enum TimeDimensionEnum { public enum TimeDimensionEnum {
DAY("sys_imp_date", "数据日期"), DAY, WEEK, MONTH;
WEEK("sys_imp_week", "数据日期_周"),
MONTH("sys_imp_month", "数据日期_月");
private String name;
private String chName;
TimeDimensionEnum(String name, String chName) {
this.name = name;
this.chName = chName;
}
public static boolean containsTimeDimension(String fieldName) {
if (getNameList().contains(fieldName) || getChNameList().contains(fieldName)) {
return true;
}
return false;
}
public static List<String> getNameList() {
return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getName)
.collect(Collectors.toList());
}
public static List<String> getChNameList() {
return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getChName)
.collect(Collectors.toList());
}
public static Map<String, String> getChNameToNameMap() {
return Arrays.stream(TimeDimensionEnum.values()).collect(Collectors
.toMap(TimeDimensionEnum::getChName, TimeDimensionEnum::getName, (k1, k2) -> k1));
}
public static Map<String, String> getNameToNameMap() {
return Arrays.stream(TimeDimensionEnum.values()).collect(Collectors
.toMap(TimeDimensionEnum::getName, TimeDimensionEnum::getName, (k1, k2) -> k1));
}
public String getName() {
return name;
}
public String getChName() {
return chName;
}
/**
* Determine if a time dimension field is included in a Chinese/English text field
*
* @param fields field
* @return true/false
*/
public static boolean containsZhTimeDimension(List<String> fields) {
if (CollectionUtils.isEmpty(fields)) {
return false;
}
return fields.stream().anyMatch(field -> containsTimeDimension(field));
}
} }

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.ItemDateResp; import com.tencent.supersonic.common.pojo.ItemDateResp;
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum; import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import lombok.Data; import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@@ -21,22 +20,14 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.StringJoiner; import java.util.StringJoiner;
import java.util.regex.Pattern;
import static com.tencent.supersonic.common.pojo.Constants.APOSTROPHE; import static com.tencent.supersonic.common.pojo.Constants.*;
import static com.tencent.supersonic.common.pojo.Constants.COMMA;
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT;
@Slf4j @Slf4j
@Component @Component
@Data @Data
public class DateModeUtils { public class DateModeUtils {
private final String sysDateCol = TimeDimensionEnum.DAY.getName();
private final String sysDateMonthCol = TimeDimensionEnum.MONTH.getName();
private final String sysDateWeekCol = TimeDimensionEnum.WEEK.getName();
@Value("${s2.query.parameter.sys.zipper.begin:start_}") @Value("${s2.query.parameter.sys.zipper.begin:start_}")
private String sysZipperDateColBegin; private String sysZipperDateColBegin;
@@ -60,8 +51,8 @@ public class DateModeUtils {
public String hasDataModeStr(ItemDateResp dateDate, DateConf dateInfo) { public String hasDataModeStr(ItemDateResp dateDate, DateConf dateInfo) {
if (Objects.isNull(dateDate) || StringUtils.isEmpty(dateDate.getStartDate()) if (Objects.isNull(dateDate) || StringUtils.isEmpty(dateDate.getStartDate())
|| StringUtils.isEmpty(dateDate.getStartDate())) { || StringUtils.isEmpty(dateDate.getStartDate())) {
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dateInfo.getStartDate(), return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(),
sysDateCol, dateInfo.getEndDate()); dateInfo.getStartDate(), dateInfo.getDateField(), dateInfo.getEndDate());
} else { } else {
log.info("dateDate:{}", dateDate); log.info("dateDate:{}", dateDate);
} }
@@ -79,27 +70,28 @@ public class DateModeUtils {
dateFormatStr, ChronoUnit.DAYS); dateFormatStr, ChronoUnit.DAYS);
LocalDate dateMax = endData; LocalDate dateMax = endData;
LocalDate dateMin = dateMax.minusDays(unit - 1); LocalDate dateMin = dateMax.minusDays(unit - 1);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dateMin, sysDateCol, return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(),
dateMax); dateMin, dateInfo.getDateField(), dateMax);
} }
if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) { if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) {
Long unit = getInterval(dateInfo.getStartDate(), dateInfo.getEndDate(), Long unit = getInterval(dateInfo.getStartDate(), dateInfo.getEndDate(),
dateFormatStr, ChronoUnit.MONTHS); dateFormatStr, ChronoUnit.MONTHS);
return generateMonthSql(endData, unit, dateFormatStr); return generateMonthSql(endData, unit, dateFormatStr, dateInfo);
} }
} }
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dateInfo.getStartDate(), return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(),
sysDateCol, dateInfo.getEndDate()); dateInfo.getStartDate(), dateInfo.getDateField(), dateInfo.getEndDate());
} }
public String generateMonthSql(LocalDate endData, Long unit, String dateFormatStr) { public String generateMonthSql(LocalDate endData, Long unit, String dateFormatStr,
DateConf dateConf) {
LocalDate dateMax = endData; LocalDate dateMax = endData;
List<String> months = generateMonthStr(dateMax, unit, dateFormatStr); List<String> months = generateMonthStr(dateMax, unit, dateFormatStr);
if (!CollectionUtils.isEmpty(months)) { if (!CollectionUtils.isEmpty(months)) {
StringJoiner joiner = new StringJoiner(","); StringJoiner joiner = new StringJoiner(",");
months.stream().forEach(month -> joiner.add("'" + month + "'")); months.stream().forEach(month -> joiner.add("'" + month + "'"));
return String.format("(%s in (%s))", sysDateCol, joiner.toString()); return String.format("(%s in (%s))", dateConf.getDateField(), joiner.toString());
} }
return ""; return "";
} }
@@ -116,8 +108,8 @@ public class DateModeUtils {
public String recentDayStr(ItemDateResp dateDate, DateConf dateInfo) { public String recentDayStr(ItemDateResp dateDate, DateConf dateInfo) {
ImmutablePair<String, String> dayRange = recentDay(dateDate, dateInfo); ImmutablePair<String, String> dayRange = recentDay(dateDate, dateInfo);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dayRange.left, sysDateCol, return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(), dayRange.left,
dayRange.right); dateInfo.getDateField(), dayRange.right);
} }
public ImmutablePair<String, String> recentDay(ItemDateResp dateDate, DateConf dateInfo) { public ImmutablePair<String, String> recentDay(ItemDateResp dateDate, DateConf dateInfo) {
@@ -134,24 +126,25 @@ public class DateModeUtils {
return ImmutablePair.of(start, dateDate.getEndDate()); return ImmutablePair.of(start, dateDate.getEndDate());
} }
public String recentMonthStr(LocalDate endData, Long unit, String dateFormatStr) { public String recentMonthStr(LocalDate endData, Long unit, String dateFormatStr,
DateConf dateInfo) {
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(dateFormatStr); DateTimeFormatter formatter = DateTimeFormatter.ofPattern(dateFormatStr);
String endStr = endData.format(formatter); String endStr = endData.format(formatter);
String start = endData.minusMonths(unit).format(formatter); String start = endData.minusMonths(unit).format(formatter);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateMonthCol, start, sysDateMonthCol, return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(), start,
endStr); dateInfo.getDateField(), endStr);
} }
public String recentMonthStr(ItemDateResp dateDate, DateConf dateInfo) { public String recentMonthStr(ItemDateResp dateDate, DateConf dateInfo) {
List<ImmutablePair<String, String>> range = recentMonth(dateDate, dateInfo); List<ImmutablePair<String, String>> range = recentMonth(dateDate, dateInfo);
if (range.size() == 1) { if (range.size() == 1) {
return String.format("(%s >= '%s' and %s <= '%s')", sysDateMonthCol, range.get(0).left, return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(),
sysDateMonthCol, range.get(0).right); range.get(0).left, dateInfo.getDateField(), range.get(0).right);
} }
if (range.size() > 0) { if (range.size() > 0) {
StringJoiner joiner = new StringJoiner(","); StringJoiner joiner = new StringJoiner(",");
range.stream().forEach(month -> joiner.add("'" + month.left + "'")); range.stream().forEach(month -> joiner.add("'" + month.left + "'"));
return String.format("(%s in (%s))", sysDateCol, joiner.toString()); return String.format("(%s in (%s))", dateInfo.getDateField(), joiner.toString());
} }
return ""; return "";
} }
@@ -181,17 +174,17 @@ public class DateModeUtils {
return ret; return ret;
} }
public String recentWeekStr(LocalDate endData, Long unit) { public String recentWeekStr(LocalDate endData, Long unit, DateConf dataInfo) {
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(DAY_FORMAT); DateTimeFormatter formatter = DateTimeFormatter.ofPattern(DAY_FORMAT);
String start = endData.minusDays(unit * 7).format(formatter); String start = endData.minusDays(unit * 7).format(formatter);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateWeekCol, start, sysDateWeekCol, return String.format("(%s >= '%s' and %s <= '%s')", dataInfo.getDateField(), start,
endData.format(formatter)); dataInfo.getDateField(), endData.format(formatter));
} }
public String recentWeekStr(ItemDateResp dateDate, DateConf dateInfo) { public String recentWeekStr(ItemDateResp dateDate, DateConf dateInfo) {
ImmutablePair<String, String> dayRange = recentWeek(dateDate, dateInfo); ImmutablePair<String, String> dayRange = recentWeek(dateDate, dateInfo);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateWeekCol, dayRange.left, return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(), dayRange.left,
sysDateWeekCol, dayRange.right); dateInfo.getDateField(), dayRange.right);
} }
public ImmutablePair<String, String> recentWeek(ItemDateResp dateDate, DateConf dateInfo) { public ImmutablePair<String, String> recentWeek(ItemDateResp dateDate, DateConf dateInfo) {
@@ -242,26 +235,27 @@ public class DateModeUtils {
* @return * @return
*/ */
public String betweenDateStr(DateConf dateInfo) { public String betweenDateStr(DateConf dateInfo) {
String dateField = dateInfo.getDateField();
if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) { if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) {
// startDate YYYYMM // startDate YYYYMM
if (!dateInfo.getStartDate().contains(Constants.MINUS)) { if (!dateInfo.getStartDate().contains(Constants.MINUS)) {
return String.format("%s >= '%s' and %s <= '%s'", sysDateMonthCol, return String.format("%s >= '%s' and %s <= '%s'", dateField,
dateInfo.getStartDate(), sysDateMonthCol, dateInfo.getEndDate()); dateInfo.getStartDate(), dateField, dateInfo.getEndDate());
} }
LocalDate endData = LocalDate endData =
LocalDate.parse(dateInfo.getEndDate(), DateTimeFormatter.ofPattern(DAY_FORMAT)); LocalDate.parse(dateInfo.getEndDate(), DateTimeFormatter.ofPattern(DAY_FORMAT));
LocalDate startData = LocalDate.parse(dateInfo.getStartDate(), LocalDate startData = LocalDate.parse(dateInfo.getStartDate(),
DateTimeFormatter.ofPattern(DAY_FORMAT)); DateTimeFormatter.ofPattern(DAY_FORMAT));
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(MONTH_FORMAT); DateTimeFormatter formatter = DateTimeFormatter.ofPattern(MONTH_FORMAT);
return String.format("%s >= '%s' and %s <= '%s'", sysDateMonthCol, return String.format("%s >= '%s' and %s <= '%s'", dateField,
startData.format(formatter), sysDateMonthCol, endData.format(formatter)); startData.format(formatter), dateField, endData.format(formatter));
} }
if (DatePeriodEnum.WEEK.equals(dateInfo.getPeriod())) { if (DatePeriodEnum.WEEK.equals(dateInfo.getPeriod())) {
return String.format("%s >= '%s' and %s <= '%s'", sysDateWeekCol, return String.format("%s >= '%s' and %s <= '%s'", dateField, dateInfo.getStartDate(),
dateInfo.getStartDate(), sysDateWeekCol, dateInfo.getEndDate()); dateField, dateInfo.getEndDate());
} }
return String.format("%s >= '%s' and %s <= '%s'", sysDateCol, dateInfo.getStartDate(), return String.format("%s >= '%s' and %s <= '%s'", dateField, dateInfo.getStartDate(),
sysDateCol, dateInfo.getEndDate()); dateField, dateInfo.getEndDate());
} }
/** /**
@@ -273,12 +267,12 @@ public class DateModeUtils {
public String listDateStr(DateConf dateInfo) { public String listDateStr(DateConf dateInfo) {
StringJoiner joiner = new StringJoiner(COMMA); StringJoiner joiner = new StringJoiner(COMMA);
dateInfo.getDateList().stream().forEach(date -> joiner.add(APOSTROPHE + date + APOSTROPHE)); dateInfo.getDateList().stream().forEach(date -> joiner.add(APOSTROPHE + date + APOSTROPHE));
String dateCol = sysDateCol; String dateCol = dateInfo.getDateField();
if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) { if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) {
dateCol = sysDateMonthCol; dateCol = dateInfo.getDateField();
} }
if (DatePeriodEnum.WEEK.equals(dateInfo.getPeriod())) { if (DatePeriodEnum.WEEK.equals(dateInfo.getPeriod())) {
dateCol = sysDateWeekCol; dateCol = dateInfo.getDateField();
} }
return String.format("(%s in (%s))", dateCol, joiner.toString()); return String.format("(%s in (%s))", dateCol, joiner.toString());
} }
@@ -299,25 +293,26 @@ public class DateModeUtils {
if (DatePeriodEnum.DAY.equals(dateInfo.getPeriod())) { if (DatePeriodEnum.DAY.equals(dateInfo.getPeriod())) {
LocalDate dateMax = LocalDate.now().minusDays(1); LocalDate dateMax = LocalDate.now().minusDays(1);
LocalDate dateMin = dateMax.minusDays(unit - 1); LocalDate dateMin = dateMax.minusDays(unit - 1);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dateMin, sysDateCol, return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(), dateMin,
dateMax); dateInfo.getDateField(), dateMax);
} }
if (DatePeriodEnum.WEEK.equals(dateInfo.getPeriod())) { if (DatePeriodEnum.WEEK.equals(dateInfo.getPeriod())) {
LocalDate dateMax = LocalDate.now().minusDays(1); LocalDate dateMax = LocalDate.now().minusDays(1);
return recentWeekStr(dateMax, unit.longValue()); return recentWeekStr(dateMax, unit.longValue(), dateInfo);
} }
if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) { if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) {
LocalDate dateMax = LocalDate.now().minusDays(1); LocalDate dateMax = LocalDate.now().minusDays(1);
return recentMonthStr(dateMax, unit.longValue(), MONTH_FORMAT); return recentMonthStr(dateMax, unit.longValue(), MONTH_FORMAT, dateInfo);
} }
if (DatePeriodEnum.YEAR.equals(dateInfo.getPeriod())) { if (DatePeriodEnum.YEAR.equals(dateInfo.getPeriod())) {
LocalDate dateMax = LocalDate.now().minusDays(1); LocalDate dateMax = LocalDate.now().minusDays(1);
return recentMonthStr(dateMax, unit.longValue() * 12, MONTH_FORMAT); return recentMonthStr(dateMax, unit.longValue() * 12, MONTH_FORMAT, dateInfo);
} }
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, return String.format("(%s >= '%s' and %s <= '%s')", dateInfo.getDateField(),
LocalDate.now().minusDays(2), sysDateCol, LocalDate.now().minusDays(1)); LocalDate.now().minusDays(2), dateInfo.getDateField(),
LocalDate.now().minusDays(1));
} }
public String getDateWhereStr(DateConf dateInfo) { public String getDateWhereStr(DateConf dateInfo) {
@@ -349,32 +344,7 @@ public class DateModeUtils {
} }
public String getSysDateCol(DateConf dateInfo) { public String getSysDateCol(DateConf dateInfo) {
if (DatePeriodEnum.DAY.equals(dateInfo.getPeriod())) { return dateInfo.getDateField();
return sysDateCol;
}
if (DatePeriodEnum.WEEK.equals(dateInfo.getPeriod())) {
return sysDateWeekCol;
}
if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) {
return sysDateMonthCol;
}
return "";
} }
public boolean isDateStr(String date) {
return Pattern.matches("[\\d\\s-:]+", date);
}
public DatePeriodEnum getPeriodByCol(String col) {
if (sysDateCol.equalsIgnoreCase(col)) {
return DatePeriodEnum.DAY;
}
if (sysDateWeekCol.equalsIgnoreCase(col)) {
return DatePeriodEnum.WEEK;
}
if (sysDateMonthCol.equalsIgnoreCase(col)) {
return DatePeriodEnum.MONTH;
}
return null;
}
} }

View File

@@ -10,6 +10,8 @@ import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import static java.time.Duration.ofSeconds;
@Service @Service
public class ZhipuModelFactory implements ModelFactory, InitializingBean { public class ZhipuModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "ZHIPU"; public static final String PROVIDER = "ZHIPU";
@@ -30,8 +32,9 @@ public class ZhipuModelFactory implements ModelFactory, InitializingBean {
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) { public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl()) return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName()) .apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries()) .maxRetries(embeddingModelConfig.getMaxRetries()).callTimeout(ofSeconds(60))
.logRequests(embeddingModelConfig.getLogRequests()) .connectTimeout(ofSeconds(60)).writeTimeout(ofSeconds(60))
.readTimeout(ofSeconds(60)).logRequests(embeddingModelConfig.getLogRequests())
.logResponses(embeddingModelConfig.getLogResponses()).build(); .logResponses(embeddingModelConfig.getLogResponses()).build();
} }

View File

@@ -11,31 +11,31 @@ class SqlDateSelectHelperTest {
String sql = "SELECT 维度1,sum(播放量) FROM 数据库 " String sql = "SELECT 维度1,sum(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1"; + "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1";
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql); DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql, "数据日期");
Assert.assertEquals(dateBoundInfo.getLowerBound(), ">="); Assert.assertEquals(dateBoundInfo.getLowerBound(), ">=");
Assert.assertEquals(dateBoundInfo.getLowerDate(), "2023-11-17"); Assert.assertEquals(dateBoundInfo.getLowerDate(), "2023-11-17");
sql = "SELECT 维度1,sum(播放量) FROM 数据库 " sql = "SELECT 维度1,sum(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1"; + "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1";
dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql); dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql, "数据日期");
Assert.assertEquals(dateBoundInfo.getLowerBound(), ">"); Assert.assertEquals(dateBoundInfo.getLowerBound(), ">");
Assert.assertEquals(dateBoundInfo.getLowerDate(), "2023-11-17"); Assert.assertEquals(dateBoundInfo.getLowerDate(), "2023-11-17");
sql = "SELECT 维度1,sum(播放量) FROM 数据库 " sql = "SELECT 维度1,sum(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1"; + "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql); dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql, "数据日期");
Assert.assertEquals(dateBoundInfo.getUpperBound(), "<="); Assert.assertEquals(dateBoundInfo.getUpperBound(), "<=");
Assert.assertEquals(dateBoundInfo.getUpperDate(), "2023-11-17"); Assert.assertEquals(dateBoundInfo.getUpperDate(), "2023-11-17");
sql = "SELECT 维度1,sum(播放量) FROM 数据库 " sql = "SELECT 维度1,sum(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1"; + "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1";
dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql); dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql, "数据日期");
Assert.assertEquals(dateBoundInfo.getUpperBound(), "<"); Assert.assertEquals(dateBoundInfo.getUpperBound(), "<");
Assert.assertEquals(dateBoundInfo.getUpperDate(), "2023-11-17"); Assert.assertEquals(dateBoundInfo.getUpperDate(), "2023-11-17");
sql = "SELECT 维度1,sum(播放量) FROM 数据库 " + "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-10-17' " sql = "SELECT 维度1,sum(播放量) FROM 数据库 " + "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-10-17' "
+ "AND 数据日期 <= '2023-11-17' GROUP BY 维度1"; + "AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql); dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql, "数据日期");
Assert.assertEquals(dateBoundInfo.getUpperBound(), "<="); Assert.assertEquals(dateBoundInfo.getUpperBound(), "<=");
Assert.assertEquals(dateBoundInfo.getUpperDate(), "2023-11-17"); Assert.assertEquals(dateBoundInfo.getUpperDate(), "2023-11-17");
Assert.assertEquals(dateBoundInfo.getLowerBound(), ">="); Assert.assertEquals(dateBoundInfo.getLowerBound(), ">=");

View File

@@ -157,6 +157,14 @@ class SqlRemoveHelperTest {
replaceSql = SqlRemoveHelper.removeWhereCondition(sql, removeFieldNames); replaceSql = SqlRemoveHelper.removeWhereCondition(sql, removeFieldNames);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1) " Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1) "
+ "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", replaceSql); + "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", replaceSql);
sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "and 歌曲名 between '2023-08-09' and '2024-08-09' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'"
+ " order by 播放量 desc limit 11";
replaceSql = SqlRemoveHelper.removeWhereCondition(sql, removeFieldNames);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' "
+ "ORDER BY 播放量 DESC LIMIT 11", replaceSql);
} }
@Test @Test

View File

@@ -324,6 +324,39 @@ class SqlReplaceHelperTest {
replaceSql); replaceSql);
} }
@Test
void testReplaceAliasFieldName() {
Map<String, String> map = new HashMap<>();
map.put("总访问次数", "\"总访问次数\"");
map.put("访问次数", "\"访问次数\"");
String sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where "
+ "datediff('day', 数据日期, '2023-09-05') <= 3 group by 部门 order by 总访问次数 desc limit 10";
String replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map);
System.out.println(replaceSql);
Assert.assertEquals("SELECT 部门, sum(访问次数) AS \"总访问次数\" FROM 超音数 WHERE "
+ "datediff('day', 数据日期, '2023-09-05') <= 3 GROUP BY 部门 ORDER BY \"总访问次数\" DESC LIMIT 10",
replaceSql);
sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where "
+ "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' "
+ "group by 部门 order by 总访问次数 desc limit 10";
replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map);
System.out.println(replaceSql);
Assert.assertEquals("SELECT 部门, sum(访问次数) AS \"总访问次数\" FROM 超音数 WHERE "
+ "(datediff('day', 数据日期, '2023-09-05') <= 3) AND 数据日期 = '2023-10-10' "
+ "GROUP BY 部门 ORDER BY \"总访问次数\" DESC LIMIT 10", replaceSql);
sql = "select 部门, sum(访问次数) as 访问次数 from 超音数 where "
+ "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' "
+ "group by 部门 order by 访问次数 desc limit 10";
replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map);
System.out.println(replaceSql);
Assert.assertEquals(
"SELECT 部门, sum(\"访问次数\") AS \"访问次数\" FROM 超音数 WHERE (datediff('day', 数据日期, "
+ "'2023-09-05') <= 3) AND 数据日期 = '2023-10-10' GROUP BY 部门 ORDER BY \"访问次数\" DESC LIMIT 10",
replaceSql);
}
@Test @Test
void testReplaceAggAliasOrderbyField() { void testReplaceAggAliasOrderbyField() {
String sql = "SELECT SUM(访问次数) AS top10总播放量 FROM (SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数 " String sql = "SELECT SUM(访问次数) AS top10总播放量 FROM (SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数 "

View File

@@ -1,5 +1,5 @@
# Use an official OpenJDK runtime as a parent image # Use an official OpenJDK runtime as a parent image
FROM openjdk:8-jdk FROM openjdk:21-jdk-bullseye
# Set the working directory in the container # Set the working directory in the container
WORKDIR /usr/src/app WORKDIR /usr/src/app
@@ -7,8 +7,6 @@ WORKDIR /usr/src/app
# Argument to pass in the supersonic version at build time # Argument to pass in the supersonic version at build time
ARG SUPERSONIC_VERSION ARG SUPERSONIC_VERSION
RUN apt-get update
# Install necessary packages, including Postgres client # Install necessary packages, including Postgres client
RUN apt-get update && apt-get install -y postgresql-client RUN apt-get update && apt-get install -y postgresql-client

View File

@@ -1,3 +1,3 @@
#!/usr/bin/env bash #!/usr/bin/env bash
SUPERSONIC_VERSION=0.9.10-SNAPSHOT docker-compose -f docker-compose.yml -p supersonic up SUPERSONIC_VERSION=latest docker-compose -f docker-compose.yml -p supersonic up

View File

@@ -11,8 +11,8 @@ services:
POSTGRES_PASSWORD: supersonic_password POSTGRES_PASSWORD: supersonic_password
ports: ports:
- "15432:5432" - "15432:5432"
volumes: # volumes:
- postgres_data:/var/lib/postgresql # - postgres_data:/var/lib/postgresql
networks: networks:
- supersonic_network - supersonic_network
dns: dns:
@@ -72,9 +72,9 @@ services:
- 114.114.114.114 - 114.114.114.114
- 8.8.8.8 - 8.8.8.8
- 8.8.4.4 - 8.8.4.4
volumes: #volumes:
#1.Named Volumes are best for persistent data managed by Docker. #1.Named Volumes are best for persistent data managed by Docker.
- supersonic_data:/usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-latest} #- supersonic_data:/usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-latest}
#2.Bind Mounts are suitable for frequent modifications and debugging. #2.Bind Mounts are suitable for frequent modifications and debugging.
# - ./conf/application-prd.yaml:/usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-latest}/conf/application-prd.yaml # - ./conf/application-prd.yaml:/usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-latest}/conf/application-prd.yaml
#3.Detailed Bind Mounts offer more control over the mount behavior. #3.Detailed Bind Mounts offer more control over the mount behavior.
@@ -84,9 +84,9 @@ services:
# bind: # bind:
# propagation: rprivate # propagation: rprivate
# create_host_path: true # create_host_path: true
volumes: #volumes:
postgres_data: # postgres_data:
supersonic_data: # supersonic_data:
networks: networks:
supersonic_network: supersonic_network:

View File

@@ -12,6 +12,7 @@ TAGS="latest"
# If VERSION is provided, add it to TAGS and tag the image as latest # If VERSION is provided, add it to TAGS and tag the image as latest
if [ -n "$VERSION" ]; then if [ -n "$VERSION" ]; then
TAGS="$TAGS $VERSION" TAGS="$TAGS $VERSION"
echo "Tagging Docker images $IMAGE_NAME:$VERSION to $IMAGE_NAME:latest"
docker tag $IMAGE_NAME:$VERSION $IMAGE_NAME:latest docker tag $IMAGE_NAME:$VERSION $IMAGE_NAME:latest
fi fi

View File

@@ -32,6 +32,16 @@ public class Dimension {
this.type = type; this.type = type;
this.isCreateDimension = isCreateDimension; this.isCreateDimension = isCreateDimension;
this.bizName = bizName; this.bizName = bizName;
this.expr = bizName;
}
public Dimension(String name, String bizName, String expr, DimensionType type,
Integer isCreateDimension) {
this.name = name;
this.type = type;
this.isCreateDimension = isCreateDimension;
this.bizName = bizName;
this.expr = expr;
} }
public Dimension(String name, String bizName, DimensionType type, Integer isCreateDimension, public Dimension(String name, String bizName, DimensionType type, Integer isCreateDimension,
@@ -45,12 +55,7 @@ public class Dimension {
this.bizName = bizName; this.bizName = bizName;
} }
public static Dimension getDefault() {
return new Dimension("数据日期", "imp_date", DimensionType.partition_time, 0, "imp_date",
Constants.DAY_FORMAT, new DimensionTimeTypeParams("false", "day"));
}
public String getFieldName() { public String getFieldName() {
return bizName; return expr;
} }
} }

View File

@@ -23,11 +23,20 @@ public class Measure {
private String alias; private String alias;
public Measure(String name, String bizName, String expr, String agg, Integer isCreateMetric) {
this.name = name;
this.agg = agg;
this.isCreateMetric = isCreateMetric;
this.bizName = bizName;
this.expr = expr;
}
public Measure(String name, String bizName, String agg, Integer isCreateMetric) { public Measure(String name, String bizName, String agg, Integer isCreateMetric) {
this.name = name; this.name = name;
this.agg = agg; this.agg = agg;
this.isCreateMetric = isCreateMetric; this.isCreateMetric = isCreateMetric;
this.bizName = bizName; this.bizName = bizName;
this.expr = bizName;
} }
public Measure(String bizName, String constraint) { public Measure(String bizName, String constraint) {
@@ -38,4 +47,5 @@ public class Measure {
public String getFieldName() { public String getFieldName() {
return expr; return expr;
} }
} }

View File

@@ -8,5 +8,5 @@ import java.util.List;
@Data @Data
public class MetricDefineByMeasureParams extends MetricDefineParams { public class MetricDefineByMeasureParams extends MetricDefineParams {
private List<MeasureParam> measures = Lists.newArrayList(); private List<Measure> measures = Lists.newArrayList();
} }

View File

@@ -18,6 +18,8 @@ public class ModelDetail {
private String queryType; private String queryType;
private String dbType;
private String sqlQuery; private String sqlQuery;
private String tableQuery; private String tableQuery;

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.headless.api.pojo.enums; package com.tencent.supersonic.headless.api.pojo.enums;
public enum MapModeEnum { public enum MapModeEnum {
STRICT(0), MODERATE(2), LOOSE(4); STRICT(0), MODERATE(2), LOOSE(4), ALL(6);
public int threshold; public int threshold;

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.headless.api.pojo.enums; package com.tencent.supersonic.headless.api.pojo.enums;
import com.tencent.supersonic.headless.api.pojo.MeasureParam; import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByMeasureParams; import com.tencent.supersonic.headless.api.pojo.MetricDefineByMeasureParams;
import java.util.List; import java.util.List;
@@ -32,7 +32,7 @@ public enum MetricType {
return true; return true;
} }
if (MetricDefineType.MEASURE.equals(metricDefineType)) { if (MetricDefineType.MEASURE.equals(metricDefineType)) {
List<MeasureParam> measures = typeParams.getMeasures(); List<Measure> measures = typeParams.getMeasures();
if (measures.size() > 1) { if (measures.size() > 1) {
return true; return true;
} }

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
import lombok.Data; import lombok.Data;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@@ -31,7 +32,7 @@ public class DimensionReq extends SchemaItem {
private DataTypeEnums dataType; private DataTypeEnums dataType;
private Map<String, Object> ext; private Map<String, Object> ext = new HashMap();
private DimensionTimeTypeParams typeParams; private DimensionTimeTypeParams typeParams;
} }

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.api.pojo.response;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.RecordInfo; import com.tencent.supersonic.common.pojo.RecordInfo;
import com.tencent.supersonic.common.util.AESEncryptionUtil;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
@@ -64,4 +65,8 @@ public class DatabaseResp extends RecordInfo {
} }
return ""; return "";
} }
public String passwordDecrypt() {
return AESEncryptionUtil.aesDecryptECB(password);
}
} }

View File

@@ -1,14 +1,26 @@
package com.tencent.supersonic.headless.api.pojo.response; package com.tencent.supersonic.headless.api.pojo.response;
import com.google.common.collect.Sets;
import lombok.Data; import lombok.Data;
import lombok.ToString; import lombok.ToString;
import java.util.List; import java.util.Set;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
public class DimSchemaResp extends DimensionResp { public class DimSchemaResp extends DimensionResp {
private Long useCnt = 0L; private Long useCnt = 0L;
private Set<String> fields = Sets.newHashSet();
@Override
public boolean equals(Object o) {
return super.equals(o);
}
@Override
public int hashCode() {
return super.hashCode();
}
} }

View File

@@ -2,13 +2,9 @@ package com.tencent.supersonic.headless.api.pojo.response;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.DataFormat; import com.tencent.supersonic.common.pojo.DataFormat;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension; import com.tencent.supersonic.headless.api.pojo.*;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByFieldParams;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByMeasureParams;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByMetricParams;
import com.tencent.supersonic.headless.api.pojo.RelateDimension;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType; import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType;
import com.tencent.supersonic.headless.api.pojo.enums.MetricType;
import lombok.Data; import lombok.Data;
import lombok.ToString; import lombok.ToString;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
@@ -69,6 +65,19 @@ public class MetricResp extends SchemaItem {
private boolean containsPartitionDimensions; private boolean containsPartitionDimensions;
public void setMetricDefinition(MetricDefineType type, MetricDefineParams params) {
if (MetricDefineType.MEASURE.equals(type)) {
assert params instanceof MetricDefineByMeasureParams;
metricDefineByMeasureParams = (MetricDefineByMeasureParams) params;
} else if (MetricDefineType.FIELD.equals(type)) {
assert params instanceof MetricDefineByFieldParams;
metricDefineByFieldParams = (MetricDefineByFieldParams) params;
} else if (MetricDefineType.METRIC.equals(type)) {
assert params instanceof MetricDefineByMetricParams;
metricDefineByMetricParams = (MetricDefineByMetricParams) params;
}
}
public void setClassifications(String tag) { public void setClassifications(String tag) {
if (StringUtils.isBlank(tag)) { if (StringUtils.isBlank(tag)) {
classifications = Lists.newArrayList(); classifications = Lists.newArrayList();
@@ -105,4 +114,8 @@ public class MetricResp extends SchemaItem {
} }
return ""; return "";
} }
public boolean isDerived() {
return MetricType.isDerived(metricDefineType, metricDefineByMeasureParams);
}
} }

View File

@@ -1,11 +1,26 @@
package com.tencent.supersonic.headless.api.pojo.response; package com.tencent.supersonic.headless.api.pojo.response;
import com.google.common.collect.Sets;
import lombok.Data; import lombok.Data;
import lombok.ToString; import lombok.ToString;
import java.util.Set;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
public class MetricSchemaResp extends MetricResp { public class MetricSchemaResp extends MetricResp {
private Long useCnt = 0L; private Long useCnt = 0L;
private Set<String> fields = Sets.newHashSet();
@Override
public boolean equals(Object o) {
return super.equals(o);
}
@Override
public int hashCode() {
return super.hashCode();
}
} }

View File

@@ -1,17 +1,9 @@
package com.tencent.supersonic.headless.api.pojo.response; package com.tencent.supersonic.headless.api.pojo.response;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.Dimension; import com.tencent.supersonic.headless.api.pojo.*;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.Field;
import com.tencent.supersonic.headless.api.pojo.Identify;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType; import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
import lombok.AllArgsConstructor; import lombok.*;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
@@ -26,6 +18,7 @@ import java.util.stream.Collectors;
@ToString(callSuper = true) @ToString(callSuper = true)
@AllArgsConstructor @AllArgsConstructor
@NoArgsConstructor @NoArgsConstructor
@Builder
public class ModelResp extends SchemaItem { public class ModelResp extends SchemaItem {
private Long domainId; private Long domainId;
@@ -62,6 +55,14 @@ public class ModelResp extends SchemaItem {
return isOpen != null && isOpen == 1; return isOpen != null && isOpen == 1;
} }
public List<Measure> getMeasures() {
return modelDetail != null ? modelDetail.getMeasures() : Lists.newArrayList();
}
public List<Identify> getIdentifiers() {
return modelDetail != null ? modelDetail.getIdentifiers() : Lists.newArrayList();
}
public List<Dimension> getTimeDimension() { public List<Dimension> getTimeDimension() {
if (modelDetail == null) { if (modelDetail == null) {
return Lists.newArrayList(); return Lists.newArrayList();

View File

@@ -3,7 +3,6 @@ package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
@@ -13,12 +12,7 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.*;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/** /**
@@ -47,18 +41,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
protected Map<String, String> getFieldNameMap(ChatQueryContext chatQueryContext, protected Map<String, String> getFieldNameMap(ChatQueryContext chatQueryContext,
Long dataSetId) { Long dataSetId) {
return getFieldNameMapFromDB(chatQueryContext, dataSetId);
Map<String, String> result = getFieldNameMapFromDB(chatQueryContext, dataSetId);
if (chatQueryContext.containsPartitionDimensions(dataSetId)) {
result.put(TimeDimensionEnum.DAY.getChName(), TimeDimensionEnum.DAY.getChName());
result.put(TimeDimensionEnum.MONTH.getChName(), TimeDimensionEnum.MONTH.getChName());
result.put(TimeDimensionEnum.WEEK.getChName(), TimeDimensionEnum.WEEK.getChName());
result.put(TimeDimensionEnum.DAY.getName(), TimeDimensionEnum.DAY.getChName());
result.put(TimeDimensionEnum.MONTH.getName(), TimeDimensionEnum.MONTH.getChName());
result.put(TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.WEEK.getChName());
}
return result;
} }
private static Map<String, String> getFieldNameMapFromDB(ChatQueryContext chatQueryContext, private static Map<String, String> getFieldNameMapFromDB(ChatQueryContext chatQueryContext,
@@ -126,7 +109,6 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
} }
return elements.stream(); return elements.stream();
}).collect(Collectors.toSet()); }).collect(Collectors.toSet());
dimensions.add(TimeDimensionEnum.DAY.getChName());
return dimensions; return dimensions;
} }
@@ -142,8 +124,6 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
SemanticParseInfo semanticParseInfo) { SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
Set<String> removeFieldNames = new HashSet<>(); Set<String> removeFieldNames = new HashSet<>();
removeFieldNames.addAll(TimeDimensionEnum.getChNameList());
removeFieldNames.addAll(TimeDimensionEnum.getNameList());
Map<String, String> fieldNameMap = Map<String, String> fieldNameMap =
getFieldNameMapFromDB(chatQueryContext, semanticParseInfo.getDataSetId()); getFieldNameMapFromDB(chatQueryContext, semanticParseInfo.getDataSetId());
removeFieldNames.removeIf(fieldName -> fieldNameMap.containsKey(fieldName)); removeFieldNames.removeIf(fieldName -> fieldNameMap.containsKey(fieldName));

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper; import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
@@ -53,10 +52,6 @@ public class GroupByCorrector extends BaseSemanticCorrector {
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) { if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
return false; return false;
} }
// if only date in select not add group by.
if (selectFields.size() == 1 && TimeDimensionEnum.containsZhTimeDimension(selectFields)) {
return false;
}
if (SqlSelectHelper.hasGroupBy(correctS2SQL)) { if (SqlSelectHelper.hasGroupBy(correctS2SQL)) {
log.debug("No need to add 'group by', existed 'group by' in s2sql:{}", correctS2SQL); log.debug("No need to add 'group by', existed 'group by' in s2sql:{}", correctS2SQL);
return false; return false;

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.chat.corrector; package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.AppModule; import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
@@ -34,11 +35,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
+ "please take a review and help correct it if necessary." + "\n#Rules: " + "please take a review and help correct it if necessary." + "\n#Rules: "
+ "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),sql=(corrected sql if NEGATIVE; empty string if POSITIVE)`." + "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),sql=(corrected sql if NEGATIVE; empty string if POSITIVE)`."
+ "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard." + "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard."
+ "\n3.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." + "\n3.SQL columns and values must be mentioned in the `#Schema`."
+ "\n4.ALWAYS use `with` statement if nested aggregation is needed." + "\n#Question:{{question}} #Schema:{{schema}} #InputSQL:{{sql}} #Response:";
+ "\n5.ALWAYS enclose alias declared by `AS` command in underscores."
+ "\n6.Alias created by `AS` command must be in the same language ast the `Question`."
+ "\n#Question:{{question}} #InputSQL:{{sql}} #Response:";
public LLMSqlCorrector() { public LLMSqlCorrector() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL修正") ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL修正")
@@ -67,12 +65,15 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
return; return;
} }
Text2SQLExemplar exemplar = (Text2SQLExemplar) semanticParseInfo.getProperties()
.get(Text2SQLExemplar.PROPERTY_KEY);
ChatLanguageModel chatLanguageModel = ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(chatApp.getChatModelConfig()); ModelProvider.getChatModel(chatApp.getChatModelConfig());
SemanticSqlExtractor extractor = SemanticSqlExtractor extractor =
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel); AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
Prompt prompt = generatePrompt(chatQueryContext.getRequest().getQueryText(), Prompt prompt = generatePrompt(chatQueryContext.getRequest().getQueryText(),
semanticParseInfo, chatApp.getPrompt()); semanticParseInfo, chatApp.getPrompt(), exemplar);
SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText()); SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText());
keyPipelineLog.info("LLMSqlCorrector modelReq:\n{} \nmodelResp:\n{}", prompt.text(), s2Sql); keyPipelineLog.info("LLMSqlCorrector modelReq:\n{} \nmodelResp:\n{}", prompt.text(), s2Sql);
if ("NEGATIVE".equals(s2Sql.getOpinion()) && StringUtils.isNotBlank(s2Sql.getSql())) { if ("NEGATIVE".equals(s2Sql.getOpinion()) && StringUtils.isNotBlank(s2Sql.getSql())) {
@@ -81,10 +82,11 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
} }
private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo, private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo,
String promptTemplate) { String promptTemplate, Text2SQLExemplar exemplar) {
Map<String, Object> variable = new HashMap<>(); Map<String, Object> variable = new HashMap<>();
variable.put("question", queryText); variable.put("question", queryText);
variable.put("sql", semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); variable.put("sql", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
variable.put("schema", exemplar.getDbSchema());
return PromptTemplate.from(promptTemplate).apply(variable); return PromptTemplate.from(promptTemplate).apply(variable);
} }

View File

@@ -1,14 +1,8 @@
package com.tencent.supersonic.headless.chat.corrector; package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.AggregateEnum; import com.tencent.supersonic.common.jsqlparser.*;
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.jsqlparser.SqlAsHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
@@ -21,11 +15,7 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.*;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/** Perform schema corrections on the Schema information in S2SQL. */ /** Perform schema corrections on the Schema information in S2SQL. */
@@ -144,8 +134,6 @@ public class SchemaCorrector extends BaseSemanticCorrector {
Set<String> removeFieldNames = whereExpressionList.stream() Set<String> removeFieldNames = whereExpressionList.stream()
.filter(fieldExpression -> StringUtils.isBlank(fieldExpression.getFunction())) .filter(fieldExpression -> StringUtils.isBlank(fieldExpression.getFunction()))
.filter(fieldExpression -> !TimeDimensionEnum
.containsTimeDimension(fieldExpression.getFieldName()))
.filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue() .filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue()
.equals(fieldExpression.getOperator())) .equals(fieldExpression.getOperator()))
.filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName())) .filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName()))

View File

@@ -5,7 +5,6 @@ import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlDateSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlDateSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.QueryConfig; import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
@@ -44,8 +43,7 @@ public class TimeCorrector extends BaseSemanticCorrector {
DataSetSchema dataSetSchema = DataSetSchema dataSetSchema =
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId); chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
if (Objects.isNull(dataSetSchema) || Objects.isNull(dataSetSchema.getPartitionDimension()) if (Objects.isNull(dataSetSchema) || Objects.isNull(dataSetSchema.getPartitionDimension())
|| Objects.isNull(dataSetSchema.getPartitionDimension().getName()) || Objects.isNull(dataSetSchema.getPartitionDimension().getName())) {
|| TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
return; return;
} }
String partitionDimension = dataSetSchema.getPartitionDimension().getName(); String partitionDimension = dataSetSchema.getPartitionDimension().getName();
@@ -75,7 +73,8 @@ public class TimeCorrector extends BaseSemanticCorrector {
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) { private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL); DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL,
semanticParseInfo.getDateInfo().getDateField());
if (dateBoundInfo != null && StringUtils.isBlank(dateBoundInfo.getLowerBound()) if (dateBoundInfo != null && StringUtils.isBlank(dateBoundInfo.getLowerBound())
&& StringUtils.isNotBlank(dateBoundInfo.getUpperBound()) && StringUtils.isNotBlank(dateBoundInfo.getUpperBound())

View File

@@ -0,0 +1,40 @@
package com.tencent.supersonic.headless.chat.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
import java.util.Map;
@Slf4j
public class AllFieldMapper extends BaseMapper {
@Override
public boolean accept(ChatQueryContext chatQueryContext) {
return MapModeEnum.ALL.equals(chatQueryContext.getRequest().getMapModeEnum());
}
@Override
public void doMap(ChatQueryContext chatQueryContext) {
Map<Long, DataSetSchema> schemaMap =
chatQueryContext.getSemanticSchema().getDataSetSchemaMap();
for (Map.Entry<Long, DataSetSchema> entry : schemaMap.entrySet()) {
List<SchemaElement> schemaElements = Lists.newArrayList();
schemaElements.addAll(entry.getValue().getDimensions());
schemaElements.addAll(entry.getValue().getMetrics());
for (SchemaElement schemaElement : schemaElements) {
chatQueryContext.getMapInfo().getMatchedElements(entry.getKey())
.add(SchemaElementMatch.builder().word(schemaElement.getName())
.element(schemaElement).detectWord(schemaElement.getName())
.similarity(1.0).build());
}
}
}
}

View File

@@ -27,6 +27,10 @@ public abstract class BaseMapper implements SchemaMapper {
@Override @Override
public void map(ChatQueryContext chatQueryContext) { public void map(ChatQueryContext chatQueryContext) {
if (!accept(chatQueryContext)) {
return;
}
String simpleName = this.getClass().getSimpleName(); String simpleName = this.getClass().getSimpleName();
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
log.debug("before {},mapInfo:{}", simpleName, log.debug("before {},mapInfo:{}", simpleName,
@@ -46,6 +50,10 @@ public abstract class BaseMapper implements SchemaMapper {
public abstract void doMap(ChatQueryContext chatQueryContext); public abstract void doMap(ChatQueryContext chatQueryContext);
protected boolean accept(ChatQueryContext chatQueryContext) {
return true;
}
public void addToSchemaMap(SchemaMapInfo schemaMap, Long dataSetId, public void addToSchemaMap(SchemaMapInfo schemaMap, Long dataSetId,
SchemaElementMatch newElementMatch) { SchemaElementMatch newElementMatch) {
Map<Long, List<SchemaElementMatch>> dataSetElementMatches = Map<Long, List<SchemaElementMatch>> dataSetElementMatches =

View File

@@ -7,6 +7,8 @@ import com.tencent.supersonic.headless.chat.knowledge.MapResult;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.HashMap; import java.util.HashMap;
@@ -14,10 +16,17 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;
@Service @Service
@Slf4j @Slf4j
public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStrategy<T> { public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStrategy<T> {
@Autowired
@Qualifier("mapExecutor")
private ThreadPoolExecutor executor;
@Override @Override
public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms, public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) { Set<Long> detectDataSetIds) {
@@ -63,6 +72,18 @@ public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStr
} }
} }
protected void executeTasks(List<Callable<Void>> tasks) {
try {
executor.invokeAll(tasks);
for (Callable<Void> future : tasks) {
future.call();
}
} catch (Exception e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Task execution interrupted", e);
}
}
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) { public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
if (MapModeEnum.STRICT.equals(mapModeEnum)) { if (MapModeEnum.STRICT.equals(mapModeEnum)) {
return 1.0d; return 1.0d;

View File

@@ -27,12 +27,12 @@ import java.util.stream.Collectors;
@Slf4j @Slf4j
public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult> { public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult> {
private List<SchemaElement> allElements; private ThreadLocal<List<SchemaElement>> allElements = ThreadLocal.withInitial(ArrayList::new);
@Override @Override
public Map<MatchText, List<DatabaseMapResult>> match(ChatQueryContext chatQueryContext, public Map<MatchText, List<DatabaseMapResult>> match(ChatQueryContext chatQueryContext,
List<S2Term> terms, Set<Long> detectDataSetIds) { List<S2Term> terms, Set<Long> detectDataSetIds) {
this.allElements = getSchemaElements(chatQueryContext); allElements.set(getSchemaElements(chatQueryContext));
return super.match(chatQueryContext, terms, detectDataSetIds); return super.match(chatQueryContext, terms, detectDataSetIds);
} }
@@ -43,7 +43,7 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult
} }
Double metricDimensionThresholdConfig = getThreshold(chatQueryContext); Double metricDimensionThresholdConfig = getThreshold(chatQueryContext);
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements); Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements.get());
List<DatabaseMapResult> results = new ArrayList<>(); List<DatabaseMapResult> results = new ArrayList<>();
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) { for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
String name = entry.getKey(); String name = entry.getKey();

View File

@@ -20,12 +20,13 @@ import java.util.Objects;
*/ */
@Slf4j @Slf4j
public class EmbeddingMapper extends BaseMapper { public class EmbeddingMapper extends BaseMapper {
public void doMap(ChatQueryContext chatQueryContext) {
// Check if the map mode is LOOSE @Override
if (!MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum())) { public boolean accept(ChatQueryContext chatQueryContext) {
return; return MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum());
} }
public void doMap(ChatQueryContext chatQueryContext) {
// 1. Query from embedding by queryText // 1. Query from embedding by queryText
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class); EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy); List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy);
@@ -62,4 +63,5 @@ public class EmbeddingMapper extends BaseMapper {
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch); addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
} }
} }
} }

View File

@@ -16,10 +16,11 @@ import org.springframework.stereotype.Service;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_NUMBER; import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_NUMBER;
@@ -40,7 +41,7 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
@Override @Override
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext, public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, Set<String> detectSegments) { Set<Long> detectDataSetIds, Set<String> detectSegments) {
Set<EmbeddingResult> results = new HashSet<>(); Set<EmbeddingResult> results = ConcurrentHashMap.newKeySet();
int embeddingMapperBatch = Integer int embeddingMapperBatch = Integer
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH)); .valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
@@ -52,12 +53,24 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
List<List<String>> queryTextsSubList = List<List<String>> queryTextsSubList =
Lists.partition(queryTextsList, embeddingMapperBatch); Lists.partition(queryTextsList, embeddingMapperBatch);
List<Callable<Void>> tasks = new ArrayList<>();
for (List<String> queryTextsSub : queryTextsSubList) { for (List<String> queryTextsSub : queryTextsSubList) {
tasks.add(createTask(chatQueryContext, detectDataSetIds, queryTextsSub, results));
}
executeTasks(tasks);
return new ArrayList<>(results);
}
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
List<String> queryTextsSub, Set<EmbeddingResult> results) {
return () -> {
List<EmbeddingResult> oneRoundResults = List<EmbeddingResult> oneRoundResults =
detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext); detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext);
synchronized (results) {
selectResultInOneRound(results, oneRoundResults); selectResultInOneRound(results, oneRoundResults);
} }
return new ArrayList<>(results); return null;
};
} }
private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds, private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds,

View File

@@ -35,11 +35,11 @@ public class MapperConfig extends ParameterConfig {
"维度值相似度阈值在动态调整中的最低值", "number", "Mapper相关配置"); "维度值相似度阈值在动态调整中的最低值", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_TEXT_SIZE = public static final Parameter EMBEDDING_MAPPER_TEXT_SIZE =
new Parameter("s2.mapper.embedding.word.size", "4", "用于向量召回文本长度", new Parameter("s2.mapper.embedding.word.size", "3", "用于向量召回文本长度",
"为提高向量召回效率, 按指定长度进行向量语义召回", "number", "Mapper相关配置"); "为提高向量召回效率, 按指定长度进行向量语义召回", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_TEXT_STEP = public static final Parameter EMBEDDING_MAPPER_TEXT_STEP =
new Parameter("s2.mapper.embedding.word.step", "3", "向量召回文本每步长度", new Parameter("s2.mapper.embedding.word.step", "2", "向量召回文本每步长度",
"为提高向量召回效率, 按指定每步长度进行召回", "number", "Mapper相关配置"); "为提高向量召回效率, 按指定每步长度进行召回", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_BATCH = public static final Parameter EMBEDDING_MAPPER_BATCH =
@@ -51,7 +51,7 @@ public class MapperConfig extends ParameterConfig {
"每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"); "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_THRESHOLD = public static final Parameter EMBEDDING_MAPPER_THRESHOLD =
new Parameter("s2.mapper.embedding.threshold", "0.98", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", new Parameter("s2.mapper.embedding.threshold", "0.9", "向量召回相似度阈值", "相似度小于该阈值的则舍弃",
"number", "Mapper相关配置"); "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER = public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =

View File

@@ -9,22 +9,23 @@ import lombok.extern.slf4j.Slf4j;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors;
@Slf4j @Slf4j
public class TimeFieldMapper extends BaseMapper { public class PartitionTimeMapper extends BaseMapper {
@Override
public boolean accept(ChatQueryContext chatQueryContext) {
return !(chatQueryContext.getRequest().getText2SQLType().equals(Text2SQLType.ONLY_RULE)
|| chatQueryContext.getMapInfo().isEmpty());
}
@Override @Override
public void doMap(ChatQueryContext chatQueryContext) { public void doMap(ChatQueryContext chatQueryContext) {
if (chatQueryContext.getRequest().getText2SQLType().equals(Text2SQLType.ONLY_RULE)) {
return;
}
Map<Long, DataSetSchema> schemaMap = Map<Long, DataSetSchema> schemaMap =
chatQueryContext.getSemanticSchema().getDataSetSchemaMap(); chatQueryContext.getSemanticSchema().getDataSetSchemaMap();
for (Map.Entry<Long, DataSetSchema> entry : schemaMap.entrySet()) { for (Map.Entry<Long, DataSetSchema> entry : schemaMap.entrySet()) {
List<SchemaElement> timeDims = entry.getValue().getDimensions().stream() List<SchemaElement> timeDims = entry.getValue().getDimensions().stream()
.filter(dim -> dim.getTimeFormat() != null).collect(Collectors.toList()); .filter(SchemaElement::isPartitionTime).toList();
for (SchemaElement schemaElement : timeDims) { for (SchemaElement schemaElement : timeDims) {
chatQueryContext.getMapInfo().getMatchedElements(entry.getKey()) chatQueryContext.getMapInfo().getMatchedElements(entry.getKey())
.add(SchemaElementMatch.builder().word(schemaElement.getName()) .add(SchemaElementMatch.builder().word(schemaElement.getName())

View File

@@ -21,14 +21,16 @@ import java.util.stream.Collectors;
@Slf4j @Slf4j
public class QueryFilterMapper extends BaseMapper { public class QueryFilterMapper extends BaseMapper {
private double similarity = 1.0; private final double similarity = 1.0;
@Override
public boolean accept(ChatQueryContext chatQueryContext) {
return !chatQueryContext.getRequest().getDataSetIds().isEmpty();
}
@Override @Override
public void doMap(ChatQueryContext chatQueryContext) { public void doMap(ChatQueryContext chatQueryContext) {
Set<Long> dataSetIds = chatQueryContext.getRequest().getDataSetIds(); Set<Long> dataSetIds = chatQueryContext.getRequest().getDataSetIds();
if (CollectionUtils.isEmpty(dataSetIds)) {
return;
}
SchemaMapInfo schemaMapInfo = chatQueryContext.getMapInfo(); SchemaMapInfo schemaMapInfo = chatQueryContext.getMapInfo();
clearOtherSchemaElementMatch(dataSetIds, schemaMapInfo); clearOtherSchemaElementMatch(dataSetIds, schemaMapInfo);
for (Long dataSetId : dataSetIds) { for (Long dataSetId : dataSetIds) {

View File

@@ -8,10 +8,11 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
@Service @Service
@Slf4j @Slf4j
@@ -25,28 +26,38 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
Set<Long> detectDataSetIds) { Set<Long> detectDataSetIds) {
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(terms); Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(terms);
String text = chatQueryContext.getRequest().getQueryText(); String text = chatQueryContext.getRequest().getQueryText();
Set<T> results = new HashSet<>(); Set<T> results = ConcurrentHashMap.newKeySet();
List<Callable<Void>> tasks = new ArrayList<>();
Set<String> detectSegments = new HashSet<>(); for (int startIndex = 0; startIndex <= text.length() - 1;) {
for (int index = startIndex; index <= text.length();) {
for (Integer startIndex = 0; startIndex <= text.length() - 1;) {
for (Integer index = startIndex; index <= text.length();) {
int offset = mapperHelper.getStepOffset(terms, startIndex); int offset = mapperHelper.getStepOffset(terms, startIndex);
index = mapperHelper.getStepIndex(regOffsetToLength, index); index = mapperHelper.getStepIndex(regOffsetToLength, index);
if (index <= text.length()) { if (index <= text.length()) {
String detectSegment = text.substring(startIndex, index).trim(); String detectSegment = text.substring(startIndex, index).trim();
detectSegments.add(detectSegment); Callable<Void> task = createTask(chatQueryContext, detectDataSetIds,
List<T> oneRoundResults = detectSegment, offset, results);
detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset); tasks.add(task);
selectResultInOneRound(results, oneRoundResults);
} }
} }
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex); startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
} }
executeTasks(tasks);
return new ArrayList<>(results); return new ArrayList<>(results);
} }
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
String detectSegment, int offset, Set<T> results) {
return () -> {
List<T> oneRoundResults =
detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset);
synchronized (results) {
selectResultInOneRound(results, oneRoundResults);
}
return null;
};
}
public abstract List<T> detectByStep(ChatQueryContext chatQueryContext, public abstract List<T> detectByStep(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, String detectSegment, int offset); Set<Long> detectDataSetIds, String detectSegment, int offset);
} }

View File

@@ -17,13 +17,14 @@ import java.util.List;
public class TermDescMapper extends BaseMapper { public class TermDescMapper extends BaseMapper {
@Override @Override
public void doMap(ChatQueryContext chatQueryContext) { public boolean accept(ChatQueryContext chatQueryContext) {
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo(); return !(CollectionUtils.isEmpty(chatQueryContext.getMapInfo().getTermDescriptionToMap())
List<SchemaElement> termElements = mapInfo.getTermDescriptionToMap(); || chatQueryContext.getRequest().isDescriptionMapped());
if (CollectionUtils.isEmpty(termElements)
|| chatQueryContext.getRequest().isDescriptionMapped()) {
return;
} }
@Override
public void doMap(ChatQueryContext chatQueryContext) {
List<SchemaElement> termElements = chatQueryContext.getMapInfo().getTermDescriptionToMap();
for (SchemaElement schemaElement : termElements) { for (SchemaElement schemaElement : termElements) {
ChatQueryContext queryCtx = ChatQueryContext queryCtx =
buildQueryContext(chatQueryContext, schemaElement.getDescription()); buildQueryContext(chatQueryContext, schemaElement.getDescription());

View File

@@ -2,7 +2,10 @@ package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper; import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.QueryManager; import com.tencent.supersonic.headless.chat.query.QueryManager;
@@ -50,6 +53,15 @@ public class LLMResponseService {
parseInfo.setQueryMode(semanticQuery.getQueryMode()); parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL); parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
parseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL); parseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL);
DataSetSchema dataSetSchema =
queryCtx.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
SchemaElement partitionDimension = dataSetSchema.getPartitionDimension();
if (Objects.nonNull(partitionDimension)) {
DateConf dateConf = new DateConf();
dateConf.setDateField(partitionDimension.getName());
parseInfo.setDateInfo(dateConf);
}
queryCtx.getCandidateQueries().add(semanticQuery); queryCtx.getCandidateQueries().add(semanticQuery);
} }

View File

@@ -2,6 +2,8 @@ package com.tencent.supersonic.headless.chat.parser.rule;
import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum; import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.parser.SemanticParser; import com.tencent.supersonic.headless.chat.parser.SemanticParser;
@@ -57,6 +59,10 @@ public class TimeRangeParser implements SemanticParser {
for (SemanticQuery query : queryContext.getCandidateQueries()) { for (SemanticQuery query : queryContext.getCandidateQueries()) {
SemanticParseInfo parseInfo = query.getParseInfo(); SemanticParseInfo parseInfo = query.getParseInfo();
if (queryContext.containsPartitionDimensions(parseInfo.getDataSetId())) { if (queryContext.containsPartitionDimensions(parseInfo.getDataSetId())) {
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap()
.get(parseInfo.getDataSetId());
SchemaElement partitionDimension = dataSetSchema.getPartitionDimension();
dateConf.setDateField(partitionDimension.getName());
parseInfo.setDateInfo(dateConf); parseInfo.setDateInfo(dateConf);
} }
parseInfo.setScore(parseInfo.getScore() + dateConf.getDetectWord().length()); parseInfo.setScore(parseInfo.getScore() + dateConf.getDetectWord().length());

View File

@@ -4,18 +4,8 @@ import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.Filter; import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.Order; import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.headless.api.pojo.*;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.request.*;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery; import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
import com.tencent.supersonic.headless.chat.query.QueryManager; import com.tencent.supersonic.headless.chat.query.QueryManager;
@@ -25,13 +15,8 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList; import java.util.*;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Objects;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.TERM; import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.TERM;
@@ -233,8 +218,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
protected void convertBizNameToName(DataSetSchema dataSetSchema, protected void convertBizNameToName(DataSetSchema dataSetSchema,
QueryStructReq queryStructReq) { QueryStructReq queryStructReq) {
Map<String, String> bizNameToName = dataSetSchema.getBizNameToName(); Map<String, String> bizNameToName = dataSetSchema.getBizNameToName();
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
List<Order> orders = queryStructReq.getOrders(); List<Order> orders = queryStructReq.getOrders();
if (CollectionUtils.isNotEmpty(orders)) { if (CollectionUtils.isNotEmpty(orders)) {
for (Order order : orders) { for (Order order : orders) {

View File

@@ -3,14 +3,13 @@ package com.tencent.supersonic.headless.chat.query.rule.detail;
import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.TimeMode; import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import java.time.LocalDate; import java.time.LocalDate;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@@ -33,10 +32,13 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
chatQueryContext.getSemanticSchema().getDataSetSchemaMap(); chatQueryContext.getSemanticSchema().getDataSetSchemaMap();
DataSetSchema dataSetSchema = dataSetSchemaMap.get(parseInfo.getDataSetId()); DataSetSchema dataSetSchema = dataSetSchemaMap.get(parseInfo.getDataSetId());
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getDetailTypeTimeDefaultConfig(); TimeDefaultConfig timeDefaultConfig = dataSetSchema.getDetailTypeTimeDefaultConfig();
SchemaElement partitionDimension = dataSetSchema.getPartitionDimension();
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit()) if (Objects.nonNull(partitionDimension) && Objects.nonNull(timeDefaultConfig)
&& Objects.nonNull(timeDefaultConfig.getUnit())
&& timeDefaultConfig.getUnit() != -1) { && timeDefaultConfig.getUnit() != -1) {
DateConf dateInfo = new DateConf(); DateConf dateInfo = new DateConf();
dateInfo.setDateField(partitionDimension.getName());
int unit = timeDefaultConfig.getUnit(); int unit = timeDefaultConfig.getUnit();
String startDate = LocalDate.now().minusDays(unit).toString(); String startDate = LocalDate.now().minusDays(unit).toString();
String endDate = startDate; String endDate = startDate;

View File

@@ -3,14 +3,13 @@ package com.tencent.supersonic.headless.chat.query.rule.metric;
import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.TimeMode; import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import java.time.LocalDate; import java.time.LocalDate;
import java.util.List;
import java.util.Objects; import java.util.Objects;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.METRIC; import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.METRIC;
@@ -40,10 +39,12 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap() DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap()
.get(parseInfo.getDataSetId()); .get(parseInfo.getDataSetId());
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getMetricTypeTimeDefaultConfig(); TimeDefaultConfig timeDefaultConfig = dataSetSchema.getMetricTypeTimeDefaultConfig();
DateConf dateInfo = new DateConf(); SchemaElement partitionDimension = dataSetSchema.getPartitionDimension();
// 加上时间!=-1 判断 if (Objects.nonNull(partitionDimension) && Objects.nonNull(timeDefaultConfig)
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit()) && Objects.nonNull(timeDefaultConfig.getUnit())
&& timeDefaultConfig.getUnit() != -1) { && timeDefaultConfig.getUnit() != -1) {
DateConf dateInfo = new DateConf();
dateInfo.setDateField(partitionDimension.getName());
int unit = timeDefaultConfig.getUnit(); int unit = timeDefaultConfig.getUnit();
String startDate = LocalDate.now().minusDays(unit).toString(); String startDate = LocalDate.now().minusDays(unit).toString();
String endDate = startDate; String endDate = startDate;
@@ -55,8 +56,8 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
dateInfo.setPeriod(timeDefaultConfig.getPeriod()); dateInfo.setPeriod(timeDefaultConfig.getPeriod());
dateInfo.setStartDate(startDate); dateInfo.setStartDate(startDate);
dateInfo.setEndDate(endDate); dateInfo.setEndDate(endDate);
// 时间不为-1才设置时间所以移到这里
parseInfo.setDateInfo(dateInfo); parseInfo.setDateInfo(dateInfo);
} }
} }
} }

View File

@@ -7,9 +7,7 @@ import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.Order; import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.SqlInfo;
@@ -22,13 +20,7 @@ import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.*;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Slf4j @Slf4j
@@ -179,14 +171,7 @@ public class QueryReqBuilder {
if (Objects.isNull(dateConf)) { if (Objects.isNull(dateConf)) {
return ""; return "";
} }
String dateField = TimeDimensionEnum.DAY.getName(); return dateConf.getDateField();
if (DatePeriodEnum.MONTH.equals(dateConf.getPeriod())) {
dateField = TimeDimensionEnum.MONTH.getName();
}
if (DatePeriodEnum.WEEK.equals(dateConf.getPeriod())) {
dateField = TimeDimensionEnum.WEEK.getName();
}
return dateField;
} }
public static QueryStructReq buildStructRatioReq(SemanticParseInfo parseInfo, public static QueryStructReq buildStructRatioReq(SemanticParseInfo parseInfo,

View File

@@ -7,7 +7,7 @@ public class HanadbAdaptor extends DefaultDbAdaptor {
@Override @Override
public String rewriteSql(String sql) { public String rewriteSql(String sql) {
return sql.replaceAll("`", "\""); return sql.replaceAll("`(.*?)`", "\"$1\"").replaceAll("\"([A-Z0-9_]+?)\"", "$1");
} }
} }

View File

@@ -8,7 +8,7 @@ import org.springframework.context.annotation.Configuration;
@Configuration @Configuration
public class ExecutorConfig { public class ExecutorConfig {
@Value("${s2.metricParser.agg.mysql.lowVersion:5.7}") @Value("${s2.metricParser.agg.mysql.lowVersion:8.0}")
private String mysqlLowVersion; private String mysqlLowVersion;
@Value("${s2.metricParser.agg.ck.lowVersion:20.4}") @Value("${s2.metricParser.agg.ck.lowVersion:20.4}")

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.headless.core.executor; package com.tencent.supersonic.headless.core.executor;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.core.pojo.Database;
import com.tencent.supersonic.headless.core.pojo.QueryStatement; import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import com.tencent.supersonic.headless.core.utils.ComponentFactory; import com.tencent.supersonic.headless.core.utils.ComponentFactory;
import com.tencent.supersonic.headless.core.utils.SqlUtils; import com.tencent.supersonic.headless.core.utils.SqlUtils;
@@ -38,7 +38,7 @@ public class JdbcExecutor implements QueryExecutor {
SqlUtils sqlUtils = ContextUtils.getBean(SqlUtils.class); SqlUtils sqlUtils = ContextUtils.getBean(SqlUtils.class);
String sql = StringUtils.normalizeSpace(queryStatement.getSql()); String sql = StringUtils.normalizeSpace(queryStatement.getSql());
log.info("executing SQL: {}", sql); log.info("executing SQL: {}", sql);
Database database = queryStatement.getOntology().getDatabase(); DatabaseResp database = queryStatement.getOntology().getDatabase();
SemanticQueryResp queryResultWithColumns = new SemanticQueryResp(); SemanticQueryResp queryResultWithColumns = new SemanticQueryResp();
try { try {
SqlUtils sqlUtil = sqlUtils.init(database); SqlUtils sqlUtil = sqlUtils.init(database);

View File

@@ -1,49 +0,0 @@
package com.tencent.supersonic.headless.core.pojo;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.RecordInfo;
import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.common.util.AESEncryptionUtil;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@AllArgsConstructor
@NoArgsConstructor
@Builder
public class Database extends RecordInfo {
private Long id;
private Long domainId;
private String name;
private String description;
private String version;
private String url;
private String username;
private String password;
private String database;
private String schema;
/** mysql,clickhouse */
private EngineType type;
private List<String> admins = Lists.newArrayList();
private List<String> viewers = Lists.newArrayList();
public String passwordDecrypt() {
return AESEncryptionUtil.aesDecryptECB(password);
}
}

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.core.pojo;
import com.alibaba.druid.pool.DruidDataSource; import com.alibaba.druid.pool.DruidDataSource;
import com.tencent.supersonic.headless.api.pojo.enums.DataType; import com.tencent.supersonic.headless.api.pojo.enums.DataType;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.core.utils.JdbcDataSourceUtils; import com.tencent.supersonic.headless.core.utils.JdbcDataSourceUtils;
import lombok.Getter; import lombok.Getter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -106,7 +107,7 @@ public class JdbcDataSource {
} }
} }
public void removeDatasource(Database database) { public void removeDatasource(DatabaseResp database) {
String key = getDataSourceKey(database); String key = getDataSourceKey(database);
@@ -128,7 +129,7 @@ public class JdbcDataSource {
} }
} }
public DruidDataSource getDataSource(Database database) throws RuntimeException { public DruidDataSource getDataSource(DatabaseResp database) throws RuntimeException {
String name = database.getName(); String name = database.getName();
String jdbcUrl = database.getUrl(); String jdbcUrl = database.getUrl();
@@ -239,7 +240,7 @@ public class JdbcDataSource {
return druidDataSource; return druidDataSource;
} }
private String getDataSourceKey(Database database) { private String getDataSourceKey(DatabaseResp database) {
return JdbcDataSourceUtils.getKey(database.getName(), database.getUrl(), return JdbcDataSourceUtils.getKey(database.getName(), database.getUrl(),
database.getUsername(), database.passwordDecrypt(), "", false); database.getUsername(), database.passwordDecrypt(), "", false);
} }

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.core.translator.parser.s2sql; package com.tencent.supersonic.headless.core.pojo;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;

View File

@@ -0,0 +1,42 @@
package com.tencent.supersonic.headless.core.pojo;
import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.DataModel;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.Dimension;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.Materialization;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.Metric;
import lombok.Data;
import java.util.*;
import java.util.stream.Collectors;
@Data
public class Ontology {
private List<Metric> metrics = new ArrayList<>();
private Map<String, DataModel> dataModelMap = new HashMap<>();
private Map<String, List<Dimension>> dimensionMap = new HashMap<>();
private List<Materialization> materializationList = new ArrayList<>();
private List<JoinRelation> joinRelations;
private DatabaseResp database;
public List<Dimension> getDimensions() {
return dimensionMap.values().stream().flatMap(Collection::stream)
.collect(Collectors.toList());
}
public EngineType getDatabaseType() {
if (Objects.nonNull(database)) {
return EngineType.fromString(database.getType().toUpperCase());
}
return null;
}
public String getDatabaseVersion() {
if (Objects.nonNull(database)) {
return database.getVersion();
}
return null;
}
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.core.translator.parser.s2sql; package com.tencent.supersonic.headless.core.pojo;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.common.pojo.ColumnOrder; import com.tencent.supersonic.common.pojo.ColumnOrder;
@@ -9,7 +9,7 @@ import java.util.List;
import java.util.Set; import java.util.Set;
@Data @Data
public class OntologyQueryParam { public class OntologyQuery {
private Set<String> metrics = Sets.newHashSet(); private Set<String> metrics = Sets.newHashSet();
private Set<String> dimensions = Sets.newHashSet(); private Set<String> dimensions = Sets.newHashSet();
private String where; private String where;

View File

@@ -1,8 +1,6 @@
package com.tencent.supersonic.headless.core.pojo; package com.tencent.supersonic.headless.core.pojo;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.Ontology;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.OntologyQueryParam;
import lombok.Data; import lombok.Data;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Triple; import org.apache.commons.lang3.tuple.Triple;
@@ -11,17 +9,18 @@ import org.apache.commons.lang3.tuple.Triple;
public class QueryStatement { public class QueryStatement {
private Long dataSetId; private Long dataSetId;
private String dataSetName;
private String sql; private String sql;
private String errMsg; private String errMsg;
private StructQueryParam structQueryParam; private StructQuery structQuery;
private SqlQueryParam sqlQueryParam; private SqlQuery sqlQuery;
private OntologyQueryParam ontologyQueryParam; private OntologyQuery ontologyQuery;
private Integer status = 0; private Integer status = 0;
private Boolean isS2SQL = false; private Boolean isS2SQL = false;
private Boolean enableOptimize = true; private Boolean enableOptimize = true;
private Triple<String, String, String> minMaxTime; private Triple<String, String, String> minMaxTime;
private Ontology ontology; private Ontology ontology;
private SemanticSchemaResp semanticSchemaResp; private SemanticSchemaResp semanticSchema;
private Integer limit = 1000; private Integer limit = 1000;
private Boolean isTranslated = false; private Boolean isTranslated = false;

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.core.pojo;
import lombok.Data; import lombok.Data;
@Data @Data
public class SqlQueryParam { public class SqlQuery {
private String sql; private String sql;
private String table; private String table;
private boolean supportWith = true; private boolean supportWith = true;

View File

@@ -12,7 +12,7 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
@Data @Data
public class StructQueryParam { public class StructQuery {
private List<String> groups = new ArrayList(); private List<String> groups = new ArrayList();
private List<Aggregator> aggregators = new ArrayList(); private List<Aggregator> aggregators = new ArrayList();
private List<Order> orders = new ArrayList(); private List<Order> orders = new ArrayList();

View File

@@ -3,10 +3,9 @@ package com.tencent.supersonic.headless.core.translator;
import com.tencent.supersonic.common.calcite.SqlMergeWithUtils; import com.tencent.supersonic.common.calcite.SqlMergeWithUtils;
import com.tencent.supersonic.common.pojo.enums.EngineType; import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.headless.core.pojo.QueryStatement; import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import com.tencent.supersonic.headless.core.pojo.SqlQueryParam; import com.tencent.supersonic.headless.core.pojo.SqlQuery;
import com.tencent.supersonic.headless.core.translator.converter.QueryConverter;
import com.tencent.supersonic.headless.core.translator.optimizer.QueryOptimizer; import com.tencent.supersonic.headless.core.translator.optimizer.QueryOptimizer;
import com.tencent.supersonic.headless.core.translator.parser.s2sql.OntologyQueryParam; import com.tencent.supersonic.headless.core.translator.parser.QueryParser;
import com.tencent.supersonic.headless.core.utils.ComponentFactory; import com.tencent.supersonic.headless.core.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@@ -26,48 +25,52 @@ public class DefaultSemanticTranslator implements SemanticTranslator {
return; return;
} }
try { try {
for (QueryConverter converter : ComponentFactory.getQueryConverters()) { for (QueryParser parser : ComponentFactory.getQueryParsers()) {
if (converter.accept(queryStatement)) { if (parser.accept(queryStatement)) {
log.debug("QueryConverter accept [{}]", converter.getClass().getName()); log.debug("QueryConverter accept [{}]", parser.getClass().getName());
converter.convert(queryStatement); parser.parse(queryStatement);
if (queryStatement.getStatus() != 0) {
break;
} }
} }
doOntologyParse(queryStatement); }
if (!queryStatement.isOk()) {
throw new Exception(String.format("parse ontology table [%s] error [%s]",
queryStatement.getSqlQuery().getTable(), queryStatement.getErrMsg()));
}
if (StringUtils.isNotBlank(queryStatement.getSqlQueryParam().getSimplifiedSql())) { mergeOntologyQuery(queryStatement);
queryStatement.setSql(queryStatement.getSqlQueryParam().getSimplifiedSql());
if (StringUtils.isNotBlank(queryStatement.getSqlQuery().getSimplifiedSql())) {
queryStatement.setSql(queryStatement.getSqlQuery().getSimplifiedSql());
} }
if (StringUtils.isBlank(queryStatement.getSql())) { if (StringUtils.isBlank(queryStatement.getSql())) {
throw new RuntimeException("parse exception: " + queryStatement.getErrMsg()); throw new RuntimeException("parse exception: " + queryStatement.getErrMsg());
} }
for (QueryOptimizer queryOptimizer : ComponentFactory.getQueryOptimizers()) { for (QueryOptimizer optimizer : ComponentFactory.getQueryOptimizers()) {
queryOptimizer.rewrite(queryStatement); if (optimizer.accept(queryStatement)) {
optimizer.rewrite(queryStatement);
} }
}
log.info("translated query SQL: [{}]",
StringUtils.normalizeSpace(queryStatement.getSql()));
} catch (Exception e) { } catch (Exception e) {
queryStatement.setErrMsg(e.getMessage()); queryStatement.setErrMsg(e.getMessage());
log.error("Failed to translate query [{}]", e.getMessage(), e); log.error("Failed to translate query [{}]", e.getMessage(), e);
} }
} }
private void doOntologyParse(QueryStatement queryStatement) throws Exception { private void mergeOntologyQuery(QueryStatement queryStatement) throws Exception {
OntologyQueryParam ontologyQueryParam = queryStatement.getOntologyQueryParam(); SqlQuery sqlQuery = queryStatement.getSqlQuery();
log.info("parse with ontology: [{}]", ontologyQueryParam); String ontologyQuerySql = sqlQuery.getSql();
ComponentFactory.getQueryParser().parse(queryStatement); String ontologyInnerTable = sqlQuery.getTable();
if (!queryStatement.isOk()) {
throw new Exception(String.format("parse ontology table [%s] error [%s]",
queryStatement.getSqlQueryParam().getTable(), queryStatement.getErrMsg()));
}
SqlQueryParam sqlQueryParam = queryStatement.getSqlQueryParam();
String ontologyQuerySql = sqlQueryParam.getSql();
String ontologyInnerTable = sqlQueryParam.getTable();
String ontologyInnerSql = queryStatement.getSql(); String ontologyInnerSql = queryStatement.getSql();
List<Pair<String, String>> tables = new ArrayList<>(); List<Pair<String, String>> tables = new ArrayList<>();
tables.add(Pair.of(ontologyInnerTable, ontologyInnerSql)); tables.add(Pair.of(ontologyInnerTable, ontologyInnerSql));
if (sqlQueryParam.isSupportWith()) { if (sqlQuery.isSupportWith()) {
EngineType engineType = queryStatement.getOntology().getDatabase().getType(); EngineType engineType = queryStatement.getOntology().getDatabaseType();
if (!SqlMergeWithUtils.hasWith(engineType, ontologyQuerySql)) { if (!SqlMergeWithUtils.hasWith(engineType, ontologyQuerySql)) {
String withSql = "with " + tables.stream() String withSql = "with " + tables.stream()
.map(t -> String.format("%s as (%s)", t.getLeft(), t.getRight())) .map(t -> String.format("%s as (%s)", t.getLeft(), t.getRight()))
@@ -84,9 +87,9 @@ public class DefaultSemanticTranslator implements SemanticTranslator {
} }
} else { } else {
for (Pair<String, String> tb : tables) { for (Pair<String, String> tb : tables) {
ontologyQuerySql = ontologyQuerySql = StringUtils.replace(ontologyQuerySql, tb.getLeft(),
StringUtils.replace(ontologyQuerySql, tb.getLeft(), "(" + tb.getRight() "(" + tb.getRight() + ") " + (sqlQuery.isWithAlias() ? "" : tb.getLeft()),
+ ") " + (sqlQueryParam.isWithAlias() ? "" : tb.getLeft()), -1); -1);
} }
queryStatement.setSql(ontologyQuerySql); queryStatement.setSql(ontologyQuerySql);
} }

Some files were not shown because too many files have changed in this diff Show More