(improvement)(headless)Rename QueryTextReq to QueryNLReq to explicitly reflect natural language concept.

This commit is contained in:
jerryjzhang
2024-07-08 10:20:20 +08:00
parent 9911e6772c
commit efd617b2e5
12 changed files with 83 additions and 84 deletions

View File

@@ -15,7 +15,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService; import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
@@ -69,11 +69,11 @@ public class NL2SQLParser implements ChatParser {
} }
processMultiTurn(chatParseContext); processMultiTurn(chatParseContext);
QueryTextReq queryTextReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
addDynamicExemplars(chatParseContext.getAgent().getId(), queryTextReq); addDynamicExemplars(chatParseContext.getAgent().getId(), queryNLReq);
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryTextReq); ParseResp text2SqlParseResp = chatQueryService.performParsing(queryNLReq);
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) { if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
} }
@@ -149,8 +149,8 @@ public class NL2SQLParser implements ChatParser {
// derive mapping result of current question and parsing result of last question. // derive mapping result of current question and parsing result of last question.
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
QueryTextReq queryTextReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
MapResp currentMapResult = chatQueryService.performMapping(queryTextReq); MapResp currentMapResult = chatQueryService.performMapping(queryNLReq);
List<ParseResp> historyParseResults = getHistoryParseResult(chatParseContext.getChatId(), 1); List<ParseResp> historyParseResults = getHistoryParseResult(chatParseContext.getChatId(), 1);
if (historyParseResults.size() == 0) { if (historyParseResults.size() == 0) {
@@ -168,7 +168,7 @@ public class NL2SQLParser implements ChatParser {
.curtSchema(curtMapStr) .curtSchema(curtMapStr)
.histSchema(histMapStr) .histSchema(histMapStr)
.histSQL(histSQL) .histSQL(histSQL)
.modelConfig(queryTextReq.getModelConfig()) .modelConfig(queryNLReq.getModelConfig())
.build()); .build());
chatParseContext.setQueryText(rewrittenQuery); chatParseContext.setQueryText(rewrittenQuery);
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
@@ -225,13 +225,13 @@ public class NL2SQLParser implements ChatParser {
return contextualList; return contextualList;
} }
private void addDynamicExemplars(Integer agentId, QueryTextReq queryTextReq) { private void addDynamicExemplars(Integer agentId, QueryNLReq queryNLReq) {
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class); ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId); String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName, List<SqlExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
queryTextReq.getQueryText(), 5); queryNLReq.getQueryText(), 5);
queryTextReq.getDynamicExemplars().addAll(exemplars); queryNLReq.getDynamicExemplars().addAll(exemplars);
} }
@Builder @Builder

View File

@@ -10,7 +10,7 @@ import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
@@ -76,7 +76,7 @@ public class ChatQueryController {
} }
@PostMapping("queryContext") @PostMapping("queryContext")
public Object queryContext(@RequestBody QueryTextReq queryCtx, public Object queryContext(@RequestBody QueryNLReq queryCtx,
HttpServletRequest request, HttpServletResponse response) { HttpServletRequest request, HttpServletResponse response) {
queryCtx.setUser(UserHolder.findUser(request, response)); queryCtx.setUser(UserHolder.findUser(request, response));
return chatService.queryContext(queryCtx.getChatId()); return chatService.queryContext(queryCtx.getChatId());

View File

@@ -21,7 +21,7 @@ import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
@@ -60,8 +60,8 @@ public class ChatServiceImpl implements ChatService {
if (!agent.enableSearch()) { if (!agent.enableSearch()) {
return Lists.newArrayList(); return Lists.newArrayList();
} }
QueryTextReq queryTextReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
return retrieveService.retrieve(queryTextReq); return retrieveService.retrieve(queryNLReq);
} }
@Override @Override
@@ -137,8 +137,8 @@ public class ChatServiceImpl implements ChatService {
} }
private void supplyMapInfo(ChatParseContext chatParseContext) { private void supplyMapInfo(ChatParseContext chatParseContext) {
QueryTextReq queryTextReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
MapResp mapResp = chatQueryService.performMapping(queryTextReq); MapResp mapResp = chatQueryService.performMapping(queryNLReq);
chatParseContext.setMapInfo(mapResp.getMapInfo()); chatParseContext.setMapInfo(mapResp.getMapInfo());
} }

View File

@@ -4,35 +4,35 @@ import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import org.apache.commons.collections.MapUtils; import org.apache.commons.collections.MapUtils;
import java.util.Objects; import java.util.Objects;
public class QueryReqConverter { public class QueryReqConverter {
public static QueryTextReq buildText2SqlQueryReq(ChatParseContext chatParseContext) { public static QueryNLReq buildText2SqlQueryReq(ChatParseContext chatParseContext) {
QueryTextReq queryTextReq = new QueryTextReq(); QueryNLReq queryNLReq = new QueryNLReq();
BeanMapper.mapper(chatParseContext, queryTextReq); BeanMapper.mapper(chatParseContext, queryNLReq);
Agent agent = chatParseContext.getAgent(); Agent agent = chatParseContext.getAgent();
if (agent == null) { if (agent == null) {
return queryTextReq; return queryNLReq;
} }
if (agent.containsLLMParserTool() && agent.containsRuleTool()) { if (agent.containsLLMParserTool() && agent.containsRuleTool()) {
queryTextReq.setText2SQLType(Text2SQLType.RULE_AND_LLM); queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
} else if (agent.containsLLMParserTool()) { } else if (agent.containsLLMParserTool()) {
queryTextReq.setText2SQLType(Text2SQLType.ONLY_LLM); queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM);
} else if (agent.containsRuleTool()) { } else if (agent.containsRuleTool()) {
queryTextReq.setText2SQLType(Text2SQLType.ONLY_RULE); queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
} }
queryTextReq.setDataSetIds(agent.getDataSetIds()); queryNLReq.setDataSetIds(agent.getDataSetIds());
if (Objects.nonNull(queryTextReq.getMapInfo()) if (Objects.nonNull(queryNLReq.getMapInfo())
&& MapUtils.isNotEmpty(queryTextReq.getMapInfo().getDataSetElementMatches())) { && MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) {
queryTextReq.setMapInfo(queryTextReq.getMapInfo()); queryNLReq.setMapInfo(queryNLReq.getMapInfo());
} }
queryTextReq.setModelConfig(agent.getModelConfig()); queryNLReq.setModelConfig(agent.getModelConfig());
queryTextReq.setPromptConfig(agent.getPromptConfig()); queryNLReq.setPromptConfig(agent.getPromptConfig());
return queryTextReq; return queryNLReq;
} }
} }

View File

@@ -16,7 +16,7 @@ import java.util.List;
import java.util.Set; import java.util.Set;
@Data @Data
public class QueryTextReq { public class QueryNLReq {
private String queryText; private String queryText;
private Integer chatId; private Integer chatId;
private Set<Long> dataSetIds = Sets.newHashSet(); private Set<Long> dataSetIds = Sets.newHashSet();

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.headless.chat.parser.llm; package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
@@ -22,7 +22,7 @@ public class ParseResult {
private LLMResp llmResp; private LLMResp llmResp;
private QueryTextReq request; private QueryNLReq request;
private List<LLMReq.ElementValue> linkingValues; private List<LLMReq.ElementValue> linkingValues;
} }

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.headless.server.facade.rest; package com.tencent.supersonic.headless.server.facade.rest;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService; import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.RetrieveService; import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
@@ -26,27 +26,27 @@ public class ChatQueryApiController {
private RetrieveService retrieveService; private RetrieveService retrieveService;
@PostMapping("/chat/search") @PostMapping("/chat/search")
public Object search(@RequestBody QueryTextReq queryTextReq, public Object search(@RequestBody QueryNLReq queryNLReq,
HttpServletRequest request, HttpServletRequest request,
HttpServletResponse response) throws Exception { HttpServletResponse response) throws Exception {
queryTextReq.setUser(UserHolder.findUser(request, response)); queryNLReq.setUser(UserHolder.findUser(request, response));
return retrieveService.retrieve(queryTextReq); return retrieveService.retrieve(queryNLReq);
} }
@PostMapping("/chat/map") @PostMapping("/chat/map")
public MapResp map(@RequestBody QueryTextReq queryTextReq, public MapResp map(@RequestBody QueryNLReq queryNLReq,
HttpServletRequest request, HttpServletRequest request,
HttpServletResponse response) { HttpServletResponse response) {
queryTextReq.setUser(UserHolder.findUser(request, response)); queryNLReq.setUser(UserHolder.findUser(request, response));
return chatQueryService.performMapping(queryTextReq); return chatQueryService.performMapping(queryNLReq);
} }
@PostMapping("/chat/parse") @PostMapping("/chat/parse")
public Object parse(@RequestBody QueryTextReq queryTextReq, public Object parse(@RequestBody QueryNLReq queryNLReq,
HttpServletRequest request, HttpServletRequest request,
HttpServletResponse response) throws Exception { HttpServletResponse response) throws Exception {
queryTextReq.setUser(UserHolder.findUser(request, response)); queryNLReq.setUser(UserHolder.findUser(request, response));
return chatQueryService.performParsing(queryTextReq); return chatQueryService.performParsing(queryNLReq);
} }
} }

View File

@@ -7,23 +7,23 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq; import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp; import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
/*** /***dd
* SemanticLayerService for query and search * SemanticLayerService for query and search
*/ */
public interface ChatQueryService { public interface ChatQueryService {
MapResp performMapping(QueryTextReq queryTextReq); MapResp performMapping(QueryNLReq queryNLReq);
MapInfoResp map(QueryMapReq queryMapReq); MapInfoResp map(QueryMapReq queryMapReq);
ParseResp performParsing(QueryTextReq queryTextReq); ParseResp performParsing(QueryNLReq queryNLReq);
@Deprecated @Deprecated
QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception; QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception;

View File

@@ -1,12 +1,12 @@
package com.tencent.supersonic.headless.server.facade.service; package com.tencent.supersonic.headless.server.facade.service;
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult; import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import java.util.List; import java.util.List;
public interface RetrieveService { public interface RetrieveService {
List<SearchResult> retrieve(QueryTextReq queryCtx); List<SearchResult> retrieve(QueryNLReq queryCtx);
} }

View File

@@ -29,7 +29,7 @@ import com.tencent.supersonic.headless.api.pojo.enums.CostType;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod; import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq; import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
@@ -118,40 +118,40 @@ public class ChatQueryServiceImpl implements ChatQueryService {
private ChatWorkflowEngine chatWorkflowEngine; private ChatWorkflowEngine chatWorkflowEngine;
@Override @Override
public MapResp performMapping(QueryTextReq queryTextReq) { public MapResp performMapping(QueryNLReq queryNLReq) {
MapResp mapResp = new MapResp(); MapResp mapResp = new MapResp();
ChatQueryContext queryCtx = buildQueryContext(queryTextReq); ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq);
ComponentFactory.getSchemaMappers().forEach(mapper -> { ComponentFactory.getSchemaMappers().forEach(mapper -> {
mapper.map(queryCtx); mapper.map(queryCtx);
}); });
SchemaMapInfo mapInfo = queryCtx.getMapInfo(); SchemaMapInfo mapInfo = queryCtx.getMapInfo();
mapResp.setMapInfo(mapInfo); mapResp.setMapInfo(mapInfo);
mapResp.setQueryText(queryTextReq.getQueryText()); mapResp.setQueryText(queryNLReq.getQueryText());
return mapResp; return mapResp;
} }
@Override @Override
public MapInfoResp map(QueryMapReq queryMapReq) { public MapInfoResp map(QueryMapReq queryMapReq) {
QueryTextReq queryTextReq = new QueryTextReq(); QueryNLReq queryNLReq = new QueryNLReq();
BeanUtils.copyProperties(queryMapReq, queryTextReq); BeanUtils.copyProperties(queryMapReq, queryNLReq);
List<DataSetResp> dataSets = dataSetService.getDataSets(queryMapReq.getDataSetNames(), queryMapReq.getUser()); List<DataSetResp> dataSets = dataSetService.getDataSets(queryMapReq.getDataSetNames(), queryMapReq.getUser());
Set<Long> dataSetIds = dataSets.stream().map(SchemaItem::getId).collect(Collectors.toSet()); Set<Long> dataSetIds = dataSets.stream().map(SchemaItem::getId).collect(Collectors.toSet());
queryTextReq.setDataSetIds(dataSetIds); queryNLReq.setDataSetIds(dataSetIds);
MapResp mapResp = performMapping(queryTextReq); MapResp mapResp = performMapping(queryNLReq);
dataSetIds.retainAll(mapResp.getMapInfo().getDataSetElementMatches().keySet()); dataSetIds.retainAll(mapResp.getMapInfo().getDataSetElementMatches().keySet());
return convert(mapResp, queryMapReq.getTopN(), dataSetIds); return convert(mapResp, queryMapReq.getTopN(), dataSetIds);
} }
@Override @Override
public ParseResp performParsing(QueryTextReq queryTextReq) { public ParseResp performParsing(QueryNLReq queryNLReq) {
ParseResp parseResult = new ParseResp(queryTextReq.getChatId(), queryTextReq.getQueryText()); ParseResp parseResult = new ParseResp(queryNLReq.getChatId(), queryNLReq.getQueryText());
// build queryContext and chatContext // build queryContext and chatContext
ChatQueryContext queryCtx = buildQueryContext(queryTextReq); ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq);
// in order to support multi-turn conversation, chat context is needed // in order to support multi-turn conversation, chat context is needed
ChatContext chatCtx = chatContextService.getOrCreateContext(queryTextReq.getChatId()); ChatContext chatCtx = chatContextService.getOrCreateContext(queryNLReq.getChatId());
chatWorkflowEngine.execute(queryCtx, chatCtx, parseResult); chatWorkflowEngine.execute(queryCtx, chatCtx, parseResult);
@@ -161,21 +161,20 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return parseResult; return parseResult;
} }
public ChatQueryContext buildQueryContext(QueryTextReq queryTextReq) { public ChatQueryContext buildChatQueryContext(QueryNLReq queryNLReq) {
SemanticSchema semanticSchema = schemaService.getSemanticSchema(); SemanticSchema semanticSchema = schemaService.getSemanticSchema();
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds(); Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds();
ChatQueryContext queryCtx = ChatQueryContext.builder() ChatQueryContext queryCtx = ChatQueryContext.builder()
.queryFilters(queryTextReq.getQueryFilters()) .queryFilters(queryNLReq.getQueryFilters())
.semanticSchema(semanticSchema) .semanticSchema(semanticSchema)
.candidateQueries(new ArrayList<>()) .candidateQueries(new ArrayList<>())
.mapInfo(new SchemaMapInfo()) .mapInfo(new SchemaMapInfo())
.modelIdToDataSetIds(modelIdToDataSetIds) .modelIdToDataSetIds(modelIdToDataSetIds)
.text2SQLType(queryTextReq.getText2SQLType()) .text2SQLType(queryNLReq.getText2SQLType())
.mapModeEnum(queryTextReq.getMapModeEnum()) .mapModeEnum(queryNLReq.getMapModeEnum())
.dataSetIds(queryTextReq.getDataSetIds()) .dataSetIds(queryNLReq.getDataSetIds())
.build(); .build();
BeanUtils.copyProperties(queryTextReq, queryCtx); BeanUtils.copyProperties(queryNLReq, queryCtx);
return queryCtx; return queryCtx;
} }
@@ -278,9 +277,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
validFilter(semanticQuery.getParseInfo().getMetricFilters()); validFilter(semanticQuery.getParseInfo().getMetricFilters());
//init s2sql //init s2sql
semanticQuery.initS2Sql(semanticSchema, user); semanticQuery.initS2Sql(semanticSchema, user);
QueryTextReq queryTextReq = new QueryTextReq(); QueryNLReq queryNLReq = new QueryNLReq();
queryTextReq.setQueryFilters(new QueryFilters()); queryNLReq.setQueryFilters(new QueryFilters());
queryTextReq.setUser(user); queryNLReq.setUser(user);
} }
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
QueryResult queryResult = doExecution(semanticQueryReq, semanticQuery.getParseInfo(), user); QueryResult queryResult = doExecution(semanticQueryReq, semanticQuery.getParseInfo(), user);

View File

@@ -8,7 +8,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult; import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
@@ -64,9 +64,9 @@ public class RetrieveServiceImpl implements RetrieveService {
@Autowired @Autowired
private SearchMatchStrategy searchMatchStrategy; private SearchMatchStrategy searchMatchStrategy;
@Override @Override
public List<SearchResult> retrieve(QueryTextReq queryTextReq) { public List<SearchResult> retrieve(QueryNLReq queryNLReq) {
String queryText = queryTextReq.getQueryText(); String queryText = queryNLReq.getQueryText();
// 1.get meta info // 1.get meta info
SemanticSchema semanticSchemaDb = schemaService.getSemanticSchema(); SemanticSchema semanticSchemaDb = schemaService.getSemanticSchema();
List<SchemaElement> metricsDb = semanticSchemaDb.getMetrics(); List<SchemaElement> metricsDb = semanticSchemaDb.getMetrics();
@@ -76,10 +76,10 @@ public class RetrieveServiceImpl implements RetrieveService {
// 2.detect by segment // 2.detect by segment
List<S2Term> originals = knowledgeBaseService.getTerms(queryText, modelIdToDataSetIds); List<S2Term> originals = knowledgeBaseService.getTerms(queryText, modelIdToDataSetIds);
log.debug("hanlp parse result: {}", originals); log.debug("hanlp parse result: {}", originals);
Set<Long> dataSetIds = queryTextReq.getDataSetIds(); Set<Long> dataSetIds = queryNLReq.getDataSetIds();
ChatQueryContext chatQueryContext = new ChatQueryContext(); ChatQueryContext chatQueryContext = new ChatQueryContext();
BeanUtils.copyProperties(queryTextReq, chatQueryContext); BeanUtils.copyProperties(queryNLReq, chatQueryContext);
chatQueryContext.setModelIdToDataSetIds(dataSetService.getModelIdToDataSetIds()); chatQueryContext.setModelIdToDataSetIds(dataSetService.getModelIdToDataSetIds());
Map<MatchText, List<HanlpMapResult>> regTextMap = Map<MatchText, List<HanlpMapResult>> regTextMap =
@@ -100,12 +100,12 @@ public class RetrieveServiceImpl implements RetrieveService {
return Lists.newArrayList(); return Lists.newArrayList();
} }
Map.Entry<MatchText, List<HanlpMapResult>> searchTextEntry = mostSimilarSearchResult.get(); Map.Entry<MatchText, List<HanlpMapResult>> searchTextEntry = mostSimilarSearchResult.get();
log.debug("searchTextEntry:{},queryTextReq:{}", searchTextEntry, queryTextReq); log.debug("searchTextEntry:{},queryNLReq:{}", searchTextEntry, queryNLReq);
Set<SearchResult> searchResults = new LinkedHashSet(); Set<SearchResult> searchResults = new LinkedHashSet();
DataSetInfoStat dataSetInfoStat = NatureHelper.getDataSetStat(originals); DataSetInfoStat dataSetInfoStat = NatureHelper.getDataSetStat(originals);
List<Long> possibleDataSets = getPossibleDataSets(queryTextReq, originals, dataSetInfoStat, dataSetIds); List<Long> possibleDataSets = getPossibleDataSets(queryNLReq, originals, dataSetInfoStat, dataSetIds);
// 5.1 priority dimension metric // 5.1 priority dimension metric
boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleDataSets), dataSetIdToName, boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleDataSets), dataSetIdToName,
@@ -120,14 +120,14 @@ public class RetrieveServiceImpl implements RetrieveService {
Set<SearchResult> searchResultSet = searchDimensionValue(metricsDb, dataSetIdToName, Set<SearchResult> searchResultSet = searchDimensionValue(metricsDb, dataSetIdToName,
dataSetInfoStat.getMetricDataSetCount(), existMetricAndDimension, dataSetInfoStat.getMetricDataSetCount(), existMetricAndDimension,
matchText, natureToNameMap, natureToNameEntry, queryTextReq.getQueryFilters()); matchText, natureToNameMap, natureToNameEntry, queryNLReq.getQueryFilters());
searchResults.addAll(searchResultSet); searchResults.addAll(searchResultSet);
} }
return searchResults.stream().limit(RESULT_SIZE).collect(Collectors.toList()); return searchResults.stream().limit(RESULT_SIZE).collect(Collectors.toList());
} }
private List<Long> getPossibleDataSets(QueryTextReq queryCtx, List<S2Term> originals, private List<Long> getPossibleDataSets(QueryNLReq queryCtx, List<S2Term> originals,
DataSetInfoStat dataSetInfoStat, Set<Long> dataSetIds) { DataSetInfoStat dataSetInfoStat, Set<Long> dataSetIds) {
if (CollectionUtils.isNotEmpty(dataSetIds)) { if (CollectionUtils.isNotEmpty(dataSetIds)) {
return new ArrayList<>(dataSetIds); return new ArrayList<>(dataSetIds);

View File

@@ -22,9 +22,9 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
/** /**
* QueryTextReqBuilderTest * QueryNLReqBuilderTest
*/ */
class QueryTextReqBuilderTest { class QueryNLReqBuilderTest {
@Test @Test
void buildS2SQLReq() { void buildS2SQLReq() {