(feature)(project) change dsl to s2ql in project and queryMode to llmParser (#250)

This commit is contained in:
lexluo09
2023-10-18 09:53:01 +08:00
committed by GitHub
parent bf5be11549
commit 8d81f63e08
77 changed files with 320 additions and 305 deletions

View File

@@ -14,7 +14,7 @@ import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
@@ -37,7 +37,7 @@ public interface SemanticInterpreter {
QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
QueryResultWithSchemaResp queryByDsl(QueryDslReq queryDslReq, User user);
QueryResultWithSchemaResp queryByS2QL(QueryS2QLReq queryS2QLReq, User user);
QueryResultWithSchemaResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);

View File

@@ -5,7 +5,7 @@ import lombok.Data;
@Data
public class SqlInfo {
private String llmParseSql;
private String s2QL;
private String logicSql;
private String querySql;
}

View File

@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.agent.tool;
public enum AgentToolType {
RULE,
DSL,
LLM_S2QL,
PLUGIN,
INTERPRET
}

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.chat.agent.tool;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class CommonAgentTool extends AgentTool {
protected List<Long> modelIds;
}

View File

@@ -5,9 +5,7 @@ import lombok.Data;
import java.util.List;
@Data
public class DslTool extends AgentTool {
private List<Long> modelIds;
public class LLMParserTool extends CommonAgentTool {
private List<String> exampleQuestions;

View File

@@ -7,9 +7,8 @@ import org.apache.commons.collections.CollectionUtils;
import java.util.List;
@Data
public class RuleQueryTool extends AgentTool {
public class RuleQueryTool extends CommonAgentTool {
private List<Long> modelIds;
private List<String> queryModes;

View File

@@ -1,9 +1,9 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
import com.tencent.supersonic.chat.parser.llm.s2ql.ParseResult;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
@@ -67,11 +67,11 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector {
return null;
}
DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class);
if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) {
ParseResult parseResult = JsonUtil.toObject(JsonUtil.toString(context), ParseResult.class);
if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
return null;
}
LLMReq llmReq = dslParseResult.getLlmReq();
LLMReq llmReq = parseResult.getLlmReq();
return llmReq.getLinking();
}

View File

@@ -5,7 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
import com.tencent.supersonic.chat.parser.llm.s2ql.S2QLDateHelper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
@@ -72,7 +72,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
String sql = semanticCorrectInfo.getSql();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DateUtils.DATE_FIELD)) {
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
String currentDate = S2QLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
if (StringUtils.isNotBlank(currentDate)) {
sql = SqlParserAddHelper.addParenthesisToWhere(sql);
sql = SqlParserAddHelper.addWhere(sql, DateUtils.DATE_FIELD, currentDate);

View File

@@ -88,7 +88,7 @@ public class MapperHelper {
AgentService agentService = ContextUtils.getBean(AgentService.class);
Set<Long> detectModelIds = agentService.getDslToolsModelIds(request.getAgentId(), null);
Set<Long> detectModelIds = agentService.getModelIds(request.getAgentId(), null);
//contains all
if (agentService.containsAllModel(detectModelIds)) {
if (Objects.nonNull(modelId) && modelId > 0) {

View File

@@ -5,7 +5,7 @@ import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
@@ -20,7 +20,7 @@ public class SatisfactionChecker {
// check all the parse info in candidate
public static boolean check(QueryContext queryContext) {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
if (query.getQueryMode().equals(DslQuery.QUERY_MODE)) {
if (query.getQueryMode().equals(S2QLQuery.QUERY_MODE)) {
continue;
}
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.llm.dsl;
package com.tencent.supersonic.chat.parser.llm.s2ql;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.parser.llm.dsl;
package com.tencent.supersonic.chat.parser.llm.s2ql;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.DslTool;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
@@ -18,10 +18,10 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.LLMParserConfig;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.dsl.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
@@ -59,18 +59,18 @@ import org.springframework.util.CollectionUtils;
import org.springframework.web.client.RestTemplate;
@Slf4j
public class LLMDslParser implements SemanticParser {
public class LLMS2QLParser implements SemanticParser {
@Override
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
QueryReq request = queryCtx.getRequest();
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
log.info("llm parser url is empty, skip dsl parser, llmParserConfig:{}", llmParserConfig);
log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMS2QLParser.class, llmParserConfig);
return;
}
if (SatisfactionChecker.check(queryCtx)) {
log.info("skip dsl parser, queryText:{}", request.getQueryText());
log.info("skip {}, queryText:{}", LLMS2QLParser.class, request.getQueryText());
return;
}
try {
@@ -79,9 +79,9 @@ public class LLMDslParser implements SemanticParser {
return;
}
DslTool dslTool = getDslTool(request, modelId);
if (Objects.isNull(dslTool)) {
log.info("no dsl tool in this agent, skip dsl parser");
CommonAgentTool commonAgentTool = getParserTool(request, modelId);
if (Objects.isNull(commonAgentTool)) {
log.info("no tool in this agent, skip {}", LLMS2QLParser.class);
return;
}
@@ -91,10 +91,10 @@ public class LLMDslParser implements SemanticParser {
if (Objects.isNull(llmResp)) {
return;
}
DSLParseResult dslParseResult = DSLParseResult.builder().request(request)
.dslTool(dslTool).llmReq(llmReq).llmResp(llmResp).build();
ParseResult parseResult = ParseResult.builder().request(request)
.commonAgentTool(commonAgentTool).llmReq(llmReq).llmResp(llmResp).build();
SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, dslTool, dslParseResult);
SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, commonAgentTool, parseResult);
SemanticCorrectInfo semanticCorrectInfo = getCorrectorSql(queryCtx, parseInfo, llmResp.getSqlOutput());
@@ -103,7 +103,7 @@ public class LLMDslParser implements SemanticParser {
updateParseInfo(semanticCorrectInfo, modelId, parseInfo);
} catch (Exception e) {
log.error("LLMDSLParser error", e);
log.error("LLMS2QLParser error", e);
}
}
@@ -243,12 +243,12 @@ public class LLMDslParser implements SemanticParser {
.queryFilters(queryCtx.getRequest().getQueryFilters()).sql(sql)
.parseInfo(parseInfo).build();
List<SemanticCorrector> dslCorrections = ComponentFactory.getSqlCorrections();
List<SemanticCorrector> corrections = ComponentFactory.getSqlCorrections();
dslCorrections.forEach(dslCorrection -> {
corrections.forEach(correction -> {
try {
dslCorrection.correct(correctInfo);
log.info("sqlCorrection:{} sql:{}", dslCorrection.getClass().getSimpleName(), correctInfo.getSql());
correction.correct(correctInfo);
log.info("sqlCorrection:{} sql:{}", correction.getClass().getSimpleName(), correctInfo.getSql());
} catch (Exception e) {
log.error(String.format("correct error,correctInfo:%s", correctInfo), e);
}
@@ -256,21 +256,21 @@ public class LLMDslParser implements SemanticParser {
return correctInfo;
}
private SemanticParseInfo getParseInfo(QueryContext queryCtx, Long modelId, DslTool dslTool,
DSLParseResult dslParseResult) {
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(DslQuery.QUERY_MODE);
private SemanticParseInfo getParseInfo(QueryContext queryCtx, Long modelId, CommonAgentTool commonAgentTool,
ParseResult parseResult) {
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(S2QLQuery.QUERY_MODE);
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId));
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, dslParseResult);
properties.put(Constants.CONTEXT, parseResult);
properties.put("type", "internal");
properties.put("name", dslTool.getName());
properties.put("name", commonAgentTool.getName());
parseInfo.setProperties(properties);
parseInfo.setScore(queryCtx.getRequest().getQueryText().length());
parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setLlmParseSql(dslParseResult.getLlmResp().getSqlOutput());
parseInfo.getSqlInfo().setS2QL(parseResult.getLlmResp().getSqlOutput());
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
@@ -284,10 +284,11 @@ public class LLMDslParser implements SemanticParser {
return parseInfo;
}
private DslTool getDslTool(QueryReq request, Long modelId) {
private CommonAgentTool getParserTool(QueryReq request, Long modelId) {
AgentService agentService = ContextUtils.getBean(AgentService.class);
List<DslTool> dslTools = agentService.getDslTools(request.getAgentId(), AgentToolType.DSL);
Optional<DslTool> dslToolOptional = dslTools.stream()
List<CommonAgentTool> commonAgentTools = agentService.getParserTools(request.getAgentId(),
AgentToolType.LLM_S2QL);
Optional<CommonAgentTool> llmParserTool = commonAgentTools.stream()
.filter(tool -> {
List<Long> modelIds = tool.getModelIds();
if (agentService.containsAllModel(new HashSet<>(modelIds))) {
@@ -296,18 +297,18 @@ public class LLMDslParser implements SemanticParser {
return modelIds.contains(modelId);
})
.findFirst();
return dslToolOptional.orElse(null);
return llmParserTool.orElse(null);
}
private Long getModelId(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) {
AgentService agentService = ContextUtils.getBean(AgentService.class);
Set<Long> distinctModelIds = agentService.getDslToolsModelIds(agentId, AgentToolType.DSL);
Set<Long> distinctModelIds = agentService.getModelIds(agentId, AgentToolType.LLM_S2QL);
if (agentService.containsAllModel(distinctModelIds)) {
distinctModelIds = new HashSet<>();
}
ModelResolver modelResolver = ComponentFactory.getModelResolver();
Long modelId = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
log.info("resolve modelId:{},dslModels:{}", modelId, distinctModelIds);
log.info("resolve modelId:{},llmParser Models:{}", modelId, distinctModelIds);
return modelId;
}
@@ -354,7 +355,7 @@ public class LLMDslParser implements SemanticParser {
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
llmReq.setLinking(linking);
String currentDate = DSLDateHelper.getReferenceDate(modelId);
String currentDate = S2QLDateHelper.getReferenceDate(modelId);
llmReq.setCurrentDate(currentDate);
return llmReq;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.llm.dsl;
package com.tencent.supersonic.chat.parser.llm.s2ql;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.llm.dsl;
package com.tencent.supersonic.chat.parser.llm.s2ql;
import com.tencent.supersonic.chat.api.pojo.ChatContext;

View File

@@ -1,9 +1,9 @@
package com.tencent.supersonic.chat.parser.llm.dsl;
package com.tencent.supersonic.chat.parser.llm.s2ql;
import com.tencent.supersonic.chat.agent.tool.DslTool;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
import com.tencent.supersonic.chat.query.llm.dsl.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMResp;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
@@ -13,7 +13,7 @@ import lombok.NoArgsConstructor;
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class DSLParseResult {
public class ParseResult {
private LLMReq llmReq;
@@ -21,5 +21,5 @@ public class DSLParseResult {
private QueryReq request;
private DslTool dslTool;
private CommonAgentTool commonAgentTool;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.llm.dsl;
package com.tencent.supersonic.chat.parser.llm.s2ql;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
@@ -11,7 +11,7 @@ import java.util.List;
import java.util.Objects;
import org.apache.commons.collections.CollectionUtils;
public class DSLDateHelper {
public class S2QLDateHelper {
public static String getReferenceDate(Long modelId) {
String defaultDate = DateUtils.getBeforeDate(0);

View File

@@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.net.URI;
@@ -102,7 +102,7 @@ public class FunctionBasedParser extends PluginParser {
log.info("user decide Model:{}", modelId);
List<Plugin> plugins = getPluginList(queryContext);
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
if (DslQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
if (S2QLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
return false;
}
if (plugin.getParseModeConfig() == null) {

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.llm.dsl;
package com.tencent.supersonic.chat.query.llm.s2ql;
import java.util.List;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.llm.dsl;
package com.tencent.supersonic.chat.query.llm.s2ql;
import java.util.List;
import lombok.Data;

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic.chat.query.llm.dsl;
package com.tencent.supersonic.chat.query.llm.s2ql;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.parser.llm.s2ql.ParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory;
@@ -16,7 +16,7 @@ import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -26,12 +26,12 @@ import org.springframework.stereotype.Component;
@Slf4j
@Component
public class DslQuery extends PluginSemanticQuery {
public class S2QLQuery extends PluginSemanticQuery {
public static final String QUERY_MODE = "DSL";
public static final String QUERY_MODE = "LLM_S2QL";
protected SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
public DslQuery() {
public S2QLQuery() {
QueryManager.register(this);
}
@@ -45,10 +45,10 @@ public class DslQuery extends PluginSemanticQuery {
LLMResp llmResp = getLlmResp();
long startTime = System.currentTimeMillis();
QueryDslReq queryDslReq = getQueryDslReq(llmResp);
QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByDsl(queryDslReq, user);
QueryS2QLReq queryS2QLReq = getQueryS2QLReq(llmResp);
QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByS2QL(queryS2QLReq, user);
log.info("queryByDsl cost:{},querySql:{}", System.currentTimeMillis() - startTime, llmResp.getSqlOutput());
log.info("queryByS2QL cost:{},querySql:{}", System.currentTimeMillis() - startTime, llmResp.getSqlOutput());
QueryResult queryResult = new QueryResult();
if (Objects.nonNull(queryResp)) {
@@ -69,13 +69,12 @@ public class DslQuery extends PluginSemanticQuery {
private LLMResp getLlmResp() {
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class);
return dslParseResult.getLlmResp();
ParseResult parseResult = JsonUtil.toObject(json, ParseResult.class);
return parseResult.getLlmResp();
}
private QueryDslReq getQueryDslReq(LLMResp llmResp) {
QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(llmResp.getCorrectorSql(), parseInfo.getModelId());
return queryDslReq;
private QueryS2QLReq getQueryS2QLReq(LLMResp llmResp) {
return QueryReqBuilder.buildS2QLReq(llmResp.getCorrectorSql(), parseInfo.getModelId());
}
@Override
@@ -84,7 +83,7 @@ public class DslQuery extends PluginSemanticQuery {
try {
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryTypeEnum.SQL)
.queryReq(getQueryDslReq(getLlmResp()))
.queryReq(getQueryS2QLReq(getLlmResp()))
.build();
return semanticInterpreter.explain(explainSqlReq, user);
} catch (Exception e) {

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List;
@@ -24,7 +24,7 @@ public class EntityInfoExecuteResponder implements ExecuteResponder {
return;
}
String queryMode = semanticParseInfo.getQueryMode();
if (QueryManager.isPluginQuery(queryMode) && !DslQuery.QUERY_MODE.equals(queryMode)) {
if (QueryManager.isPluginQuery(queryMode) && !S2QLQuery.QUERY_MODE.equals(queryMode)) {
return;
}
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List;
@@ -24,7 +24,7 @@ public class EntityInfoParseResponder implements ParseResponder {
QueryReq queryReq = queryContext.getRequest();
selectedParses.forEach(parseInfo -> {
if (QueryManager.isPluginQuery(parseInfo.getQueryMode())
&& !DslQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
&& !S2QLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
return;
}
//1. set entity info

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.DslTool;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import java.util.List;
import java.util.Set;
@@ -19,9 +19,9 @@ public interface AgentService {
void deleteAgent(Integer id);
List<DslTool> getDslTools(Integer agentId, AgentToolType agentToolType);
List<CommonAgentTool> getParserTools(Integer agentId, AgentToolType agentToolType);
Set<Long> getDslToolsModelIds(Integer agentId, AgentToolType agentToolType);
Set<Long> getModelIds(Integer agentId, AgentToolType agentToolType);
boolean containsAllModel(Set<Long> detectModelIds);
}

View File

@@ -5,7 +5,7 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.DslTool;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.persistence.repository.AgentRepository;
import com.tencent.supersonic.chat.service.AgentService;
@@ -87,7 +87,7 @@ public class AgentServiceImpl implements AgentService {
return agentDO;
}
public List<DslTool> getDslTools(Integer agentId, AgentToolType agentToolType) {
public List<CommonAgentTool> getParserTools(Integer agentId, AgentToolType agentToolType) {
Agent agent = getAgent(agentId);
if (agent == null) {
return Lists.newArrayList();
@@ -96,15 +96,16 @@ public class AgentServiceImpl implements AgentService {
if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList();
}
return tools.stream().map(tool -> JSONObject.parseObject(tool, DslTool.class)).collect(Collectors.toList());
return tools.stream().map(tool -> JSONObject.parseObject(tool, CommonAgentTool.class))
.collect(Collectors.toList());
}
public Set<Long> getDslToolsModelIds(Integer agentId, AgentToolType agentToolType) {
List<DslTool> dslTools = getDslTools(agentId, agentToolType);
if (CollectionUtils.isEmpty(dslTools)) {
public Set<Long> getModelIds(Integer agentId, AgentToolType agentToolType) {
List<CommonAgentTool> commonAgentTools = getParserTools(agentId, agentToolType);
if (CollectionUtils.isEmpty(commonAgentTools)) {
return new HashSet<>();
}
return dslTools.stream().map(DslTool::getModelIds)
return commonAgentTools.stream().map(CommonAgentTool::getModelIds)
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
.flatMap(Collection::stream)
.collect(Collectors.toSet());

View File

@@ -19,15 +19,15 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.parser.llm.s2ql.ParseResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.QuerySelector;
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.chat.query.llm.dsl.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
import com.tencent.supersonic.chat.query.llm.s2ql.LLMResp;
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
import com.tencent.supersonic.chat.service.ChatService;
@@ -239,7 +239,7 @@ public class QueryServiceImpl implements QueryService {
}
private void saveSolvedQuery(ExecuteQueryReq queryReq, SemanticParseInfo parseInfo,
ChatQueryDO chatQueryDO, QueryResult queryResult) {
ChatQueryDO chatQueryDO, QueryResult queryResult) {
if (queryResult.getResponse() == null && CollectionUtils.isEmpty(queryResult.getQueryResults())) {
return;
}
@@ -305,12 +305,12 @@ public class QueryServiceImpl implements QueryService {
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
if (DslQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
if (S2QLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class);
LLMResp llmResp = dslParseResult.getLlmResp();
ParseResult parseResult = JsonUtil.toObject(json, ParseResult.class);
LLMResp llmResp = parseResult.getLlmResp();
String correctorSql = llmResp.getCorrectorSql();
log.info("correctorSql before replacing:{}", correctorSql);
@@ -330,9 +330,9 @@ public class QueryServiceImpl implements QueryService {
correctorSql = SqlParserReplaceHelper.replaceHavingValue(correctorSql, havingFiledNameToValueMap);
log.info("correctorSql after replacing:{}", correctorSql);
llmResp.setCorrectorSql(correctorSql);
dslParseResult.setLlmResp(llmResp);
parseResult.setLlmResp(llmResp);
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, dslParseResult);
properties.put(Constants.CONTEXT, parseResult);
parseInfo.setProperties(properties);
parseInfo.getSqlInfo().setLogicSql(correctorSql);
semanticQuery.setParseInfo(parseInfo);
@@ -399,29 +399,29 @@ public class QueryServiceImpl implements QueryService {
if (CollectionUtils.isEmpty(metricFilters)) {
return;
}
for (QueryFilter dslQueryFilter : metricFilters) {
for (QueryFilter queryFilter : metricFilters) {
Map<String, String> map = new HashMap<>();
for (FilterExpression filterExpression : filterExpressionList) {
if (filterExpression.getFieldName() != null
&& filterExpression.getFieldName().contains(dslQueryFilter.getName())
&& dslQueryFilter.getOperator().getValue().equals(filterExpression.getOperator())) {
map.put(filterExpression.getFieldValue().toString(), dslQueryFilter.getValue().toString());
&& filterExpression.getFieldName().contains(queryFilter.getName())
&& queryFilter.getOperator().getValue().equals(filterExpression.getOperator())) {
map.put(filterExpression.getFieldValue().toString(), queryFilter.getValue().toString());
contextMetricFilters.stream().forEach(o -> {
if (o.getName().equals(dslQueryFilter.getName())) {
o.setValue(dslQueryFilter.getValue());
if (o.getName().equals(queryFilter.getName())) {
o.setValue(queryFilter.getValue());
}
});
break;
}
}
filedNameToValueMap.put(dslQueryFilter.getName(), map);
filedNameToValueMap.put(queryFilter.getName(), map);
}
}
private SemanticParseInfo getSemanticParseInfo(QueryDataReq queryData, ChatParseDO chatParseDO) {
SemanticParseInfo parseInfo = JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
if (DslQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
if (S2QLQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
return parseInfo;
}
if (CollectionUtils.isNotEmpty(queryData.getDimensions())) {

View File

@@ -9,7 +9,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import com.tencent.supersonic.chat.parser.llm.dsl.ModelResolver;
import com.tencent.supersonic.chat.parser.llm.s2ql.ModelResolver;
import com.tencent.supersonic.chat.query.QuerySelector;
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
@@ -20,7 +20,7 @@ public class ComponentFactory {
private static List<SchemaMapper> schemaMappers = new ArrayList<>();
private static List<SemanticParser> semanticParsers = new ArrayList<>();
private static List<SemanticCorrector> dslCorrections = new ArrayList<>();
private static List<SemanticCorrector> s2QLCorrections = new ArrayList<>();
private static SemanticInterpreter semanticInterpreter;
private static List<ParseResponder> parseResponders = new ArrayList<>();
private static List<ExecuteResponder> executeResponders = new ArrayList<>();
@@ -35,7 +35,8 @@ public class ComponentFactory {
}
public static List<SemanticCorrector> getSqlCorrections() {
return CollectionUtils.isEmpty(dslCorrections) ? init(SemanticCorrector.class, dslCorrections) : dslCorrections;
return CollectionUtils.isEmpty(s2QLCorrections) ? init(SemanticCorrector.class,
s2QLCorrections) : s2QLCorrections;
}
public static List<ParseResponder> getParseResponders() {

View File

@@ -12,7 +12,7 @@ import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.query.pojo.Filter;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.time.LocalDate;
@@ -123,19 +123,19 @@ public class QueryReqBuilder {
}
/**
* convert to QueryDslReq
* convert to QueryS2QLReq
*
* @param querySql
* @param modelId
* @return
*/
public static QueryDslReq buildDslReq(String querySql, Long modelId) {
QueryDslReq queryDslReq = new QueryDslReq();
public static QueryS2QLReq buildS2QLReq(String querySql, Long modelId) {
QueryS2QLReq queryS2QLReq = new QueryS2QLReq();
if (Objects.nonNull(querySql)) {
queryDslReq.setSql(querySql);
queryS2QLReq.setSql(querySql);
}
queryDslReq.setModelId(modelId);
return queryDslReq;
queryS2QLReq.setModelId(modelId);
return queryS2QLReq;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.llm.dsl;
package com.tencent.supersonic.chat.parser.llm.s2ql;
import static org.mockito.Mockito.when;
@@ -16,7 +16,7 @@ import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
class LLMDslParserTest {
class LLMS2QLParserTest {
@Test
void setFilter() {
@@ -72,9 +72,9 @@ class LLMDslParserTest {
.parseInfo(parseInfo)
.build();
LLMDslParser llmDslParser = new LLMDslParser();
LLMS2QLParser llms2QLParser = new LLMS2QLParser();
llmDslParser.updateParseInfo(semanticCorrectInfo, 2L, parseInfo);
llms2QLParser.updateParseInfo(semanticCorrectInfo, 2L, parseInfo);
}
}

View File

@@ -17,7 +17,7 @@ import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import com.tencent.supersonic.semantic.model.domain.DimensionService;
@@ -56,9 +56,9 @@ public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
@Override
@SneakyThrows
public QueryResultWithSchemaResp queryByDsl(QueryDslReq queryDslReq, User user) {
public QueryResultWithSchemaResp queryByS2QL(QueryS2QLReq queryS2QLReq, User user) {
queryService = ContextUtils.getBean(QueryService.class);
Object object = queryService.queryBySql(queryDslReq, user);
Object object = queryService.queryBySql(queryS2QLReq, user);
QueryResultWithSchemaResp queryResultWithSchemaResp = JsonUtil.toObject(JsonUtil.toString(object),
QueryResultWithSchemaResp.class);
return queryResultWithSchemaResp;

View File

@@ -31,7 +31,7 @@ import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.net.URI;
@@ -83,10 +83,10 @@ public class RemoteSemanticInterpreter extends BaseSemanticInterpreter {
}
@Override
public QueryResultWithSchemaResp queryByDsl(QueryDslReq queryDslReq, User user) {
public QueryResultWithSchemaResp queryByS2QL(QueryS2QLReq queryS2QLReq, User user) {
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
return searchByRestTemplate(defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getSearchBySqlPath(),
new Gson().toJson(queryDslReq));
new Gson().toJson(queryS2QLReq));
}
public QueryResultWithSchemaResp searchByRestTemplate(String url, String jsonReq) {