(improvement)(Headless) support multiturn text-to-sql (#983)

This commit is contained in:
mainmain
2024-05-13 14:13:02 +08:00
committed by GitHub
parent 947a01e8ba
commit 0e28d6cbcc
15 changed files with 407 additions and 11 deletions

View File

@@ -1,18 +1,47 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.core.chat.mapper.SchemaMapper;
import com.tencent.supersonic.headless.core.chat.parser.llm.LLMRequestService;
import com.tencent.supersonic.headless.core.chat.parser.llm.RewriteQueryGeneration;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.server.service.ChatQueryService;
import java.util.List;
import java.util.Collections;
import java.util.ArrayList;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import com.tencent.supersonic.headless.server.service.impl.ChatQueryServiceImpl;
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.env.Environment;
import static com.tencent.supersonic.common.pojo.Constants.CONTEXT;
@Slf4j
public class NL2SQLParser implements ChatParser {
private int contextualNum = 5;
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
if (!chatParseContext.enableNL2SQL()) {
@@ -21,7 +50,15 @@ public class NL2SQLParser implements ChatParser {
if (checkSkip(parseResp)) {
return;
}
considerMultiturn(chatParseContext, parseResp);
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
Environment environment = ContextUtils.getBean(Environment.class);
String multiTurn = environment.getProperty("multi.turn");
if (StringUtils.isNotBlank(multiTurn) && Boolean.parseBoolean(multiTurn)) {
queryReq.setMapInfo(new SchemaMapInfo());
}
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
@@ -30,6 +67,100 @@ public class NL2SQLParser implements ChatParser {
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
}
private void considerMultiturn(ChatParseContext chatParseContext, ParseResp parseResp) {
Environment environment = ContextUtils.getBean(Environment.class);
RewriteQueryGeneration rewriteQueryGeneration = ContextUtils.getBean(RewriteQueryGeneration.class);
String multiTurn = environment.getProperty("multi.turn");
String multiNum = environment.getProperty("multi.num");
if (StringUtils.isBlank(multiTurn) || !Boolean.parseBoolean(multiTurn)) {
return;
}
log.info("multi turn text-to-sql!");
List<ParseResp> contextualList = getContextualList(parseResp, multiNum);
List<String> contextualQuestions = getContextualQuestionsWithLink(contextualList);
StringBuffer currentPromptSb = new StringBuffer();
if (contextualQuestions.size() == 0) {
currentPromptSb.append("contextualQuestions:" + "\n");
} else {
currentPromptSb.append("contextualQuestions:" + "\n" + String.join("\n", contextualQuestions) + "\n");
}
String currentQuestion = getQueryLinks(chatParseContext);
currentPromptSb.append("currentQuestion:" + currentQuestion + "\n");
currentPromptSb.append("rewritingCurrentQuestion:\n");
String rewriteQuery = rewriteQueryGeneration.generation(currentPromptSb.toString());
log.info("rewriteQuery:{}", rewriteQuery);
chatParseContext.setQueryText(rewriteQuery);
}
private List<String> getContextualQuestionsWithLink(List<ParseResp> contextualList) {
List<String> contextualQuestions = new ArrayList<>();
contextualList.stream().forEach(o -> {
Map<String, Object> map = JsonUtil.toMap(JsonUtil.toString(
o.getSelectedParses().get(0).getProperties().get(CONTEXT)), String.class, Object.class);
LLMReq llmReq = JsonUtil.toObject(JsonUtil.toString(map.get("llmReq")), LLMReq.class);
List<LLMReq.ElementValue> linking = llmReq.getLinking();
List<String> priorLinkingList = new ArrayList<>();
for (LLMReq.ElementValue priorLinking : linking) {
String fieldName = priorLinking.getFieldName();
String fieldValue = priorLinking.getFieldValue();
priorLinkingList.add("" + fieldValue + "‘是一个‘" + fieldName + "");
}
String linkingListStr = String.join("", priorLinkingList);
String questionAugmented = String.format("%s (补充信息:%s) ", o.getQueryText(), linkingListStr);
contextualQuestions.add(questionAugmented);
});
return contextualQuestions;
}
private List<ParseResp> getContextualList(ParseResp parseResp, String multiNum) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(
parseResp.getChatId()).stream().filter(o -> o.getSelectedParses().get(0)
.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)
).collect(Collectors.toList());
if (StringUtils.isNotBlank(multiNum) && StringUtils.isNumeric(multiNum)) {
int num = Integer.parseInt(multiNum);
contextualNum = Math.min(contextualNum, num);
}
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
Math.min(contextualNum, contextualParseInfoList.size()));
Collections.reverse(contextualList);
return contextualList;
}
private String getQueryLinks(ChatParseContext chatParseContext) {
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
ChatQueryServiceImpl chatQueryService = ContextUtils.getBean(ChatQueryServiceImpl.class);
// build queryContext and chatContext
QueryContext queryCtx = chatQueryService.buildQueryContext(queryReq);
// 1. mapper
if (Objects.isNull(chatParseContext.getMapInfo())
|| MapUtils.isEmpty(chatParseContext.getMapInfo().getDataSetElementMatches())) {
schemaMappers.forEach(mapper -> {
mapper.map(queryCtx);
});
}
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
Long dataSetId = requestService.getDataSetId(queryCtx);
log.info("dataSetId:{}", dataSetId);
if (dataSetId == null) {
return null;
}
List<LLMReq.ElementValue> linkingValues = requestService.getValueList(queryCtx, dataSetId);
List<String> priorLinkingList = new ArrayList<>();
for (LLMReq.ElementValue priorLinking : linkingValues) {
String fieldName = priorLinking.getFieldName();
String fieldValue = priorLinking.getFieldValue();
priorLinkingList.add("" + fieldValue + "‘是一个‘" + fieldName + "");
}
String linkingListStr = String.join("", priorLinkingList);
String questionAugmented = String.format("%s (补充信息:%s) ", chatParseContext.getQueryText(), linkingListStr);
log.info("questionAugmented:{}", questionAugmented);
return questionAugmented;
}
private boolean checkSkip(ParseResp parseResp) {
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
for (SemanticParseInfo semanticParseInfo : selectedParses) {

View File

@@ -18,4 +18,6 @@ public interface ChatParseMapper {
List<ChatParseDO> getParseInfoList(List<Long> questionIds);
List<ChatParseDO> getContextualParseInfo(Integer chatId);
}

View File

@@ -36,4 +36,6 @@ public interface ChatQueryRepository {
Boolean deleteChatQuery(Long questionId);
List<ParseResp> getContextualParseInfo(Integer chatId);
}

View File

@@ -26,6 +26,7 @@ import org.springframework.beans.BeanUtils;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
@@ -44,8 +45,8 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
private final ShowCaseCustomMapper showCaseCustomMapper;
public ChatQueryRepositoryImpl(ChatQueryDOMapper chatQueryDOMapper,
ChatParseMapper chatParseMapper,
ShowCaseCustomMapper showCaseCustomMapper) {
ChatParseMapper chatParseMapper,
ShowCaseCustomMapper showCaseCustomMapper) {
this.chatQueryDOMapper = chatQueryDOMapper;
this.chatParseMapper = chatParseMapper;
this.showCaseCustomMapper = showCaseCustomMapper;
@@ -131,7 +132,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
@Override
public List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq,
ParseResp parseResult, List<SemanticParseInfo> candidateParses) {
ParseResp parseResult, List<SemanticParseInfo> candidateParses) {
List<ChatParseDO> chatParseDOList = new ArrayList<>();
getChatParseDO(chatParseReq, parseResult.getQueryId(), candidateParses, chatParseDOList);
if (!CollectionUtils.isEmpty(candidateParses)) {
@@ -141,7 +142,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
}
public void getChatParseDO(ChatParseReq chatParseReq, Long queryId,
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
for (int i = 0; i < parses.size(); i++) {
ChatParseDO chatParseDO = new ChatParseDO();
chatParseDO.setChatId(Long.valueOf(chatParseReq.getChatId()));
@@ -193,4 +194,17 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
return chatQueryDOMapper.deleteByPrimaryKey(questionId);
}
@Override
public List<ParseResp> getContextualParseInfo(Integer chatId) {
List<ChatParseDO> chatParseDOList = chatParseMapper.getContextualParseInfo(chatId);
List<ParseResp> semanticParseInfoList = chatParseDOList.stream().map(parseInfo -> {
ParseResp parseResp = new ParseResp(chatId, parseInfo.getQueryText());
List<SemanticParseInfo> selectedParses = new ArrayList<>();
selectedParses.add(JSONObject.parseObject(parseInfo.getParseInfo(), SemanticParseInfo.class));
parseResp.setSelectedParses(selectedParses);
return parseResp;
}).collect(Collectors.toList());
return semanticParseInfoList;
}
}

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
@@ -24,11 +25,8 @@ import java.util.List;
@RequestMapping({"/api/chat/manage", "/openapi/chat/manage"})
public class ChatController {
private final ChatManageService chatService;
public ChatController(ChatManageService chatService) {
this.chatService = chatService;
}
@Autowired
private ChatManageService chatService;
@PostMapping("/save")
public Boolean save(@RequestParam(value = "chatName") String chatName,

View File

@@ -64,6 +64,7 @@ public class ChatServiceImpl implements ChatService {
@Override
public ParseResp performParsing(ChatParseReq chatParseReq) {
String queryText = chatParseReq.getQueryText();
ParseResp parseResp = new ParseResp(chatParseReq.getChatId(), chatParseReq.getQueryText());
chatManageService.createChatQuery(chatParseReq, parseResp);
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
@@ -73,6 +74,8 @@ public class ChatServiceImpl implements ChatService {
for (ParseResultProcessor processor : parseResultProcessors) {
processor.process(chatParseContext, parseResp);
}
parseResp.setQueryText(queryText);
chatParseReq.setQueryText(queryText);
chatManageService.batchAddParse(chatParseReq, parseResp);
return parseResp;
}

View File

@@ -46,4 +46,10 @@
</foreach>
</select>
<select id="getContextualParseInfo" resultMap="ChatParse">
select *
from s2_chat_parse
where chat_id = #{chatId} order by question_id desc limit 10
</select>
</mapper>