(improvement)(headless)Rename QueryTextReq to QueryNLReq to explicitly reflect natural language concept.

This commit is contained in:
jerryjzhang
2024-07-08 10:20:20 +08:00
parent 9911e6772c
commit efd617b2e5
12 changed files with 83 additions and 84 deletions

View File

@@ -15,7 +15,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
@@ -69,11 +69,11 @@ public class NL2SQLParser implements ChatParser {
}
processMultiTurn(chatParseContext);
QueryTextReq queryTextReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
addDynamicExemplars(chatParseContext.getAgent().getId(), queryTextReq);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
addDynamicExemplars(chatParseContext.getAgent().getId(), queryNLReq);
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryTextReq);
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryNLReq);
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
}
@@ -149,8 +149,8 @@ public class NL2SQLParser implements ChatParser {
// derive mapping result of current question and parsing result of last question.
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
QueryTextReq queryTextReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
MapResp currentMapResult = chatQueryService.performMapping(queryTextReq);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
MapResp currentMapResult = chatQueryService.performMapping(queryNLReq);
List<ParseResp> historyParseResults = getHistoryParseResult(chatParseContext.getChatId(), 1);
if (historyParseResults.size() == 0) {
@@ -168,7 +168,7 @@ public class NL2SQLParser implements ChatParser {
.curtSchema(curtMapStr)
.histSchema(histMapStr)
.histSQL(histSQL)
.modelConfig(queryTextReq.getModelConfig())
.modelConfig(queryNLReq.getModelConfig())
.build());
chatParseContext.setQueryText(rewrittenQuery);
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
@@ -225,13 +225,13 @@ public class NL2SQLParser implements ChatParser {
return contextualList;
}
private void addDynamicExemplars(Integer agentId, QueryTextReq queryTextReq) {
private void addDynamicExemplars(Integer agentId, QueryNLReq queryNLReq) {
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
queryTextReq.getQueryText(), 5);
queryTextReq.getDynamicExemplars().addAll(exemplars);
queryNLReq.getQueryText(), 5);
queryNLReq.getDynamicExemplars().addAll(exemplars);
}
@Builder

View File

@@ -10,7 +10,7 @@ import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;
@@ -76,7 +76,7 @@ public class ChatQueryController {
}
@PostMapping("queryContext")
public Object queryContext(@RequestBody QueryTextReq queryCtx,
public Object queryContext(@RequestBody QueryNLReq queryCtx,
HttpServletRequest request, HttpServletResponse response) {
queryCtx.setUser(UserHolder.findUser(request, response));
return chatService.queryContext(queryCtx.getChatId());

View File

@@ -21,7 +21,7 @@ import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
@@ -60,8 +60,8 @@ public class ChatServiceImpl implements ChatService {
if (!agent.enableSearch()) {
return Lists.newArrayList();
}
QueryTextReq queryTextReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
return retrieveService.retrieve(queryTextReq);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
return retrieveService.retrieve(queryNLReq);
}
@Override
@@ -137,8 +137,8 @@ public class ChatServiceImpl implements ChatService {
}
private void supplyMapInfo(ChatParseContext chatParseContext) {
QueryTextReq queryTextReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
MapResp mapResp = chatQueryService.performMapping(queryTextReq);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
MapResp mapResp = chatQueryService.performMapping(queryNLReq);
chatParseContext.setMapInfo(mapResp.getMapInfo());
}

View File

@@ -4,35 +4,35 @@ import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import org.apache.commons.collections.MapUtils;
import java.util.Objects;
public class QueryReqConverter {
public static QueryTextReq buildText2SqlQueryReq(ChatParseContext chatParseContext) {
QueryTextReq queryTextReq = new QueryTextReq();
BeanMapper.mapper(chatParseContext, queryTextReq);
public static QueryNLReq buildText2SqlQueryReq(ChatParseContext chatParseContext) {
QueryNLReq queryNLReq = new QueryNLReq();
BeanMapper.mapper(chatParseContext, queryNLReq);
Agent agent = chatParseContext.getAgent();
if (agent == null) {
return queryTextReq;
return queryNLReq;
}
if (agent.containsLLMParserTool() && agent.containsRuleTool()) {
queryTextReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
} else if (agent.containsLLMParserTool()) {
queryTextReq.setText2SQLType(Text2SQLType.ONLY_LLM);
queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM);
} else if (agent.containsRuleTool()) {
queryTextReq.setText2SQLType(Text2SQLType.ONLY_RULE);
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
}
queryTextReq.setDataSetIds(agent.getDataSetIds());
if (Objects.nonNull(queryTextReq.getMapInfo())
&& MapUtils.isNotEmpty(queryTextReq.getMapInfo().getDataSetElementMatches())) {
queryTextReq.setMapInfo(queryTextReq.getMapInfo());
queryNLReq.setDataSetIds(agent.getDataSetIds());
if (Objects.nonNull(queryNLReq.getMapInfo())
&& MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) {
queryNLReq.setMapInfo(queryNLReq.getMapInfo());
}
queryTextReq.setModelConfig(agent.getModelConfig());
queryTextReq.setPromptConfig(agent.getPromptConfig());
return queryTextReq;
queryNLReq.setModelConfig(agent.getModelConfig());
queryNLReq.setPromptConfig(agent.getPromptConfig());
return queryNLReq;
}
}