mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 20:25:12 +00:00
(improvement)(headless&chat)Execute against SemanticLayerService instead of ChatQueryService in chat-server and rename several classes by the way.
This commit is contained in:
@@ -5,25 +5,32 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.ResultFormatter;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import com.tencent.supersonic.headless.server.web.service.ChatContextService;
|
||||
import lombok.SneakyThrows;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class SqlExecutor implements ChatExecutor {
|
||||
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
||||
ExecuteQueryReq executeQueryReq = buildExecuteReq(chatExecuteContext);
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
QueryResult queryResult = chatQueryService.performExecution(executeQueryReq);
|
||||
QueryResult queryResult = doExecute(chatExecuteContext);
|
||||
|
||||
if (queryResult != null) {
|
||||
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
|
||||
queryResult.getQueryResults());
|
||||
@@ -48,16 +55,43 @@ public class SqlExecutor implements ChatExecutor {
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
private ExecuteQueryReq buildExecuteReq(ChatExecuteContext chatExecuteContext) {
|
||||
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
|
||||
return ExecuteQueryReq.builder()
|
||||
.queryId(chatExecuteContext.getQueryId())
|
||||
.chatId(chatExecuteContext.getChatId())
|
||||
.queryText(chatExecuteContext.getQueryText())
|
||||
.parseInfo(parseInfo)
|
||||
.saveAnswer(chatExecuteContext.isSaveAnswer())
|
||||
.user(chatExecuteContext.getUser())
|
||||
@SneakyThrows
|
||||
private QueryResult doExecute(ChatExecuteContext chatExecuteContext) {
|
||||
SemanticLayerService semanticLayer = ContextUtils.getBean(SemanticLayerService.class);
|
||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||
|
||||
ChatContext chatCtx = chatContextService.getOrCreateContext(chatExecuteContext.getChatId());
|
||||
|
||||
QuerySqlReq sqlReq = QuerySqlReq.builder()
|
||||
.sql(chatExecuteContext.getParseInfo().getSqlInfo().getCorrectS2SQL())
|
||||
.build();
|
||||
sqlReq.setSqlInfo(chatExecuteContext.getParseInfo().getSqlInfo());
|
||||
sqlReq.setDataSetId(chatExecuteContext.getParseInfo().getDataSetId());
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
SemanticQueryResp queryResp = semanticLayer.queryByReq(sqlReq, chatExecuteContext.getUser());
|
||||
QueryResult queryResult = new QueryResult();
|
||||
|
||||
if (queryResp != null) {
|
||||
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
||||
List<Map<String, Object>> resultList = queryResp == null ? new ArrayList<>()
|
||||
: queryResp.getResultList();
|
||||
List<QueryColumn> columns = queryResp == null ? new ArrayList<>() : queryResp.getColumns();
|
||||
queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime);
|
||||
queryResult.setQuerySql(queryResp.getSql());
|
||||
queryResult.setQueryResults(resultList);
|
||||
queryResult.setQueryColumns(columns);
|
||||
queryResult.setQueryMode(chatExecuteContext.getParseInfo().getQueryMode());
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
|
||||
chatCtx.setParseInfo(chatExecuteContext.getParseInfo());
|
||||
chatContextService.updateContext(chatCtx);
|
||||
} else {
|
||||
queryResult.setQueryState(QueryState.INVALID);
|
||||
queryResult.setQueryMode(chatExecuteContext.getParseInfo().getQueryMode());
|
||||
}
|
||||
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
public String buildSchemaStr(SemanticParseInfo parseInfo) {
|
||||
|
||||
@@ -15,7 +15,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||
@@ -69,11 +69,11 @@ public class NL2SQLParser implements ChatParser {
|
||||
}
|
||||
|
||||
processMultiTurn(chatParseContext);
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
addDynamicExemplars(chatParseContext.getAgent().getId(), queryReq);
|
||||
QueryTextReq queryTextReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
addDynamicExemplars(chatParseContext.getAgent().getId(), queryTextReq);
|
||||
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
|
||||
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryTextReq);
|
||||
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
|
||||
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
||||
}
|
||||
@@ -149,8 +149,8 @@ public class NL2SQLParser implements ChatParser {
|
||||
|
||||
// derive mapping result of current question and parsing result of last question.
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
MapResp currentMapResult = chatQueryService.performMapping(queryReq);
|
||||
QueryTextReq queryTextReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
MapResp currentMapResult = chatQueryService.performMapping(queryTextReq);
|
||||
|
||||
List<ParseResp> historyParseResults = getHistoryParseResult(chatParseContext.getChatId(), 1);
|
||||
if (historyParseResults.size() == 0) {
|
||||
@@ -168,7 +168,7 @@ public class NL2SQLParser implements ChatParser {
|
||||
.curtSchema(curtMapStr)
|
||||
.histSchema(histMapStr)
|
||||
.histSQL(histSQL)
|
||||
.llmConfig(queryReq.getLlmConfig())
|
||||
.llmConfig(queryTextReq.getLlmConfig())
|
||||
.build());
|
||||
chatParseContext.setQueryText(rewrittenQuery);
|
||||
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
|
||||
@@ -225,13 +225,13 @@ public class NL2SQLParser implements ChatParser {
|
||||
return contextualList;
|
||||
}
|
||||
|
||||
private void addDynamicExemplars(Integer agentId, QueryReq queryReq) {
|
||||
private void addDynamicExemplars(Integer agentId, QueryTextReq queryTextReq) {
|
||||
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
||||
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
|
||||
queryReq.getQueryText(), 5);
|
||||
queryReq.getDynamicExemplars().addAll(exemplars);
|
||||
queryTextReq.getQueryText(), 5);
|
||||
queryTextReq.getDynamicExemplars().addAll(exemplars);
|
||||
}
|
||||
|
||||
@Builder
|
||||
|
||||
@@ -10,7 +10,7 @@ import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
@@ -76,7 +76,7 @@ public class ChatQueryController {
|
||||
}
|
||||
|
||||
@PostMapping("queryContext")
|
||||
public Object queryContext(@RequestBody QueryReq queryCtx,
|
||||
public Object queryContext(@RequestBody QueryTextReq queryCtx,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
queryCtx.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.queryContext(queryCtx.getChatId());
|
||||
|
||||
@@ -21,7 +21,7 @@ import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
@@ -60,8 +60,8 @@ public class ChatServiceImpl implements ChatService {
|
||||
if (!agent.enableSearch()) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
return retrieveService.retrieve(queryReq);
|
||||
QueryTextReq queryTextReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
return retrieveService.retrieve(queryTextReq);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -137,8 +137,8 @@ public class ChatServiceImpl implements ChatService {
|
||||
}
|
||||
|
||||
private void supplyMapInfo(ChatParseContext chatParseContext) {
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
MapResp mapResp = chatQueryService.performMapping(queryReq);
|
||||
QueryTextReq queryTextReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
MapResp mapResp = chatQueryService.performMapping(queryTextReq);
|
||||
chatParseContext.setMapInfo(mapResp.getMapInfo());
|
||||
}
|
||||
|
||||
|
||||
@@ -4,35 +4,35 @@ import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
public class QueryReqConverter {
|
||||
|
||||
public static QueryReq buildText2SqlQueryReq(ChatParseContext chatParseContext) {
|
||||
QueryReq queryReq = new QueryReq();
|
||||
BeanMapper.mapper(chatParseContext, queryReq);
|
||||
public static QueryTextReq buildText2SqlQueryReq(ChatParseContext chatParseContext) {
|
||||
QueryTextReq queryTextReq = new QueryTextReq();
|
||||
BeanMapper.mapper(chatParseContext, queryTextReq);
|
||||
Agent agent = chatParseContext.getAgent();
|
||||
if (agent == null) {
|
||||
return queryReq;
|
||||
return queryTextReq;
|
||||
}
|
||||
if (agent.containsLLMParserTool() && agent.containsRuleTool()) {
|
||||
queryReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
|
||||
queryTextReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
|
||||
} else if (agent.containsLLMParserTool()) {
|
||||
queryReq.setText2SQLType(Text2SQLType.ONLY_LLM);
|
||||
queryTextReq.setText2SQLType(Text2SQLType.ONLY_LLM);
|
||||
} else if (agent.containsRuleTool()) {
|
||||
queryReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
||||
queryTextReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
||||
}
|
||||
queryReq.setDataSetIds(agent.getDataSetIds());
|
||||
if (Objects.nonNull(queryReq.getMapInfo())
|
||||
&& MapUtils.isNotEmpty(queryReq.getMapInfo().getDataSetElementMatches())) {
|
||||
queryReq.setMapInfo(queryReq.getMapInfo());
|
||||
queryTextReq.setDataSetIds(agent.getDataSetIds());
|
||||
if (Objects.nonNull(queryTextReq.getMapInfo())
|
||||
&& MapUtils.isNotEmpty(queryTextReq.getMapInfo().getDataSetElementMatches())) {
|
||||
queryTextReq.setMapInfo(queryTextReq.getMapInfo());
|
||||
}
|
||||
queryReq.setLlmConfig(agent.getLlmConfig());
|
||||
queryReq.setPromptConfig(agent.getPromptConfig());
|
||||
return queryReq;
|
||||
queryTextReq.setLlmConfig(agent.getLlmConfig());
|
||||
queryTextReq.setPromptConfig(agent.getPromptConfig());
|
||||
return queryTextReq;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user