mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(fix)(headless)Fix multi-turn conversation issue, recalling history ChatQuery instead of ChatParse.
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.parser.ParserConfig;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
@@ -19,6 +19,7 @@ import dev.langchain4j.provider.ModelProvider;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||
@@ -70,8 +71,8 @@ public class PlainTextExecutor implements ChatQueryExecutor {
|
||||
? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig;
|
||||
|
||||
if (Boolean.TRUE.equals(multiTurnConfig)) {
|
||||
List<ParseResp> parseResps = getHistoryParseResult(executeContext.getChatId(), 5);
|
||||
parseResps.stream().forEach(p -> {
|
||||
List<QueryResp> queryResps = getHistoryQueries(executeContext.getChatId(), 5);
|
||||
queryResps.stream().forEach(p -> {
|
||||
historyInput.append(p.getQueryText());
|
||||
historyInput.append(";");
|
||||
});
|
||||
@@ -80,12 +81,15 @@ public class PlainTextExecutor implements ChatQueryExecutor {
|
||||
return historyInput.toString();
|
||||
}
|
||||
|
||||
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
|
||||
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
|
||||
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
|
||||
.stream().filter(p -> p.getState() == ParseResp.ParseState.COMPLETED).collect(Collectors.toList());
|
||||
private List<QueryResp> getHistoryQueries(int chatId, int multiNum) {
|
||||
ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class);
|
||||
List<QueryResp> contextualParseInfoList = chatManageService.getChatQueries(chatId)
|
||||
.stream()
|
||||
.filter(q -> Objects.nonNull(q.getQueryResult())
|
||||
&& q.getQueryResult().getQueryState() == QueryState.SUCCESS)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
|
||||
List<QueryResp> contextualList = contextualParseInfoList.subList(0,
|
||||
Math.min(multiNum, contextualParseInfoList.size()));
|
||||
Collections.reverse(contextualList);
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
@@ -19,6 +20,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
@@ -176,21 +178,22 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||
MapResp currentMapResult = chatLayerService.performMapping(queryNLReq);
|
||||
|
||||
List<ParseResp> historyParseResults = getHistoryParseResult(parseContext.getChatId(), 1);
|
||||
if (historyParseResults.size() == 0) {
|
||||
List<QueryResp> historyQueries = getHistoryQueries(parseContext.getChatId(), 1);
|
||||
if (historyQueries.size() == 0) {
|
||||
return;
|
||||
}
|
||||
ParseResp lastParseResult = historyParseResults.get(0);
|
||||
Long dataId = lastParseResult.getSelectedParses().get(0).getDataSetId();
|
||||
QueryResp lastQuery = historyQueries.get(0);
|
||||
SemanticParseInfo lastParseInfo = lastQuery.getParseInfos().get(0);
|
||||
Long dataId = lastParseInfo.getDataSetId();
|
||||
|
||||
String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
|
||||
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
|
||||
String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectedS2SQL();
|
||||
String histMapStr = generateSchemaPrompt(lastParseInfo.getElementMatches());
|
||||
String histSQL = lastParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
|
||||
Map<String, Object> variables = new HashMap<>();
|
||||
variables.put("current_question", currentMapResult.getQueryText());
|
||||
variables.put("current_schema", curtMapStr);
|
||||
variables.put("history_question", lastParseResult.getQueryText());
|
||||
variables.put("history_question", lastQuery.getQueryText());
|
||||
variables.put("history_schema", histMapStr);
|
||||
variables.put("history_sql", histSQL);
|
||||
|
||||
@@ -203,7 +206,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
|
||||
parseContext.setQueryText(rewrittenQuery);
|
||||
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
|
||||
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
|
||||
lastQuery.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
|
||||
}
|
||||
|
||||
private String rewriteErrorMessage(ChatLanguageModel chatLanguageModel, String userQuestion,
|
||||
@@ -255,12 +258,15 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
return prompt.toString();
|
||||
}
|
||||
|
||||
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
|
||||
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
|
||||
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
|
||||
.stream().filter(p -> p.getState() == ParseResp.ParseState.COMPLETED).collect(Collectors.toList());
|
||||
private List<QueryResp> getHistoryQueries(int chatId, int multiNum) {
|
||||
ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class);
|
||||
List<QueryResp> contextualParseInfoList = chatManageService.getChatQueries(chatId)
|
||||
.stream()
|
||||
.filter(q -> Objects.nonNull(q.getQueryResult())
|
||||
&& q.getQueryResult().getQueryState() == QueryState.SUCCESS)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
|
||||
List<QueryResp> contextualList = contextualParseInfoList.subList(0,
|
||||
Math.min(multiNum, contextualParseInfoList.size()));
|
||||
Collections.reverse(contextualList);
|
||||
return contextualList;
|
||||
|
||||
@@ -16,6 +16,7 @@ public class PlainTextParser implements ChatQueryParser {
|
||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
parseInfo.setQueryMode("PLAIN_TEXT");
|
||||
parseResp.getSelectedParses().add(parseInfo);
|
||||
parseResp.setState(ParseResp.ParseState.COMPLETED);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,142 +1,25 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
|
||||
@Data
|
||||
public class ChatParseDO {
|
||||
|
||||
/**
|
||||
* questionId
|
||||
*/
|
||||
private Long questionId;
|
||||
|
||||
/**
|
||||
* chatId
|
||||
*/
|
||||
private Long chatId;
|
||||
private Integer chatId;
|
||||
|
||||
/**
|
||||
* parseId
|
||||
*/
|
||||
private Integer parseId;
|
||||
|
||||
/**
|
||||
* createTime
|
||||
*/
|
||||
private Date createTime;
|
||||
|
||||
/**
|
||||
* queryText
|
||||
*/
|
||||
private String queryText;
|
||||
|
||||
/**
|
||||
* userName
|
||||
*/
|
||||
private String userName;
|
||||
|
||||
|
||||
/**
|
||||
* parseInfo
|
||||
*/
|
||||
private String parseInfo;
|
||||
|
||||
/**
|
||||
* isCandidate
|
||||
*/
|
||||
private Integer isCandidate;
|
||||
|
||||
/**
|
||||
* return question_id
|
||||
*/
|
||||
public Long getQuestionId() {
|
||||
return questionId;
|
||||
}
|
||||
|
||||
/**
|
||||
* questionId
|
||||
*/
|
||||
public void setQuestionId(Long questionId) {
|
||||
this.questionId = questionId;
|
||||
}
|
||||
|
||||
/**
|
||||
* return create_time
|
||||
*/
|
||||
public Date getCreateTime() {
|
||||
return createTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* createTime
|
||||
*/
|
||||
public void setCreateTime(Date createTime) {
|
||||
this.createTime = createTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* return user_name
|
||||
*/
|
||||
public String getUserName() {
|
||||
return userName;
|
||||
}
|
||||
|
||||
/**
|
||||
* userName
|
||||
*/
|
||||
public void setUserName(String userName) {
|
||||
this.userName = userName == null ? null : userName.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* return chat_id
|
||||
*/
|
||||
public Long getChatId() {
|
||||
return chatId;
|
||||
}
|
||||
|
||||
/**
|
||||
* chatId
|
||||
*/
|
||||
public void setChatId(Long chatId) {
|
||||
this.chatId = chatId;
|
||||
}
|
||||
|
||||
/**
|
||||
* return query_text
|
||||
*/
|
||||
public String getQueryText() {
|
||||
return queryText;
|
||||
}
|
||||
|
||||
/**
|
||||
* queryText
|
||||
*/
|
||||
public void setQueryText(String queryText) {
|
||||
this.queryText = queryText == null ? null : queryText.trim();
|
||||
}
|
||||
|
||||
public Integer getIsCandidate() {
|
||||
return isCandidate;
|
||||
}
|
||||
|
||||
public Integer getParseId() {
|
||||
return parseId;
|
||||
}
|
||||
|
||||
public String getParseInfo() {
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
public void setParseId(Integer parseId) {
|
||||
this.parseId = parseId;
|
||||
}
|
||||
|
||||
public void setIsCandidate(Integer isCandidate) {
|
||||
this.isCandidate = isCandidate;
|
||||
}
|
||||
|
||||
public void setParseInfo(String parseInfo) {
|
||||
this.parseInfo = parseInfo;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,8 @@ public interface ChatQueryRepository {
|
||||
|
||||
QueryResp getChatQuery(Long queryId);
|
||||
|
||||
List<QueryResp> getChatQueries(Integer chatId);
|
||||
|
||||
ChatQueryDO getChatQueryDO(Long queryId);
|
||||
|
||||
List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
|
||||
@@ -35,6 +37,4 @@ public interface ChatQueryRepository {
|
||||
|
||||
List<ChatParseDO> getParseInfoList(List<Long> questionIds);
|
||||
|
||||
List<ParseResp> getContextualParseInfo(Integer chatId);
|
||||
|
||||
}
|
||||
|
||||
@@ -90,6 +90,16 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
return chatQueryDOMapper.selectById(queryId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<QueryResp> getChatQueries(Integer chatId) {
|
||||
QueryWrapper<ChatQueryDO> queryWrapper = new QueryWrapper<>();
|
||||
queryWrapper.lambda().eq(ChatQueryDO::getChatId, chatId);
|
||||
queryWrapper.lambda().orderByDesc(ChatQueryDO::getQuestionId);
|
||||
return chatQueryDOMapper.selectList(queryWrapper).stream()
|
||||
.map(q -> convertTo(q))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) {
|
||||
return showCaseCustomMapper.queryShowCase(pageQueryInfoReq.getLimitStart(),
|
||||
@@ -145,7 +155,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
|
||||
for (int i = 0; i < parses.size(); i++) {
|
||||
ChatParseDO chatParseDO = new ChatParseDO();
|
||||
chatParseDO.setChatId(Long.valueOf(chatParseReq.getChatId()));
|
||||
chatParseDO.setChatId(chatParseReq.getChatId());
|
||||
chatParseDO.setQuestionId(queryId);
|
||||
chatParseDO.setQueryText(chatParseReq.getQueryText());
|
||||
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
|
||||
@@ -179,17 +189,4 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
return chatParseMapper.getParseInfoList(questionIds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ParseResp> getContextualParseInfo(Integer chatId) {
|
||||
List<ChatParseDO> chatParseDOList = chatParseMapper.getContextualParseInfo(chatId);
|
||||
List<ParseResp> semanticParseInfoList = chatParseDOList.stream().map(parseInfo -> {
|
||||
ParseResp parseResp = new ParseResp(parseInfo.getQueryText());
|
||||
List<SemanticParseInfo> selectedParses = new ArrayList<>();
|
||||
selectedParses.add(JSONObject.parseObject(parseInfo.getParseInfo(), SemanticParseInfo.class));
|
||||
parseResp.setSelectedParses(selectedParses);
|
||||
return parseResp;
|
||||
}).collect(Collectors.toList());
|
||||
return semanticParseInfoList;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -35,6 +35,8 @@ public interface ChatManageService {
|
||||
|
||||
QueryResp getChatQuery(Long queryId);
|
||||
|
||||
List<QueryResp> getChatQueries(Integer chatId);
|
||||
|
||||
ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId);
|
||||
|
||||
ChatQueryDO saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult);
|
||||
|
||||
@@ -107,6 +107,13 @@ public class ChatManageServiceImpl implements ChatManageService {
|
||||
return chatQueryRepository.getChatQuery(queryId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<QueryResp> getChatQueries(Integer chatId) {
|
||||
List<QueryResp> queries = chatQueryRepository.getChatQueries(chatId);
|
||||
fillParseInfo(queries);
|
||||
return queries;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) {
|
||||
ShowCaseResp showCaseResp = new ShowCaseResp();
|
||||
|
||||
@@ -2,6 +2,6 @@ package com.tencent.supersonic.headless.api.pojo.enums;
|
||||
|
||||
public enum SchemaType {
|
||||
|
||||
VIEW,
|
||||
DATASET,
|
||||
MODEL
|
||||
}
|
||||
|
||||
@@ -376,7 +376,7 @@ public class SchemaServiceImpl implements SchemaService {
|
||||
semanticSchemaResp.setModelResps(modelList);
|
||||
semanticSchemaResp.setModelRelas(modelRelas);
|
||||
semanticSchemaResp.setModelIds(modelIds);
|
||||
semanticSchemaResp.setSchemaType(SchemaType.VIEW);
|
||||
semanticSchemaResp.setSchemaType(SchemaType.DATASET);
|
||||
} else if (!CollectionUtils.isEmpty(schemaFilterReq.getModelIds())) {
|
||||
List<ModelSchemaResp> modelSchemaResps = fetchModelSchemaResps(schemaFilterReq.getModelIds());
|
||||
semanticSchemaResp.setMetrics(modelSchemaResps.stream().map(ModelSchemaResp::getMetrics)
|
||||
|
||||
@@ -359,4 +359,7 @@ alter table s2_agent add `model_config` text null;
|
||||
alter table s2_agent add `enable_memory_review` tinyint DEFAULT 0;
|
||||
|
||||
--20240718
|
||||
alter table s2_chat_memory add `side_info` TEXT DEFAULT NULL COMMENT '辅助信息';
|
||||
alter table s2_chat_memory add `side_info` TEXT DEFAULT NULL COMMENT '辅助信息';
|
||||
|
||||
--20240730
|
||||
alter table s2_chat_parse modify column `chat_id` int(11);
|
||||
@@ -45,7 +45,7 @@ CREATE TABLE IF NOT EXISTS `s2_chat_query`
|
||||
CREATE TABLE IF NOT EXISTS `s2_chat_parse`
|
||||
(
|
||||
`question_id` BIGINT NOT NULL,
|
||||
`chat_id` BIGINT NOT NULL ,
|
||||
`chat_id` INT NOT NULL ,
|
||||
`parse_id` INT NOT NULL ,
|
||||
`create_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
`query_text` varchar(500),
|
||||
|
||||
@@ -169,8 +169,8 @@ CREATE TABLE IF NOT EXISTS `s2_chat_context` (
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `s2_chat_parse` (
|
||||
`question_id` bigint(20) NOT NULL,
|
||||
`chat_id` bigint(20) NOT NULL,
|
||||
`question_id` bigint NOT NULL,
|
||||
`chat_id` int(11) NOT NULL,
|
||||
`parse_id` int(11) NOT NULL,
|
||||
`create_time` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
`query_text` varchar(500) DEFAULT NULL,
|
||||
|
||||
@@ -44,8 +44,8 @@ CREATE TABLE IF NOT EXISTS `s2_chat_query`
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `s2_chat_parse`
|
||||
(
|
||||
`question_id` BIGINT NOT NULL,
|
||||
`chat_id` BIGINT NOT NULL ,
|
||||
`question_id` BIGINT NOT NULL,
|
||||
`chat_id` INT NOT NULL ,
|
||||
`parse_id` INT NOT NULL ,
|
||||
`create_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
`query_text` varchar(500),
|
||||
|
||||
Reference in New Issue
Block a user