mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
[improvement]Use QueryWrapper in place of hard-coded SQLs (#1944)
Some checks are pending
supersonic CentOS CI / build (11) (push) Waiting to run
supersonic CentOS CI / build (21) (push) Waiting to run
supersonic CentOS CI / build (8) (push) Waiting to run
supersonic mac CI / build (11) (push) Waiting to run
supersonic mac CI / build (21) (push) Waiting to run
supersonic mac CI / build (8) (push) Waiting to run
supersonic ubuntu CI / build (11) (push) Waiting to run
supersonic ubuntu CI / build (21) (push) Waiting to run
supersonic ubuntu CI / build (8) (push) Waiting to run
supersonic windows CI / build (11) (push) Waiting to run
supersonic windows CI / build (21) (push) Waiting to run
supersonic windows CI / build (8) (push) Waiting to run
Some checks are pending
supersonic CentOS CI / build (11) (push) Waiting to run
supersonic CentOS CI / build (21) (push) Waiting to run
supersonic CentOS CI / build (8) (push) Waiting to run
supersonic mac CI / build (11) (push) Waiting to run
supersonic mac CI / build (21) (push) Waiting to run
supersonic mac CI / build (8) (push) Waiting to run
supersonic ubuntu CI / build (11) (push) Waiting to run
supersonic ubuntu CI / build (21) (push) Waiting to run
supersonic ubuntu CI / build (8) (push) Waiting to run
supersonic windows CI / build (11) (push) Waiting to run
supersonic windows CI / build (21) (push) Waiting to run
supersonic windows CI / build (8) (push) Waiting to run
* [improvement][launcher]Use API to get element ID avoiding hard-code. * [fix][launcher]Fix mysql scripts. * [improvement][launcher]Support DuckDB database and refactor translator code structure. * [improvement][headless-fe] Revamped the interaction for semantic modeling routing and successfully implemented the switching between dimension and dataset management. * [improvement][Headless] Add table ddl in Dbschema * [improvement][Headless] Add get database by type * [improvement][Headless] Supports automatic batch creation of models based on db table names. * [improvement][Headless] Supports getting domain by bizName * [improvement][launcher]Refactor unit tests and demo data. * [fix][launcher]Change default vector dimension to 512. * [improvement](Dict) add dimValueAliasMap info for KnowledgeBaseService * [improvement][headless]Use QueryWrapper to replace hard-code SQL in mapper xml. * [improvement][chat]Introduce ChatMemory to delegate ChatMemoryDO. * [fix][common]Fix embedding store sys configs. * [fix][common]Fix postgres schema, using varchar instead of char. * [improvement][launcher]Change supersonic docker deployment from mysql to postgres. * [Fix][launcher]Fix a number of issues related to semantic modeling. * [Fix][headless]Fix the evaluation logic of agg type. * [fix][assembly]Fix Dockerfile and add docker compose run script. * [fix][chat]Fix "multiple assignments to same column "similar_queries". * [improvement][headless]Use LamdaQueryWrapper to avoid hard-coded column names. * [improvement][headless]Refactor headless infra to support advanced semantic modelling. * [improvement][headless]Change class name `Dim` to `Dimension`. * [improvement][chat]Introduce `TimeFieldMapper` to always map time field. * [fix][headless]Remove unnecessary dimension existence check. * [fix][chat]Fix adjusted filters don't take effect. ---------
This commit is contained in:
@@ -4,9 +4,11 @@ import javax.validation.constraints.NotNull;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
public class ChatMemoryUpdateReq {
|
||||
|
||||
@NotNull(message = "id不可为空")
|
||||
|
||||
@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
@@ -44,7 +44,7 @@ public class SqlExecutor implements ChatQueryExecutor {
|
||||
Text2SQLExemplar.class);
|
||||
|
||||
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
|
||||
memoryService.createMemory(ChatMemoryDO.builder()
|
||||
memoryService.createMemory(ChatMemory.builder()
|
||||
.agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING)
|
||||
.question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo())
|
||||
.dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql())
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.memory;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
@@ -66,7 +67,7 @@ public class MemoryReviewTask {
|
||||
}
|
||||
ChatMemoryFilter chatMemoryFilter =
|
||||
ChatMemoryFilter.builder().agentId(agent.getId()).build();
|
||||
memoryService.getMemories(chatMemoryFilter).stream().forEach(memory -> {
|
||||
memoryService.getMemories(chatMemoryFilter).forEach(memory -> {
|
||||
try {
|
||||
processMemory(memory, agent);
|
||||
} catch (Exception e) {
|
||||
@@ -77,23 +78,19 @@ public class MemoryReviewTask {
|
||||
}
|
||||
}
|
||||
|
||||
private void processMemory(ChatMemoryDO m, Agent agent) {
|
||||
private void processMemory(ChatMemory m, Agent agent) {
|
||||
if (Objects.isNull(agent)) {
|
||||
log.warn("Agent id {} not found or memory review disabled", m.getAgentId());
|
||||
return;
|
||||
}
|
||||
|
||||
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
|
||||
if (Objects.isNull(chatApp) || !chatApp.isEnable()) {
|
||||
// if either LLM or human has reviewed, just return
|
||||
if (Objects.nonNull(m.getLlmReviewRet()) || Objects.nonNull(m.getHumanReviewRet())) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 如果大模型已经评估过,则不再评估
|
||||
if (Objects.nonNull(m.getLlmReviewRet())) {
|
||||
// directly enable memory if the LLM determines it positive
|
||||
if (MemoryReviewResult.POSITIVE.equals(m.getLlmReviewRet())) {
|
||||
memoryService.enableMemory(m);
|
||||
}
|
||||
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
|
||||
if (Objects.isNull(chatApp) || !chatApp.isEnable()) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -112,19 +109,19 @@ public class MemoryReviewTask {
|
||||
}
|
||||
}
|
||||
|
||||
private String createPromptString(ChatMemoryDO m, String promptTemplate) {
|
||||
private String createPromptString(ChatMemory m, String promptTemplate) {
|
||||
return String.format(promptTemplate, m.getQuestion(), m.getDbSchema(), m.getSideInfo(),
|
||||
m.getS2sql());
|
||||
}
|
||||
|
||||
private void processResponse(String response, ChatMemoryDO m) {
|
||||
private void processResponse(String response, ChatMemory m) {
|
||||
Matcher matcher = OUTPUT_PATTERN.matcher(response);
|
||||
if (matcher.find()) {
|
||||
m.setLlmReviewRet(MemoryReviewResult.getMemoryReviewResult(matcher.group(1)));
|
||||
m.setLlmReviewCmt(matcher.group(2));
|
||||
// directly enable memory if the LLM determines it positive
|
||||
if (MemoryReviewResult.POSITIVE.equals(m.getLlmReviewRet())) {
|
||||
memoryService.enableMemory(m);
|
||||
m.setStatus(MemoryStatus.ENABLED);
|
||||
}
|
||||
memoryService.updateMemory(m);
|
||||
}
|
||||
|
||||
@@ -91,6 +91,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
// mapModes
|
||||
Set<Long> requestedDatasets = queryNLReq.getDataSetIds();
|
||||
List<SemanticParseInfo> candidateParses = Lists.newArrayList();
|
||||
StringBuilder errMsg = new StringBuilder();
|
||||
for (Long datasetId : requestedDatasets) {
|
||||
queryNLReq.setDataSetIds(Collections.singleton(datasetId));
|
||||
ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId());
|
||||
@@ -104,6 +105,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
doParse(queryNLReq, parseResp);
|
||||
}
|
||||
if (parseResp.getSelectedParses().isEmpty()) {
|
||||
errMsg.append(parseResp.getErrorMsg());
|
||||
continue;
|
||||
}
|
||||
// for one dataset select the top 1 parse after sorting
|
||||
@@ -116,6 +118,10 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
SemanticParseInfo.sort(candidateParses);
|
||||
parseContext.getResponse().setSelectedParses(
|
||||
candidateParses.subList(0, Math.min(parserShowCount, candidateParses.size())));
|
||||
if (parseContext.getResponse().getSelectedParses().isEmpty()) {
|
||||
parseContext.getResponse().setState(ParseResp.ParseState.FAILED);
|
||||
parseContext.getResponse().setErrorMsg(errMsg.toString());
|
||||
}
|
||||
}
|
||||
|
||||
// next go with llm-based parsers unless LLM is disabled or use feedback is needed.
|
||||
|
||||
@@ -4,17 +4,17 @@ import com.baomidou.mybatisplus.annotation.IdType;
|
||||
import com.baomidou.mybatisplus.annotation.TableField;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@ToString
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@TableName("s2_chat_memory")
|
||||
public class ChatMemoryDO {
|
||||
@TableId(type = IdType.AUTO)
|
||||
@@ -36,16 +36,16 @@ public class ChatMemoryDO {
|
||||
private String s2sql;
|
||||
|
||||
@TableField("status")
|
||||
private MemoryStatus status;
|
||||
private String status;
|
||||
|
||||
@TableField("llm_review")
|
||||
private MemoryReviewResult llmReviewRet;
|
||||
private String llmReviewRet;
|
||||
|
||||
@TableField("llm_comment")
|
||||
private String llmReviewCmt;
|
||||
|
||||
@TableField("human_review")
|
||||
private MemoryReviewResult humanReviewRet;
|
||||
private String humanReviewRet;
|
||||
|
||||
@TableField("human_comment")
|
||||
private String humanReviewCmt;
|
||||
|
||||
@@ -20,7 +20,6 @@ import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryReposi
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.PageUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseTimeCostResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
package com.tencent.supersonic.chat.server.pojo;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
@ToString
|
||||
public class ChatMemory {
|
||||
private Long id;
|
||||
|
||||
private Integer agentId;
|
||||
|
||||
private String question;
|
||||
|
||||
private String sideInfo;
|
||||
|
||||
private String dbSchema;
|
||||
|
||||
private String s2sql;
|
||||
|
||||
private MemoryStatus status;
|
||||
|
||||
private MemoryReviewResult llmReviewRet;
|
||||
|
||||
private String llmReviewCmt;
|
||||
|
||||
private MemoryReviewResult humanReviewRet;
|
||||
|
||||
private String humanReviewCmt;
|
||||
|
||||
private String createdBy;
|
||||
|
||||
private Date createdAt;
|
||||
|
||||
private String updatedBy;
|
||||
|
||||
private Date updatedAt;
|
||||
}
|
||||
@@ -56,8 +56,7 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
private void updateChatQuery(ChatQueryDO chatQueryDO) {
|
||||
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
|
||||
UpdateWrapper<ChatQueryDO> updateWrapper = new UpdateWrapper<>();
|
||||
updateWrapper.eq("question_id", chatQueryDO.getQuestionId());
|
||||
updateWrapper.set("similar_queries", chatQueryDO.getSimilarQueries());
|
||||
updateWrapper.lambda().eq(ChatQueryDO::getQuestionId, chatQueryDO.getQuestionId());
|
||||
chatQueryRepository.updateChatQuery(chatQueryDO, updateWrapper);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryCreateReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
|
||||
@@ -32,7 +32,7 @@ public class MemoryController {
|
||||
public Boolean createMemory(@RequestBody ChatMemoryCreateReq chatMemoryCreateReq,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
memoryService.createMemory(ChatMemoryDO.builder().agentId(chatMemoryCreateReq.getAgentId())
|
||||
memoryService.createMemory(ChatMemory.builder().agentId(chatMemoryCreateReq.getAgentId())
|
||||
.s2sql(chatMemoryCreateReq.getS2sql()).question(chatMemoryCreateReq.getQuestion())
|
||||
.dbSchema(chatMemoryCreateReq.getDbSchema()).status(chatMemoryCreateReq.getStatus())
|
||||
.humanReviewRet(MemoryReviewResult.POSITIVE).createdBy(user.getName())
|
||||
@@ -49,7 +49,7 @@ public class MemoryController {
|
||||
}
|
||||
|
||||
@RequestMapping("/pageMemories")
|
||||
public PageInfo<ChatMemoryDO> pageMemories(@RequestBody PageMemoryReq pageMemoryReq) {
|
||||
public PageInfo<ChatMemory> pageMemories(@RequestBody PageMemoryReq pageMemoryReq) {
|
||||
return memoryService.pageMemories(pageMemoryReq);
|
||||
}
|
||||
|
||||
|
||||
@@ -4,27 +4,22 @@ import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface MemoryService {
|
||||
void createMemory(ChatMemoryDO memory);
|
||||
void createMemory(ChatMemory memory);
|
||||
|
||||
void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user);
|
||||
|
||||
void updateMemory(ChatMemoryDO memory);
|
||||
|
||||
void enableMemory(ChatMemoryDO memory);
|
||||
|
||||
void disableMemory(ChatMemoryDO memory);
|
||||
void updateMemory(ChatMemory memory);
|
||||
|
||||
void batchDelete(List<Long> ids);
|
||||
|
||||
PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq);
|
||||
PageInfo<ChatMemory> pageMemories(PageMemoryReq pageMemoryReq);
|
||||
|
||||
List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter);
|
||||
List<ChatMemory> getMemories(ChatMemoryFilter chatMemoryFilter);
|
||||
|
||||
List<ChatMemoryDO> getMemoriesForLlmReview();
|
||||
}
|
||||
|
||||
@@ -6,8 +6,8 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.VisualConfig;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
@@ -121,7 +121,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
ChatMemoryFilter chatMemoryFilter =
|
||||
ChatMemoryFilter.builder().agentId(agent.getId()).questions(examples).build();
|
||||
List<String> memoriesExisted = memoryService.getMemories(chatMemoryFilter).stream()
|
||||
.map(ChatMemoryDO::getQuestion).collect(Collectors.toList());
|
||||
.map(ChatMemory::getQuestion).collect(Collectors.toList());
|
||||
for (String example : examples) {
|
||||
if (memoriesExisted.contains(example)) {
|
||||
continue;
|
||||
|
||||
@@ -18,11 +18,7 @@ import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.*;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
@@ -48,11 +44,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.LongValue;
|
||||
import net.sf.jsqlparser.expression.StringValue;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.InExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList;
|
||||
import net.sf.jsqlparser.expression.operators.relational.*;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
@@ -60,14 +52,7 @@ import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@@ -210,20 +195,22 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
private void handleLLMQueryMode(ChatQueryDataReq chatQueryDataReq, SemanticQuery semanticQuery,
|
||||
DataSetSchema dataSetSchema, User user) throws Exception {
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
List<String> fields = getFieldsFromSql(parseInfo);
|
||||
if (checkMetricReplace(fields, chatQueryDataReq.getMetrics())) {
|
||||
log.info("llm begin replace metrics!");
|
||||
String rebuiltS2SQL;
|
||||
if (checkMetricReplace(chatQueryDataReq, parseInfo)) {
|
||||
log.info("rebuild S2SQL with adjusted metrics!");
|
||||
SchemaElement metricToReplace = chatQueryDataReq.getMetrics().iterator().next();
|
||||
replaceMetrics(parseInfo, metricToReplace);
|
||||
rebuiltS2SQL = replaceMetrics(parseInfo, metricToReplace);
|
||||
} else {
|
||||
log.info("llm begin revise filters!");
|
||||
String correctorSql = reviseCorrectS2SQL(chatQueryDataReq, parseInfo, dataSetSchema);
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
|
||||
semanticQuery.setParseInfo(parseInfo);
|
||||
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
|
||||
SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
|
||||
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
|
||||
log.info("rebuild S2SQL with adjusted filters!");
|
||||
rebuiltS2SQL = replaceFilters(chatQueryDataReq, parseInfo, dataSetSchema);
|
||||
}
|
||||
// reset SqlInfo and request re-translation
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(rebuiltS2SQL);
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(rebuiltS2SQL);
|
||||
parseInfo.getSqlInfo().setQuerySQL(null);
|
||||
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
|
||||
SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
|
||||
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
|
||||
}
|
||||
|
||||
private void handleRuleQueryMode(SemanticQuery semanticQuery, DataSetSchema dataSetSchema,
|
||||
@@ -243,7 +230,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
private boolean checkMetricReplace(List<String> oriFields, Set<SchemaElement> metrics) {
|
||||
private boolean checkMetricReplace(ChatQueryDataReq chatQueryDataReq, SemanticParseInfo parseInfo) {
|
||||
List<String> oriFields = getFieldsFromSql(parseInfo);
|
||||
Set<SchemaElement> metrics = chatQueryDataReq.getMetrics();
|
||||
if (CollectionUtils.isEmpty(oriFields) || CollectionUtils.isEmpty(metrics)) {
|
||||
return false;
|
||||
}
|
||||
@@ -252,8 +241,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
return !oriFields.containsAll(metricNames);
|
||||
}
|
||||
|
||||
private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo,
|
||||
DataSetSchema dataSetSchema) {
|
||||
private String replaceFilters(ChatQueryDataReq queryData, SemanticParseInfo parseInfo,
|
||||
DataSetSchema dataSetSchema) {
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
log.info("correctorSql before replacing:{}", correctorSql);
|
||||
// get where filter and having filter
|
||||
@@ -290,7 +279,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
return correctorSql;
|
||||
}
|
||||
|
||||
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
|
||||
private String replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
|
||||
List<String> oriMetrics = parseInfo.getMetrics().stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toList());
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
@@ -302,7 +291,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap);
|
||||
}
|
||||
log.info("after replaceMetrics:{}", correctorSql);
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
|
||||
return correctorSql;
|
||||
}
|
||||
|
||||
private QueryResult doExecution(SemanticQueryReq semanticQueryReq, String queryMode, User user)
|
||||
@@ -477,6 +466,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
}
|
||||
|
||||
private void mergeParseInfo(SemanticParseInfo parseInfo, ChatQueryDataReq queryData) {
|
||||
if (Objects.nonNull(queryData.getDateInfo())) {
|
||||
parseInfo.setDateInfo(queryData.getDateInfo());
|
||||
}
|
||||
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
||||
return;
|
||||
}
|
||||
@@ -492,9 +484,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
if (!CollectionUtils.isEmpty(queryData.getMetricFilters())) {
|
||||
parseInfo.setMetricFilters(queryData.getMetricFilters());
|
||||
}
|
||||
if (Objects.nonNull(queryData.getDateInfo())) {
|
||||
parseInfo.setDateInfo(queryData.getDateInfo());
|
||||
}
|
||||
|
||||
parseInfo.setSqlInfo(new SqlInfo());
|
||||
}
|
||||
|
||||
|
||||
@@ -3,12 +3,14 @@ package com.tencent.supersonic.chat.server.service.impl;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.github.pagehelper.PageHelper;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
@@ -16,12 +18,15 @@ import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
public class MemoryServiceImpl implements MemoryService {
|
||||
@@ -36,20 +41,22 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
@Override
|
||||
public void createMemory(ChatMemoryDO memory) {
|
||||
public void createMemory(ChatMemory memory) {
|
||||
// if an existing enabled memory has the same question, just skip
|
||||
List<ChatMemoryDO> memories =
|
||||
List<ChatMemory> memories =
|
||||
getMemories(ChatMemoryFilter.builder().agentId(memory.getAgentId())
|
||||
.question(memory.getQuestion()).status(MemoryStatus.ENABLED).build());
|
||||
if (memories.size() == 0) {
|
||||
chatMemoryRepository.createMemory(memory);
|
||||
if (memories.isEmpty()) {
|
||||
ChatMemoryDO memoryDO = getMemoryDO(memory);
|
||||
chatMemoryRepository.createMemory(memoryDO);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user) {
|
||||
ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId());
|
||||
boolean hadEnabled = MemoryStatus.ENABLED.equals(chatMemoryDO.getStatus());
|
||||
boolean hadEnabled =
|
||||
MemoryStatus.ENABLED.toString().equals(chatMemoryDO.getStatus().trim());
|
||||
chatMemoryDO.setUpdatedBy(user.getName());
|
||||
chatMemoryDO.setUpdatedAt(new Date());
|
||||
BeanMapper.mapper(chatMemoryUpdateReq, chatMemoryDO);
|
||||
@@ -58,12 +65,12 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
} else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) {
|
||||
disableMemory(chatMemoryDO);
|
||||
}
|
||||
updateMemory(chatMemoryDO);
|
||||
chatMemoryRepository.updateMemory(chatMemoryDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateMemory(ChatMemoryDO memory) {
|
||||
chatMemoryRepository.updateMemory(memory);
|
||||
public void updateMemory(ChatMemory memory) {
|
||||
chatMemoryRepository.updateMemory(getMemoryDO(memory));
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -72,7 +79,7 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq) {
|
||||
public PageInfo<ChatMemory> pageMemories(PageMemoryReq pageMemoryReq) {
|
||||
ChatMemoryFilter chatMemoryFilter = pageMemoryReq.getChatMemoryFilter();
|
||||
chatMemoryFilter.setSort(pageMemoryReq.getSort());
|
||||
chatMemoryFilter.setOrderCondition(pageMemoryReq.getOrderCondition());
|
||||
@@ -81,7 +88,7 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter) {
|
||||
public List<ChatMemory> getMemories(ChatMemoryFilter chatMemoryFilter) {
|
||||
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
|
||||
if (chatMemoryFilter.getAgentId() != null) {
|
||||
queryWrapper.lambda().eq(ChatMemoryDO::getAgentId, chatMemoryFilter.getAgentId());
|
||||
@@ -109,32 +116,52 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
queryWrapper.orderBy(true, chatMemoryFilter.isAsc(),
|
||||
chatMemoryFilter.getOrderCondition());
|
||||
}
|
||||
return chatMemoryRepository.getMemories(queryWrapper);
|
||||
List<ChatMemoryDO> chatMemoryDOS = chatMemoryRepository.getMemories(queryWrapper);
|
||||
return chatMemoryDOS.stream().map(this::getMemory).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMemoryDO> getMemoriesForLlmReview() {
|
||||
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
|
||||
queryWrapper.lambda().eq(ChatMemoryDO::getStatus, MemoryStatus.PENDING)
|
||||
.isNull(ChatMemoryDO::getLlmReviewRet);
|
||||
return chatMemoryRepository.getMemories(queryWrapper);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void enableMemory(ChatMemoryDO memory) {
|
||||
memory.setStatus(MemoryStatus.ENABLED);
|
||||
private void enableMemory(ChatMemoryDO memory) {
|
||||
memory.setStatus(MemoryStatus.ENABLED.toString());
|
||||
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
Text2SQLExemplar.builder().question(memory.getQuestion())
|
||||
.sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema())
|
||||
.sql(memory.getS2sql()).build());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void disableMemory(ChatMemoryDO memory) {
|
||||
memory.setStatus(MemoryStatus.DISABLED);
|
||||
private void disableMemory(ChatMemoryDO memory) {
|
||||
memory.setStatus(MemoryStatus.DISABLED.toString());
|
||||
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
Text2SQLExemplar.builder().question(memory.getQuestion())
|
||||
.sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema())
|
||||
.sql(memory.getS2sql()).build());
|
||||
}
|
||||
|
||||
private ChatMemoryDO getMemoryDO(ChatMemory memory) {
|
||||
ChatMemoryDO memoryDO = new ChatMemoryDO();
|
||||
BeanUtils.copyProperties(memory, memoryDO);
|
||||
memoryDO.setStatus(memory.getStatus().toString().trim());
|
||||
if (Objects.nonNull(memory.getHumanReviewRet())) {
|
||||
memoryDO.setHumanReviewRet(memory.getHumanReviewRet().toString().trim());
|
||||
}
|
||||
if (Objects.nonNull(memory.getLlmReviewRet())) {
|
||||
memoryDO.setLlmReviewRet(memory.getLlmReviewRet().toString().trim());
|
||||
}
|
||||
|
||||
return memoryDO;
|
||||
}
|
||||
|
||||
private ChatMemory getMemory(ChatMemoryDO memoryDO) {
|
||||
ChatMemory memory = new ChatMemory();
|
||||
BeanUtils.copyProperties(memoryDO, memory);
|
||||
memory.setStatus(MemoryStatus.valueOf(memoryDO.getStatus().trim()));
|
||||
if (Objects.nonNull(memoryDO.getHumanReviewRet())) {
|
||||
memory.setHumanReviewRet(
|
||||
MemoryReviewResult.valueOf(memoryDO.getHumanReviewRet().trim()));
|
||||
}
|
||||
if (Objects.nonNull(memoryDO.getLlmReviewRet())) {
|
||||
memory.setLlmReviewRet(MemoryReviewResult.valueOf(memoryDO.getLlmReviewRet().trim()));
|
||||
}
|
||||
return memory;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user