(fix)(headless)Fix multi-turn conversation issue, recalling history ChatQuery instead of ChatParse.

This commit is contained in:
jerryjzhang
2024-07-30 21:57:44 +08:00
parent 12a504585f
commit 4a5bb9e457
14 changed files with 71 additions and 168 deletions

View File

@@ -1,13 +1,13 @@
package com.tencent.supersonic.chat.server.executor;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.parser.ParserConfig;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import dev.langchain4j.data.message.AiMessage;
@@ -19,6 +19,7 @@ import dev.langchain4j.provider.ModelProvider;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
@@ -70,8 +71,8 @@ public class PlainTextExecutor implements ChatQueryExecutor {
? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig;
if (Boolean.TRUE.equals(multiTurnConfig)) {
List<ParseResp> parseResps = getHistoryParseResult(executeContext.getChatId(), 5);
parseResps.stream().forEach(p -> {
List<QueryResp> queryResps = getHistoryQueries(executeContext.getChatId(), 5);
queryResps.stream().forEach(p -> {
historyInput.append(p.getQueryText());
historyInput.append(";");
});
@@ -80,12 +81,15 @@ public class PlainTextExecutor implements ChatQueryExecutor {
return historyInput.toString();
}
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
.stream().filter(p -> p.getState() == ParseResp.ParseState.COMPLETED).collect(Collectors.toList());
private List<QueryResp> getHistoryQueries(int chatId, int multiNum) {
ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class);
List<QueryResp> contextualParseInfoList = chatManageService.getChatQueries(chatId)
.stream()
.filter(q -> Objects.nonNull(q.getQueryResult())
&& q.getQueryResult().getQueryState() == QueryState.SUCCESS)
.collect(Collectors.toList());
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
List<QueryResp> contextualList = contextualParseInfoList.subList(0,
Math.min(multiNum, contextualParseInfoList.size()));
Collections.reverse(contextualList);

View File

@@ -1,10 +1,11 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.service.ChatContextService;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
@@ -19,6 +20,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -176,21 +178,22 @@ public class NL2SQLParser implements ChatQueryParser {
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
MapResp currentMapResult = chatLayerService.performMapping(queryNLReq);
List<ParseResp> historyParseResults = getHistoryParseResult(parseContext.getChatId(), 1);
if (historyParseResults.size() == 0) {
List<QueryResp> historyQueries = getHistoryQueries(parseContext.getChatId(), 1);
if (historyQueries.size() == 0) {
return;
}
ParseResp lastParseResult = historyParseResults.get(0);
Long dataId = lastParseResult.getSelectedParses().get(0).getDataSetId();
QueryResp lastQuery = historyQueries.get(0);
SemanticParseInfo lastParseInfo = lastQuery.getParseInfos().get(0);
Long dataId = lastParseInfo.getDataSetId();
String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectedS2SQL();
String histMapStr = generateSchemaPrompt(lastParseInfo.getElementMatches());
String histSQL = lastParseInfo.getSqlInfo().getCorrectedS2SQL();
Map<String, Object> variables = new HashMap<>();
variables.put("current_question", currentMapResult.getQueryText());
variables.put("current_schema", curtMapStr);
variables.put("history_question", lastParseResult.getQueryText());
variables.put("history_question", lastQuery.getQueryText());
variables.put("history_schema", histMapStr);
variables.put("history_sql", histSQL);
@@ -203,7 +206,7 @@ public class NL2SQLParser implements ChatQueryParser {
parseContext.setQueryText(rewrittenQuery);
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
lastQuery.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
}
private String rewriteErrorMessage(ChatLanguageModel chatLanguageModel, String userQuestion,
@@ -255,12 +258,15 @@ public class NL2SQLParser implements ChatQueryParser {
return prompt.toString();
}
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
.stream().filter(p -> p.getState() == ParseResp.ParseState.COMPLETED).collect(Collectors.toList());
private List<QueryResp> getHistoryQueries(int chatId, int multiNum) {
ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class);
List<QueryResp> contextualParseInfoList = chatManageService.getChatQueries(chatId)
.stream()
.filter(q -> Objects.nonNull(q.getQueryResult())
&& q.getQueryResult().getQueryState() == QueryState.SUCCESS)
.collect(Collectors.toList());
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
List<QueryResp> contextualList = contextualParseInfoList.subList(0,
Math.min(multiNum, contextualParseInfoList.size()));
Collections.reverse(contextualList);
return contextualList;

View File

@@ -16,6 +16,7 @@ public class PlainTextParser implements ChatQueryParser {
SemanticParseInfo parseInfo = new SemanticParseInfo();
parseInfo.setQueryMode("PLAIN_TEXT");
parseResp.getSelectedParses().add(parseInfo);
parseResp.setState(ParseResp.ParseState.COMPLETED);
}
}

View File

@@ -1,142 +1,25 @@
package com.tencent.supersonic.chat.server.persistence.dataobject;
import lombok.Data;
import java.util.Date;
@Data
public class ChatParseDO {
/**
* questionId
*/
private Long questionId;
/**
* chatId
*/
private Long chatId;
private Integer chatId;
/**
* parseId
*/
private Integer parseId;
/**
* createTime
*/
private Date createTime;
/**
* queryText
*/
private String queryText;
/**
* userName
*/
private String userName;
/**
* parseInfo
*/
private String parseInfo;
/**
* isCandidate
*/
private Integer isCandidate;
/**
* return question_id
*/
public Long getQuestionId() {
return questionId;
}
/**
* questionId
*/
public void setQuestionId(Long questionId) {
this.questionId = questionId;
}
/**
* return create_time
*/
public Date getCreateTime() {
return createTime;
}
/**
* createTime
*/
public void setCreateTime(Date createTime) {
this.createTime = createTime;
}
/**
* return user_name
*/
public String getUserName() {
return userName;
}
/**
* userName
*/
public void setUserName(String userName) {
this.userName = userName == null ? null : userName.trim();
}
/**
* return chat_id
*/
public Long getChatId() {
return chatId;
}
/**
* chatId
*/
public void setChatId(Long chatId) {
this.chatId = chatId;
}
/**
* return query_text
*/
public String getQueryText() {
return queryText;
}
/**
* queryText
*/
public void setQueryText(String queryText) {
this.queryText = queryText == null ? null : queryText.trim();
}
public Integer getIsCandidate() {
return isCandidate;
}
public Integer getParseId() {
return parseId;
}
public String getParseInfo() {
return parseInfo;
}
public void setParseId(Integer parseId) {
this.parseId = parseId;
}
public void setIsCandidate(Integer isCandidate) {
this.isCandidate = isCandidate;
}
public void setParseInfo(String parseInfo) {
this.parseInfo = parseInfo;
}
}

View File

@@ -18,6 +18,8 @@ public interface ChatQueryRepository {
QueryResp getChatQuery(Long queryId);
List<QueryResp> getChatQueries(Integer chatId);
ChatQueryDO getChatQueryDO(Long queryId);
List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
@@ -35,6 +37,4 @@ public interface ChatQueryRepository {
List<ChatParseDO> getParseInfoList(List<Long> questionIds);
List<ParseResp> getContextualParseInfo(Integer chatId);
}

View File

@@ -90,6 +90,16 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
return chatQueryDOMapper.selectById(queryId);
}
@Override
public List<QueryResp> getChatQueries(Integer chatId) {
QueryWrapper<ChatQueryDO> queryWrapper = new QueryWrapper<>();
queryWrapper.lambda().eq(ChatQueryDO::getChatId, chatId);
queryWrapper.lambda().orderByDesc(ChatQueryDO::getQuestionId);
return chatQueryDOMapper.selectList(queryWrapper).stream()
.map(q -> convertTo(q))
.collect(Collectors.toList());
}
@Override
public List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) {
return showCaseCustomMapper.queryShowCase(pageQueryInfoReq.getLimitStart(),
@@ -145,7 +155,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
for (int i = 0; i < parses.size(); i++) {
ChatParseDO chatParseDO = new ChatParseDO();
chatParseDO.setChatId(Long.valueOf(chatParseReq.getChatId()));
chatParseDO.setChatId(chatParseReq.getChatId());
chatParseDO.setQuestionId(queryId);
chatParseDO.setQueryText(chatParseReq.getQueryText());
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
@@ -179,17 +189,4 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
return chatParseMapper.getParseInfoList(questionIds);
}
@Override
public List<ParseResp> getContextualParseInfo(Integer chatId) {
List<ChatParseDO> chatParseDOList = chatParseMapper.getContextualParseInfo(chatId);
List<ParseResp> semanticParseInfoList = chatParseDOList.stream().map(parseInfo -> {
ParseResp parseResp = new ParseResp(parseInfo.getQueryText());
List<SemanticParseInfo> selectedParses = new ArrayList<>();
selectedParses.add(JSONObject.parseObject(parseInfo.getParseInfo(), SemanticParseInfo.class));
parseResp.setSelectedParses(selectedParses);
return parseResp;
}).collect(Collectors.toList());
return semanticParseInfoList;
}
}

View File

@@ -35,6 +35,8 @@ public interface ChatManageService {
QueryResp getChatQuery(Long queryId);
List<QueryResp> getChatQueries(Integer chatId);
ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId);
ChatQueryDO saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult);

View File

@@ -107,6 +107,13 @@ public class ChatManageServiceImpl implements ChatManageService {
return chatQueryRepository.getChatQuery(queryId);
}
@Override
public List<QueryResp> getChatQueries(Integer chatId) {
List<QueryResp> queries = chatQueryRepository.getChatQueries(chatId);
fillParseInfo(queries);
return queries;
}
@Override
public ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) {
ShowCaseResp showCaseResp = new ShowCaseResp();

View File

@@ -2,6 +2,6 @@ package com.tencent.supersonic.headless.api.pojo.enums;
public enum SchemaType {
VIEW,
DATASET,
MODEL
}

View File

@@ -376,7 +376,7 @@ public class SchemaServiceImpl implements SchemaService {
semanticSchemaResp.setModelResps(modelList);
semanticSchemaResp.setModelRelas(modelRelas);
semanticSchemaResp.setModelIds(modelIds);
semanticSchemaResp.setSchemaType(SchemaType.VIEW);
semanticSchemaResp.setSchemaType(SchemaType.DATASET);
} else if (!CollectionUtils.isEmpty(schemaFilterReq.getModelIds())) {
List<ModelSchemaResp> modelSchemaResps = fetchModelSchemaResps(schemaFilterReq.getModelIds());
semanticSchemaResp.setMetrics(modelSchemaResps.stream().map(ModelSchemaResp::getMetrics)

View File

@@ -359,4 +359,7 @@ alter table s2_agent add `model_config` text null;
alter table s2_agent add `enable_memory_review` tinyint DEFAULT 0;
--20240718
alter table s2_chat_memory add `side_info` TEXT DEFAULT NULL COMMENT '辅助信息';
alter table s2_chat_memory add `side_info` TEXT DEFAULT NULL COMMENT '辅助信息';
--20240730
alter table s2_chat_parse modify column `chat_id` int(11);

View File

@@ -45,7 +45,7 @@ CREATE TABLE IF NOT EXISTS `s2_chat_query`
CREATE TABLE IF NOT EXISTS `s2_chat_parse`
(
`question_id` BIGINT NOT NULL,
`chat_id` BIGINT NOT NULL ,
`chat_id` INT NOT NULL ,
`parse_id` INT NOT NULL ,
`create_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`query_text` varchar(500),

View File

@@ -169,8 +169,8 @@ CREATE TABLE IF NOT EXISTS `s2_chat_context` (
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
CREATE TABLE IF NOT EXISTS `s2_chat_parse` (
`question_id` bigint(20) NOT NULL,
`chat_id` bigint(20) NOT NULL,
`question_id` bigint NOT NULL,
`chat_id` int(11) NOT NULL,
`parse_id` int(11) NOT NULL,
`create_time` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`query_text` varchar(500) DEFAULT NULL,

View File

@@ -44,8 +44,8 @@ CREATE TABLE IF NOT EXISTS `s2_chat_query`
CREATE TABLE IF NOT EXISTS `s2_chat_parse`
(
`question_id` BIGINT NOT NULL,
`chat_id` BIGINT NOT NULL ,
`question_id` BIGINT NOT NULL,
`chat_id` INT NOT NULL ,
`parse_id` INT NOT NULL ,
`create_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`query_text` varchar(500),