diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java index 63c03194a..79eb00a10 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java @@ -1,13 +1,13 @@ package com.tencent.supersonic.chat.server.executor; +import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.parser.ParserConfig; -import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.pojo.ExecuteContext; import com.tencent.supersonic.chat.server.service.AgentService; +import com.tencent.supersonic.chat.server.service.ChatManageService; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryState; import dev.langchain4j.data.message.AiMessage; @@ -19,6 +19,7 @@ import dev.langchain4j.provider.ModelProvider; import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.stream.Collectors; import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE; @@ -70,8 +71,8 @@ public class PlainTextExecutor implements ChatQueryExecutor { ? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig; if (Boolean.TRUE.equals(multiTurnConfig)) { - List parseResps = getHistoryParseResult(executeContext.getChatId(), 5); - parseResps.stream().forEach(p -> { + List queryResps = getHistoryQueries(executeContext.getChatId(), 5); + queryResps.stream().forEach(p -> { historyInput.append(p.getQueryText()); historyInput.append(";"); }); @@ -80,12 +81,15 @@ public class PlainTextExecutor implements ChatQueryExecutor { return historyInput.toString(); } - private List getHistoryParseResult(int chatId, int multiNum) { - ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class); - List contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId) - .stream().filter(p -> p.getState() == ParseResp.ParseState.COMPLETED).collect(Collectors.toList()); + private List getHistoryQueries(int chatId, int multiNum) { + ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class); + List contextualParseInfoList = chatManageService.getChatQueries(chatId) + .stream() + .filter(q -> Objects.nonNull(q.getQueryResult()) + && q.getQueryResult().getQueryState() == QueryState.SUCCESS) + .collect(Collectors.toList()); - List contextualList = contextualParseInfoList.subList(0, + List contextualList = contextualParseInfoList.subList(0, Math.min(multiNum, contextualParseInfoList.size())); Collections.reverse(contextualList); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 136dfb6dd..5330a957d 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -1,10 +1,11 @@ package com.tencent.supersonic.chat.server.parser; +import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; -import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.chat.server.service.ChatContextService; +import com.tencent.supersonic.chat.server.service.ChatManageService; import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; @@ -19,6 +20,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.server.pojo.ChatContext; +import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.server.facade.service.ChatLayerService; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -176,21 +178,22 @@ public class NL2SQLParser implements ChatQueryParser { QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext); MapResp currentMapResult = chatLayerService.performMapping(queryNLReq); - List historyParseResults = getHistoryParseResult(parseContext.getChatId(), 1); - if (historyParseResults.size() == 0) { + List historyQueries = getHistoryQueries(parseContext.getChatId(), 1); + if (historyQueries.size() == 0) { return; } - ParseResp lastParseResult = historyParseResults.get(0); - Long dataId = lastParseResult.getSelectedParses().get(0).getDataSetId(); + QueryResp lastQuery = historyQueries.get(0); + SemanticParseInfo lastParseInfo = lastQuery.getParseInfos().get(0); + Long dataId = lastParseInfo.getDataSetId(); String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId)); - String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches()); - String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectedS2SQL(); + String histMapStr = generateSchemaPrompt(lastParseInfo.getElementMatches()); + String histSQL = lastParseInfo.getSqlInfo().getCorrectedS2SQL(); Map variables = new HashMap<>(); variables.put("current_question", currentMapResult.getQueryText()); variables.put("current_schema", curtMapStr); - variables.put("history_question", lastParseResult.getQueryText()); + variables.put("history_question", lastQuery.getQueryText()); variables.put("history_schema", histMapStr); variables.put("history_sql", histSQL); @@ -203,7 +206,7 @@ public class NL2SQLParser implements ChatQueryParser { parseContext.setQueryText(rewrittenQuery); log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", - lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery); + lastQuery.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery); } private String rewriteErrorMessage(ChatLanguageModel chatLanguageModel, String userQuestion, @@ -255,12 +258,15 @@ public class NL2SQLParser implements ChatQueryParser { return prompt.toString(); } - private List getHistoryParseResult(int chatId, int multiNum) { - ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class); - List contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId) - .stream().filter(p -> p.getState() == ParseResp.ParseState.COMPLETED).collect(Collectors.toList()); + private List getHistoryQueries(int chatId, int multiNum) { + ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class); + List contextualParseInfoList = chatManageService.getChatQueries(chatId) + .stream() + .filter(q -> Objects.nonNull(q.getQueryResult()) + && q.getQueryResult().getQueryState() == QueryState.SUCCESS) + .collect(Collectors.toList()); - List contextualList = contextualParseInfoList.subList(0, + List contextualList = contextualParseInfoList.subList(0, Math.min(multiNum, contextualParseInfoList.size())); Collections.reverse(contextualList); return contextualList; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java index 0c8267d0d..7f5de681b 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java @@ -16,6 +16,7 @@ public class PlainTextParser implements ChatQueryParser { SemanticParseInfo parseInfo = new SemanticParseInfo(); parseInfo.setQueryMode("PLAIN_TEXT"); parseResp.getSelectedParses().add(parseInfo); + parseResp.setState(ParseResp.ParseState.COMPLETED); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatParseDO.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatParseDO.java index 608ca0257..6cc3a514c 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatParseDO.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatParseDO.java @@ -1,142 +1,25 @@ package com.tencent.supersonic.chat.server.persistence.dataobject; +import lombok.Data; + import java.util.Date; - +@Data public class ChatParseDO { - /** - * questionId - */ private Long questionId; - /** - * chatId - */ - private Long chatId; + private Integer chatId; - /** - * parseId - */ private Integer parseId; - /** - * createTime - */ private Date createTime; - /** - * queryText - */ private String queryText; - /** - * userName - */ private String userName; - - /** - * parseInfo - */ private String parseInfo; - /** - * isCandidate - */ private Integer isCandidate; - - /** - * return question_id - */ - public Long getQuestionId() { - return questionId; - } - - /** - * questionId - */ - public void setQuestionId(Long questionId) { - this.questionId = questionId; - } - - /** - * return create_time - */ - public Date getCreateTime() { - return createTime; - } - - /** - * createTime - */ - public void setCreateTime(Date createTime) { - this.createTime = createTime; - } - - /** - * return user_name - */ - public String getUserName() { - return userName; - } - - /** - * userName - */ - public void setUserName(String userName) { - this.userName = userName == null ? null : userName.trim(); - } - - /** - * return chat_id - */ - public Long getChatId() { - return chatId; - } - - /** - * chatId - */ - public void setChatId(Long chatId) { - this.chatId = chatId; - } - - /** - * return query_text - */ - public String getQueryText() { - return queryText; - } - - /** - * queryText - */ - public void setQueryText(String queryText) { - this.queryText = queryText == null ? null : queryText.trim(); - } - - public Integer getIsCandidate() { - return isCandidate; - } - - public Integer getParseId() { - return parseId; - } - - public String getParseInfo() { - return parseInfo; - } - - public void setParseId(Integer parseId) { - this.parseId = parseId; - } - - public void setIsCandidate(Integer isCandidate) { - this.isCandidate = isCandidate; - } - - public void setParseInfo(String parseInfo) { - this.parseInfo = parseInfo; - } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatQueryRepository.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatQueryRepository.java index 2686aa603..51099181d 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatQueryRepository.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatQueryRepository.java @@ -18,6 +18,8 @@ public interface ChatQueryRepository { QueryResp getChatQuery(Long queryId); + List getChatQueries(Integer chatId); + ChatQueryDO getChatQueryDO(Long queryId); List queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId); @@ -35,6 +37,4 @@ public interface ChatQueryRepository { List getParseInfoList(List questionIds); - List getContextualParseInfo(Integer chatId); - } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatQueryRepositoryImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatQueryRepositoryImpl.java index 108ae51e1..2f481197a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatQueryRepositoryImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatQueryRepositoryImpl.java @@ -90,6 +90,16 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { return chatQueryDOMapper.selectById(queryId); } + @Override + public List getChatQueries(Integer chatId) { + QueryWrapper queryWrapper = new QueryWrapper<>(); + queryWrapper.lambda().eq(ChatQueryDO::getChatId, chatId); + queryWrapper.lambda().orderByDesc(ChatQueryDO::getQuestionId); + return chatQueryDOMapper.selectList(queryWrapper).stream() + .map(q -> convertTo(q)) + .collect(Collectors.toList()); + } + @Override public List queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) { return showCaseCustomMapper.queryShowCase(pageQueryInfoReq.getLimitStart(), @@ -145,7 +155,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { List parses, List chatParseDOList) { for (int i = 0; i < parses.size(); i++) { ChatParseDO chatParseDO = new ChatParseDO(); - chatParseDO.setChatId(Long.valueOf(chatParseReq.getChatId())); + chatParseDO.setChatId(chatParseReq.getChatId()); chatParseDO.setQuestionId(queryId); chatParseDO.setQueryText(chatParseReq.getQueryText()); chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i))); @@ -179,17 +189,4 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { return chatParseMapper.getParseInfoList(questionIds); } - @Override - public List getContextualParseInfo(Integer chatId) { - List chatParseDOList = chatParseMapper.getContextualParseInfo(chatId); - List semanticParseInfoList = chatParseDOList.stream().map(parseInfo -> { - ParseResp parseResp = new ParseResp(parseInfo.getQueryText()); - List selectedParses = new ArrayList<>(); - selectedParses.add(JSONObject.parseObject(parseInfo.getParseInfo(), SemanticParseInfo.class)); - parseResp.setSelectedParses(selectedParses); - return parseResp; - }).collect(Collectors.toList()); - return semanticParseInfoList; - } - } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatManageService.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatManageService.java index 7c189c508..28de00f4e 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatManageService.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatManageService.java @@ -35,6 +35,8 @@ public interface ChatManageService { QueryResp getChatQuery(Long queryId); + List getChatQueries(Integer chatId); + ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId); ChatQueryDO saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatManageServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatManageServiceImpl.java index 44c5622e0..d9e24514f 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatManageServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatManageServiceImpl.java @@ -107,6 +107,13 @@ public class ChatManageServiceImpl implements ChatManageService { return chatQueryRepository.getChatQuery(queryId); } + @Override + public List getChatQueries(Integer chatId) { + List queries = chatQueryRepository.getChatQueries(chatId); + fillParseInfo(queries); + return queries; + } + @Override public ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) { ShowCaseResp showCaseResp = new ShowCaseResp(); diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/SchemaType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/SchemaType.java index e3bbfda41..878cc1fa6 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/SchemaType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/SchemaType.java @@ -2,6 +2,6 @@ package com.tencent.supersonic.headless.api.pojo.enums; public enum SchemaType { - VIEW, + DATASET, MODEL } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java index dbed60883..fdce37d4f 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java @@ -376,7 +376,7 @@ public class SchemaServiceImpl implements SchemaService { semanticSchemaResp.setModelResps(modelList); semanticSchemaResp.setModelRelas(modelRelas); semanticSchemaResp.setModelIds(modelIds); - semanticSchemaResp.setSchemaType(SchemaType.VIEW); + semanticSchemaResp.setSchemaType(SchemaType.DATASET); } else if (!CollectionUtils.isEmpty(schemaFilterReq.getModelIds())) { List modelSchemaResps = fetchModelSchemaResps(schemaFilterReq.getModelIds()); semanticSchemaResp.setMetrics(modelSchemaResps.stream().map(ModelSchemaResp::getMetrics) diff --git a/launchers/standalone/src/main/resources/config.update/sql-update.sql b/launchers/standalone/src/main/resources/config.update/sql-update.sql index fa2400bf8..8d351f0f9 100644 --- a/launchers/standalone/src/main/resources/config.update/sql-update.sql +++ b/launchers/standalone/src/main/resources/config.update/sql-update.sql @@ -359,4 +359,7 @@ alter table s2_agent add `model_config` text null; alter table s2_agent add `enable_memory_review` tinyint DEFAULT 0; --20240718 -alter table s2_chat_memory add `side_info` TEXT DEFAULT NULL COMMENT '辅助信息'; \ No newline at end of file +alter table s2_chat_memory add `side_info` TEXT DEFAULT NULL COMMENT '辅助信息'; + +--20240730 +alter table s2_chat_parse modify column `chat_id` int(11); \ No newline at end of file diff --git a/launchers/standalone/src/main/resources/db/schema-h2.sql b/launchers/standalone/src/main/resources/db/schema-h2.sql index 94b414b0a..b28202996 100644 --- a/launchers/standalone/src/main/resources/db/schema-h2.sql +++ b/launchers/standalone/src/main/resources/db/schema-h2.sql @@ -45,7 +45,7 @@ CREATE TABLE IF NOT EXISTS `s2_chat_query` CREATE TABLE IF NOT EXISTS `s2_chat_parse` ( `question_id` BIGINT NOT NULL, - `chat_id` BIGINT NOT NULL , + `chat_id` INT NOT NULL , `parse_id` INT NOT NULL , `create_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, `query_text` varchar(500), diff --git a/launchers/standalone/src/main/resources/db/schema-mysql.sql b/launchers/standalone/src/main/resources/db/schema-mysql.sql index 3cf6da9a9..4d9cc39f0 100644 --- a/launchers/standalone/src/main/resources/db/schema-mysql.sql +++ b/launchers/standalone/src/main/resources/db/schema-mysql.sql @@ -169,8 +169,8 @@ CREATE TABLE IF NOT EXISTS `s2_chat_context` ( ) ENGINE=InnoDB DEFAULT CHARSET=utf8; CREATE TABLE IF NOT EXISTS `s2_chat_parse` ( - `question_id` bigint(20) NOT NULL, - `chat_id` bigint(20) NOT NULL, + `question_id` bigint NOT NULL, + `chat_id` int(11) NOT NULL, `parse_id` int(11) NOT NULL, `create_time` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, `query_text` varchar(500) DEFAULT NULL, diff --git a/launchers/standalone/src/test/resources/db/schema-h2.sql b/launchers/standalone/src/test/resources/db/schema-h2.sql index c1e38555d..d4d413f3c 100644 --- a/launchers/standalone/src/test/resources/db/schema-h2.sql +++ b/launchers/standalone/src/test/resources/db/schema-h2.sql @@ -44,8 +44,8 @@ CREATE TABLE IF NOT EXISTS `s2_chat_query` CREATE TABLE IF NOT EXISTS `s2_chat_parse` ( - `question_id` BIGINT NOT NULL, - `chat_id` BIGINT NOT NULL , + `question_id` BIGINT NOT NULL, + `chat_id` INT NOT NULL , `parse_id` INT NOT NULL , `create_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, `query_text` varchar(500),