[improvement]Use QueryWrapper in place of hard-coded SQLs (#1944)
Some checks are pending
supersonic CentOS CI / build (11) (push) Waiting to run
supersonic CentOS CI / build (21) (push) Waiting to run
supersonic CentOS CI / build (8) (push) Waiting to run
supersonic mac CI / build (11) (push) Waiting to run
supersonic mac CI / build (21) (push) Waiting to run
supersonic mac CI / build (8) (push) Waiting to run
supersonic ubuntu CI / build (11) (push) Waiting to run
supersonic ubuntu CI / build (21) (push) Waiting to run
supersonic ubuntu CI / build (8) (push) Waiting to run
supersonic windows CI / build (11) (push) Waiting to run
supersonic windows CI / build (21) (push) Waiting to run
supersonic windows CI / build (8) (push) Waiting to run

* [improvement][launcher]Use API to get element ID avoiding hard-code.

* [fix][launcher]Fix mysql scripts.

* [improvement][launcher]Support DuckDB database and refactor translator code structure.

* [improvement][headless-fe] Revamped the interaction for semantic modeling routing and successfully implemented the switching between dimension and dataset management.

* [improvement][Headless] Add table ddl in Dbschema

* [improvement][Headless] Add get database by type

* [improvement][Headless] Supports automatic batch creation of models based on db table names.

* [improvement][Headless] Supports getting domain by bizName

* [improvement][launcher]Refactor unit tests and demo data.

* [fix][launcher]Change default vector dimension to 512.

* [improvement](Dict) add dimValueAliasMap info for KnowledgeBaseService

* [improvement][headless]Use QueryWrapper to replace hard-code SQL in mapper xml.

* [improvement][chat]Introduce ChatMemory to delegate ChatMemoryDO.

* [fix][common]Fix embedding store sys configs.

* [fix][common]Fix postgres schema, using varchar instead of char.

* [improvement][launcher]Change supersonic docker deployment from mysql to postgres.

* [Fix][launcher]Fix a number of issues related to semantic modeling.

* [Fix][headless]Fix the evaluation logic of agg type.

* [fix][assembly]Fix Dockerfile and add docker compose run script.

* [fix][chat]Fix "multiple assignments to same column "similar_queries".

* [improvement][headless]Use LamdaQueryWrapper to avoid hard-coded column names.

* [improvement][headless]Refactor headless infra to support advanced semantic modelling.

* [improvement][headless]Change class name `Dim` to `Dimension`.

* [improvement][chat]Introduce `TimeFieldMapper` to always map time field.

* [fix][headless]Remove unnecessary dimension existence check.

* [fix][chat]Fix adjusted filters don't take effect.

---------
This commit is contained in:
Jun Zhang
2024-12-08 13:32:29 +08:00
committed by GitHub
parent 0fc29304a8
commit e55f43c737
120 changed files with 844 additions and 5810 deletions

View File

@@ -4,9 +4,11 @@ import javax.validation.constraints.NotNull;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import lombok.Builder;
import lombok.Data;
@Data
@Builder
public class ChatMemoryUpdateReq {
@NotNull(message = "id不可为空")

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.server.executor;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.service.ChatContextService;
import com.tencent.supersonic.chat.server.service.MemoryService;
@@ -44,7 +44,7 @@ public class SqlExecutor implements ChatQueryExecutor {
Text2SQLExemplar.class);
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
memoryService.createMemory(ChatMemoryDO.builder()
memoryService.createMemory(ChatMemory.builder()
.agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING)
.question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo())
.dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql())

View File

@@ -1,9 +1,10 @@
package com.tencent.supersonic.chat.server.memory;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.pojo.ChatApp;
@@ -66,7 +67,7 @@ public class MemoryReviewTask {
}
ChatMemoryFilter chatMemoryFilter =
ChatMemoryFilter.builder().agentId(agent.getId()).build();
memoryService.getMemories(chatMemoryFilter).stream().forEach(memory -> {
memoryService.getMemories(chatMemoryFilter).forEach(memory -> {
try {
processMemory(memory, agent);
} catch (Exception e) {
@@ -77,23 +78,19 @@ public class MemoryReviewTask {
}
}
private void processMemory(ChatMemoryDO m, Agent agent) {
private void processMemory(ChatMemory m, Agent agent) {
if (Objects.isNull(agent)) {
log.warn("Agent id {} not found or memory review disabled", m.getAgentId());
return;
}
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
if (Objects.isNull(chatApp) || !chatApp.isEnable()) {
// if either LLM or human has reviewed, just return
if (Objects.nonNull(m.getLlmReviewRet()) || Objects.nonNull(m.getHumanReviewRet())) {
return;
}
// 如果大模型已经评估过,则不再评估
if (Objects.nonNull(m.getLlmReviewRet())) {
// directly enable memory if the LLM determines it positive
if (MemoryReviewResult.POSITIVE.equals(m.getLlmReviewRet())) {
memoryService.enableMemory(m);
}
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
if (Objects.isNull(chatApp) || !chatApp.isEnable()) {
return;
}
@@ -112,19 +109,19 @@ public class MemoryReviewTask {
}
}
private String createPromptString(ChatMemoryDO m, String promptTemplate) {
private String createPromptString(ChatMemory m, String promptTemplate) {
return String.format(promptTemplate, m.getQuestion(), m.getDbSchema(), m.getSideInfo(),
m.getS2sql());
}
private void processResponse(String response, ChatMemoryDO m) {
private void processResponse(String response, ChatMemory m) {
Matcher matcher = OUTPUT_PATTERN.matcher(response);
if (matcher.find()) {
m.setLlmReviewRet(MemoryReviewResult.getMemoryReviewResult(matcher.group(1)));
m.setLlmReviewCmt(matcher.group(2));
// directly enable memory if the LLM determines it positive
if (MemoryReviewResult.POSITIVE.equals(m.getLlmReviewRet())) {
memoryService.enableMemory(m);
m.setStatus(MemoryStatus.ENABLED);
}
memoryService.updateMemory(m);
}

View File

@@ -91,6 +91,7 @@ public class NL2SQLParser implements ChatQueryParser {
// mapModes
Set<Long> requestedDatasets = queryNLReq.getDataSetIds();
List<SemanticParseInfo> candidateParses = Lists.newArrayList();
StringBuilder errMsg = new StringBuilder();
for (Long datasetId : requestedDatasets) {
queryNLReq.setDataSetIds(Collections.singleton(datasetId));
ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId());
@@ -104,6 +105,7 @@ public class NL2SQLParser implements ChatQueryParser {
doParse(queryNLReq, parseResp);
}
if (parseResp.getSelectedParses().isEmpty()) {
errMsg.append(parseResp.getErrorMsg());
continue;
}
// for one dataset select the top 1 parse after sorting
@@ -116,6 +118,10 @@ public class NL2SQLParser implements ChatQueryParser {
SemanticParseInfo.sort(candidateParses);
parseContext.getResponse().setSelectedParses(
candidateParses.subList(0, Math.min(parserShowCount, candidateParses.size())));
if (parseContext.getResponse().getSelectedParses().isEmpty()) {
parseContext.getResponse().setState(ParseResp.ParseState.FAILED);
parseContext.getResponse().setErrorMsg(errMsg.toString());
}
}
// next go with llm-based parsers unless LLM is disabled or use feedback is needed.

View File

@@ -4,17 +4,17 @@ import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.ToString;
import lombok.NoArgsConstructor;
import java.util.Date;
@Data
@Builder
@ToString
@NoArgsConstructor
@AllArgsConstructor
@TableName("s2_chat_memory")
public class ChatMemoryDO {
@TableId(type = IdType.AUTO)
@@ -36,16 +36,16 @@ public class ChatMemoryDO {
private String s2sql;
@TableField("status")
private MemoryStatus status;
private String status;
@TableField("llm_review")
private MemoryReviewResult llmReviewRet;
private String llmReviewRet;
@TableField("llm_comment")
private String llmReviewCmt;
@TableField("human_review")
private MemoryReviewResult humanReviewRet;
private String humanReviewRet;
@TableField("human_comment")
private String humanReviewCmt;

View File

@@ -20,7 +20,6 @@ import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryReposi
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.PageUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseTimeCostResp;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;

View File

@@ -0,0 +1,48 @@
package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
import java.util.Date;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
@ToString
public class ChatMemory {
private Long id;
private Integer agentId;
private String question;
private String sideInfo;
private String dbSchema;
private String s2sql;
private MemoryStatus status;
private MemoryReviewResult llmReviewRet;
private String llmReviewCmt;
private MemoryReviewResult humanReviewRet;
private String humanReviewCmt;
private String createdBy;
private Date createdAt;
private String updatedBy;
private Date updatedAt;
}

View File

@@ -56,8 +56,7 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
private void updateChatQuery(ChatQueryDO chatQueryDO) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
UpdateWrapper<ChatQueryDO> updateWrapper = new UpdateWrapper<>();
updateWrapper.eq("question_id", chatQueryDO.getQuestionId());
updateWrapper.set("similar_queries", chatQueryDO.getSimilarQueries());
updateWrapper.lambda().eq(ChatQueryDO::getQuestionId, chatQueryDO.getQuestionId());
chatQueryRepository.updateChatQuery(chatQueryDO, updateWrapper);
}
}

View File

@@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryCreateReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
@@ -32,7 +32,7 @@ public class MemoryController {
public Boolean createMemory(@RequestBody ChatMemoryCreateReq chatMemoryCreateReq,
HttpServletRequest request, HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
memoryService.createMemory(ChatMemoryDO.builder().agentId(chatMemoryCreateReq.getAgentId())
memoryService.createMemory(ChatMemory.builder().agentId(chatMemoryCreateReq.getAgentId())
.s2sql(chatMemoryCreateReq.getS2sql()).question(chatMemoryCreateReq.getQuestion())
.dbSchema(chatMemoryCreateReq.getDbSchema()).status(chatMemoryCreateReq.getStatus())
.humanReviewRet(MemoryReviewResult.POSITIVE).createdBy(user.getName())
@@ -49,7 +49,7 @@ public class MemoryController {
}
@RequestMapping("/pageMemories")
public PageInfo<ChatMemoryDO> pageMemories(@RequestBody PageMemoryReq pageMemoryReq) {
public PageInfo<ChatMemory> pageMemories(@RequestBody PageMemoryReq pageMemoryReq) {
return memoryService.pageMemories(pageMemoryReq);
}

View File

@@ -4,27 +4,22 @@ import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.common.pojo.User;
import java.util.List;
public interface MemoryService {
void createMemory(ChatMemoryDO memory);
void createMemory(ChatMemory memory);
void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user);
void updateMemory(ChatMemoryDO memory);
void enableMemory(ChatMemoryDO memory);
void disableMemory(ChatMemoryDO memory);
void updateMemory(ChatMemory memory);
void batchDelete(List<Long> ids);
PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq);
PageInfo<ChatMemory> pageMemories(PageMemoryReq pageMemoryReq);
List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter);
List<ChatMemory> getMemories(ChatMemoryFilter chatMemoryFilter);
List<ChatMemoryDO> getMemoriesForLlmReview();
}

View File

@@ -6,8 +6,8 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.VisualConfig;
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.service.MemoryService;
@@ -121,7 +121,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
ChatMemoryFilter chatMemoryFilter =
ChatMemoryFilter.builder().agentId(agent.getId()).questions(examples).build();
List<String> memoriesExisted = memoryService.getMemories(chatMemoryFilter).stream()
.map(ChatMemoryDO::getQuestion).collect(Collectors.toList());
.map(ChatMemory::getQuestion).collect(Collectors.toList());
for (String example : examples) {
if (memoriesExisted.contains(example)) {
continue;

View File

@@ -18,11 +18,7 @@ import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.util.ComponentFactory;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.*;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.util.DateUtils;
@@ -48,11 +44,7 @@ import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList;
import net.sf.jsqlparser.expression.operators.relational.*;
import net.sf.jsqlparser.schema.Column;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
@@ -60,14 +52,7 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;
@Slf4j
@@ -210,20 +195,22 @@ public class ChatQueryServiceImpl implements ChatQueryService {
private void handleLLMQueryMode(ChatQueryDataReq chatQueryDataReq, SemanticQuery semanticQuery,
DataSetSchema dataSetSchema, User user) throws Exception {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
List<String> fields = getFieldsFromSql(parseInfo);
if (checkMetricReplace(fields, chatQueryDataReq.getMetrics())) {
log.info("llm begin replace metrics!");
String rebuiltS2SQL;
if (checkMetricReplace(chatQueryDataReq, parseInfo)) {
log.info("rebuild S2SQL with adjusted metrics!");
SchemaElement metricToReplace = chatQueryDataReq.getMetrics().iterator().next();
replaceMetrics(parseInfo, metricToReplace);
rebuiltS2SQL = replaceMetrics(parseInfo, metricToReplace);
} else {
log.info("llm begin revise filters!");
String correctorSql = reviseCorrectS2SQL(chatQueryDataReq, parseInfo, dataSetSchema);
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
semanticQuery.setParseInfo(parseInfo);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
log.info("rebuild S2SQL with adjusted filters!");
rebuiltS2SQL = replaceFilters(chatQueryDataReq, parseInfo, dataSetSchema);
}
// reset SqlInfo and request re-translation
parseInfo.getSqlInfo().setCorrectedS2SQL(rebuiltS2SQL);
parseInfo.getSqlInfo().setParsedS2SQL(rebuiltS2SQL);
parseInfo.getSqlInfo().setQuerySQL(null);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
}
private void handleRuleQueryMode(SemanticQuery semanticQuery, DataSetSchema dataSetSchema,
@@ -243,7 +230,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return queryResult;
}
private boolean checkMetricReplace(List<String> oriFields, Set<SchemaElement> metrics) {
private boolean checkMetricReplace(ChatQueryDataReq chatQueryDataReq, SemanticParseInfo parseInfo) {
List<String> oriFields = getFieldsFromSql(parseInfo);
Set<SchemaElement> metrics = chatQueryDataReq.getMetrics();
if (CollectionUtils.isEmpty(oriFields) || CollectionUtils.isEmpty(metrics)) {
return false;
}
@@ -252,8 +241,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return !oriFields.containsAll(metricNames);
}
private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo,
DataSetSchema dataSetSchema) {
private String replaceFilters(ChatQueryDataReq queryData, SemanticParseInfo parseInfo,
DataSetSchema dataSetSchema) {
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
log.info("correctorSql before replacing:{}", correctorSql);
// get where filter and having filter
@@ -290,7 +279,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return correctorSql;
}
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
private String replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
List<String> oriMetrics = parseInfo.getMetrics().stream().map(SchemaElement::getName)
.collect(Collectors.toList());
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
@@ -302,7 +291,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap);
}
log.info("after replaceMetrics:{}", correctorSql);
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
return correctorSql;
}
private QueryResult doExecution(SemanticQueryReq semanticQueryReq, String queryMode, User user)
@@ -477,6 +466,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
}
private void mergeParseInfo(SemanticParseInfo parseInfo, ChatQueryDataReq queryData) {
if (Objects.nonNull(queryData.getDateInfo())) {
parseInfo.setDateInfo(queryData.getDateInfo());
}
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
return;
}
@@ -492,9 +484,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
if (!CollectionUtils.isEmpty(queryData.getMetricFilters())) {
parseInfo.setMetricFilters(queryData.getMetricFilters());
}
if (Objects.nonNull(queryData.getDateInfo())) {
parseInfo.setDateInfo(queryData.getDateInfo());
}
parseInfo.setSqlInfo(new SqlInfo());
}

View File

@@ -3,12 +3,14 @@ package com.tencent.supersonic.chat.server.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
@@ -16,12 +18,15 @@ import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.common.util.BeanMapper;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.Date;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@Service
public class MemoryServiceImpl implements MemoryService {
@@ -36,20 +41,22 @@ public class MemoryServiceImpl implements MemoryService {
private EmbeddingConfig embeddingConfig;
@Override
public void createMemory(ChatMemoryDO memory) {
public void createMemory(ChatMemory memory) {
// if an existing enabled memory has the same question, just skip
List<ChatMemoryDO> memories =
List<ChatMemory> memories =
getMemories(ChatMemoryFilter.builder().agentId(memory.getAgentId())
.question(memory.getQuestion()).status(MemoryStatus.ENABLED).build());
if (memories.size() == 0) {
chatMemoryRepository.createMemory(memory);
if (memories.isEmpty()) {
ChatMemoryDO memoryDO = getMemoryDO(memory);
chatMemoryRepository.createMemory(memoryDO);
}
}
@Override
public void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user) {
ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId());
boolean hadEnabled = MemoryStatus.ENABLED.equals(chatMemoryDO.getStatus());
boolean hadEnabled =
MemoryStatus.ENABLED.toString().equals(chatMemoryDO.getStatus().trim());
chatMemoryDO.setUpdatedBy(user.getName());
chatMemoryDO.setUpdatedAt(new Date());
BeanMapper.mapper(chatMemoryUpdateReq, chatMemoryDO);
@@ -58,12 +65,12 @@ public class MemoryServiceImpl implements MemoryService {
} else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) {
disableMemory(chatMemoryDO);
}
updateMemory(chatMemoryDO);
chatMemoryRepository.updateMemory(chatMemoryDO);
}
@Override
public void updateMemory(ChatMemoryDO memory) {
chatMemoryRepository.updateMemory(memory);
public void updateMemory(ChatMemory memory) {
chatMemoryRepository.updateMemory(getMemoryDO(memory));
}
@Override
@@ -72,7 +79,7 @@ public class MemoryServiceImpl implements MemoryService {
}
@Override
public PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq) {
public PageInfo<ChatMemory> pageMemories(PageMemoryReq pageMemoryReq) {
ChatMemoryFilter chatMemoryFilter = pageMemoryReq.getChatMemoryFilter();
chatMemoryFilter.setSort(pageMemoryReq.getSort());
chatMemoryFilter.setOrderCondition(pageMemoryReq.getOrderCondition());
@@ -81,7 +88,7 @@ public class MemoryServiceImpl implements MemoryService {
}
@Override
public List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter) {
public List<ChatMemory> getMemories(ChatMemoryFilter chatMemoryFilter) {
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
if (chatMemoryFilter.getAgentId() != null) {
queryWrapper.lambda().eq(ChatMemoryDO::getAgentId, chatMemoryFilter.getAgentId());
@@ -109,32 +116,52 @@ public class MemoryServiceImpl implements MemoryService {
queryWrapper.orderBy(true, chatMemoryFilter.isAsc(),
chatMemoryFilter.getOrderCondition());
}
return chatMemoryRepository.getMemories(queryWrapper);
List<ChatMemoryDO> chatMemoryDOS = chatMemoryRepository.getMemories(queryWrapper);
return chatMemoryDOS.stream().map(this::getMemory).collect(Collectors.toList());
}
@Override
public List<ChatMemoryDO> getMemoriesForLlmReview() {
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
queryWrapper.lambda().eq(ChatMemoryDO::getStatus, MemoryStatus.PENDING)
.isNull(ChatMemoryDO::getLlmReviewRet);
return chatMemoryRepository.getMemories(queryWrapper);
}
@Override
public void enableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.ENABLED);
private void enableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.ENABLED.toString());
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
Text2SQLExemplar.builder().question(memory.getQuestion())
.sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema())
.sql(memory.getS2sql()).build());
}
@Override
public void disableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.DISABLED);
private void disableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.DISABLED.toString());
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
Text2SQLExemplar.builder().question(memory.getQuestion())
.sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema())
.sql(memory.getS2sql()).build());
}
private ChatMemoryDO getMemoryDO(ChatMemory memory) {
ChatMemoryDO memoryDO = new ChatMemoryDO();
BeanUtils.copyProperties(memory, memoryDO);
memoryDO.setStatus(memory.getStatus().toString().trim());
if (Objects.nonNull(memory.getHumanReviewRet())) {
memoryDO.setHumanReviewRet(memory.getHumanReviewRet().toString().trim());
}
if (Objects.nonNull(memory.getLlmReviewRet())) {
memoryDO.setLlmReviewRet(memory.getLlmReviewRet().toString().trim());
}
return memoryDO;
}
private ChatMemory getMemory(ChatMemoryDO memoryDO) {
ChatMemory memory = new ChatMemory();
BeanUtils.copyProperties(memoryDO, memory);
memory.setStatus(MemoryStatus.valueOf(memoryDO.getStatus().trim()));
if (Objects.nonNull(memoryDO.getHumanReviewRet())) {
memory.setHumanReviewRet(
MemoryReviewResult.valueOf(memoryDO.getHumanReviewRet().trim()));
}
if (Objects.nonNull(memoryDO.getLlmReviewRet())) {
memory.setLlmReviewRet(MemoryReviewResult.valueOf(memoryDO.getLlmReviewRet().trim()));
}
return memory;
}
}