mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
(improvement)(build) Add spotless during the build process. (#1639)
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
@@ -28,10 +27,9 @@ public class Agent extends RecordInfo {
|
||||
private String name;
|
||||
private String description;
|
||||
|
||||
/**
|
||||
* 0 offline, 1 online
|
||||
*/
|
||||
/** 0 offline, 1 online */
|
||||
private Integer status;
|
||||
|
||||
private List<String> examples;
|
||||
private String agentConfig;
|
||||
private ChatModelConfig modelConfig;
|
||||
@@ -46,13 +44,13 @@ public class Agent extends RecordInfo {
|
||||
}
|
||||
List<Map> toolList = (List) map.get("tools");
|
||||
return toolList.stream()
|
||||
.filter(tool -> {
|
||||
.filter(
|
||||
tool -> {
|
||||
if (Objects.isNull(type)) {
|
||||
return true;
|
||||
}
|
||||
return type.name().equals(tool.get("type"));
|
||||
}
|
||||
)
|
||||
})
|
||||
.map(JSONObject::toJSONString)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
@@ -74,7 +72,8 @@ public class Agent extends RecordInfo {
|
||||
if (CollectionUtils.isEmpty(tools)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return tools.stream().map(tool -> JSONObject.parseObject(tool, NL2SQLTool.class))
|
||||
return tools.stream()
|
||||
.map(tool -> JSONObject.parseObject(tool, NL2SQLTool.class))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@@ -121,7 +120,8 @@ public class Agent extends RecordInfo {
|
||||
if (CollectionUtils.isEmpty(commonAgentTools)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
return commonAgentTools.stream().map(NL2SQLTool::getDataSetIds)
|
||||
return commonAgentTools.stream()
|
||||
.map(NL2SQLTool::getDataSetIds)
|
||||
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
|
||||
.flatMap(Collection::stream)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
@@ -4,8 +4,8 @@ import com.google.common.collect.Lists;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import java.util.List;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@@ -13,5 +13,4 @@ import java.util.List;
|
||||
public class AgentConfig {
|
||||
|
||||
List<AgentTool> tools = Lists.newArrayList();
|
||||
|
||||
}
|
||||
|
||||
@@ -21,5 +21,4 @@ public enum AgentToolType {
|
||||
map.put(PLUGIN, PLUGIN.title);
|
||||
return map;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -8,5 +8,4 @@ import java.util.List;
|
||||
public class LLMParserTool extends NL2SQLTool {
|
||||
|
||||
private List<String> exampleQuestions;
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
@@ -11,5 +10,4 @@ import lombok.NoArgsConstructor;
|
||||
public class MultiTurnConfig {
|
||||
|
||||
private boolean enableMultiTurn;
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
@@ -13,5 +12,4 @@ import java.util.List;
|
||||
public class NL2SQLTool extends AgentTool {
|
||||
|
||||
protected List<Long> dataSetIds;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
@@ -9,5 +8,4 @@ import java.util.List;
|
||||
public class PluginTool extends AgentTool {
|
||||
|
||||
private List<Long> plugins;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@@ -9,7 +8,6 @@ import java.util.List;
|
||||
@Data
|
||||
public class RuleParserTool extends NL2SQLTool {
|
||||
|
||||
|
||||
private List<String> queryModes;
|
||||
|
||||
private List<String> queryTypes;
|
||||
@@ -17,5 +15,4 @@ public class RuleParserTool extends NL2SQLTool {
|
||||
public boolean isContainsAllModel() {
|
||||
return CollectionUtils.isNotEmpty(dataSetIds) && dataSetIds.contains(-1L);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@ package com.tencent.supersonic.chat.server.config;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
|
||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@@ -14,33 +14,22 @@ import java.util.List;
|
||||
@ToString
|
||||
public class ChatConfig {
|
||||
|
||||
/**
|
||||
* database auto-increment primary key
|
||||
*/
|
||||
/** database auto-increment primary key */
|
||||
private Long id;
|
||||
|
||||
private Long modelId;
|
||||
|
||||
/**
|
||||
* the chatDetailConfig about the model
|
||||
*/
|
||||
/** the chatDetailConfig about the model */
|
||||
private ChatDetailConfigReq chatDetailConfig;
|
||||
|
||||
/**
|
||||
* the chatAggConfig about the model
|
||||
*/
|
||||
/** the chatAggConfig about the model */
|
||||
private ChatAggConfigReq chatAggConfig;
|
||||
|
||||
private List<RecommendedQuestionReq> recommendedQuestions;
|
||||
|
||||
/**
|
||||
* available status
|
||||
*/
|
||||
/** available status */
|
||||
private StatusEnum status;
|
||||
|
||||
/**
|
||||
* about createdBy, createdAt, updatedBy, updatedAt
|
||||
*/
|
||||
/** about createdBy, createdAt, updatedBy, updatedAt */
|
||||
private RecordInfo recordInfo;
|
||||
|
||||
}
|
||||
|
||||
@@ -8,4 +8,4 @@ public class ChatConfigFilterInternal {
|
||||
private Long id;
|
||||
private Long modelId;
|
||||
private Integer status;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
|
||||
public interface ChatQueryExecutor {
|
||||
|
||||
QueryResult execute(ExecuteContext executeContext);
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.parser.ParserConfig;
|
||||
@@ -8,7 +9,6 @@ import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
@@ -26,13 +26,14 @@ import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULT
|
||||
|
||||
public class PlainTextExecutor implements ChatQueryExecutor {
|
||||
|
||||
private static final String INSTRUCTION = ""
|
||||
+ "#Role: You are a nice person to talk to.\n"
|
||||
+ "#Task: Respond quickly and nicely to the user."
|
||||
+ "#Rules: 1.ALWAYS use the same language as the input.\n"
|
||||
+ "#History Inputs: %s\n"
|
||||
+ "#Current Input: %s\n"
|
||||
+ "#Your response: ";
|
||||
private static final String INSTRUCTION =
|
||||
""
|
||||
+ "#Role: You are a nice person to talk to.\n"
|
||||
+ "#Task: Respond quickly and nicely to the user."
|
||||
+ "#Rules: 1.ALWAYS use the same language as the input.\n"
|
||||
+ "#History Inputs: %s\n"
|
||||
+ "#Current Input: %s\n"
|
||||
+ "#Your response: ";
|
||||
|
||||
@Override
|
||||
public QueryResult execute(ExecuteContext executeContext) {
|
||||
@@ -40,14 +41,18 @@ public class PlainTextExecutor implements ChatQueryExecutor {
|
||||
return null;
|
||||
}
|
||||
|
||||
String promptStr = String.format(INSTRUCTION, getHistoryInputs(executeContext),
|
||||
executeContext.getQueryText());
|
||||
String promptStr =
|
||||
String.format(
|
||||
INSTRUCTION,
|
||||
getHistoryInputs(executeContext),
|
||||
executeContext.getQueryText());
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatAgent.getModelConfig());
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(chatAgent.getModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
|
||||
QueryResult result = new QueryResult();
|
||||
@@ -66,16 +71,21 @@ public class PlainTextExecutor implements ChatQueryExecutor {
|
||||
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
MultiTurnConfig agentMultiTurnConfig = chatAgent.getMultiTurnConfig();
|
||||
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
||||
Boolean multiTurnConfig = agentMultiTurnConfig != null
|
||||
? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig;
|
||||
Boolean globalMultiTurnConfig =
|
||||
Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
||||
Boolean multiTurnConfig =
|
||||
agentMultiTurnConfig != null
|
||||
? agentMultiTurnConfig.isEnableMultiTurn()
|
||||
: globalMultiTurnConfig;
|
||||
|
||||
if (Boolean.TRUE.equals(multiTurnConfig)) {
|
||||
List<QueryResp> queryResps = getHistoryQueries(executeContext.getChatId(), 5);
|
||||
queryResps.stream().forEach(p -> {
|
||||
historyInput.append(p.getQueryText());
|
||||
historyInput.append(";");
|
||||
});
|
||||
queryResps.stream()
|
||||
.forEach(
|
||||
p -> {
|
||||
historyInput.append(p.getQueryText());
|
||||
historyInput.append(";");
|
||||
});
|
||||
}
|
||||
|
||||
return historyInput.toString();
|
||||
@@ -83,17 +93,20 @@ public class PlainTextExecutor implements ChatQueryExecutor {
|
||||
|
||||
private List<QueryResp> getHistoryQueries(int chatId, int multiNum) {
|
||||
ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class);
|
||||
List<QueryResp> contextualParseInfoList = chatManageService.getChatQueries(chatId)
|
||||
.stream()
|
||||
.filter(q -> Objects.nonNull(q.getQueryResult())
|
||||
&& q.getQueryResult().getQueryState() == QueryState.SUCCESS)
|
||||
.collect(Collectors.toList());
|
||||
List<QueryResp> contextualParseInfoList =
|
||||
chatManageService.getChatQueries(chatId).stream()
|
||||
.filter(
|
||||
q ->
|
||||
Objects.nonNull(q.getQueryResult())
|
||||
&& q.getQueryResult().getQueryState()
|
||||
== QueryState.SUCCESS)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<QueryResp> contextualList = contextualParseInfoList.subList(0,
|
||||
Math.min(multiNum, contextualParseInfoList.size()));
|
||||
List<QueryResp> contextualList =
|
||||
contextualParseInfoList.subList(
|
||||
0, Math.min(multiNum, contextualParseInfoList.size()));
|
||||
Collections.reverse(contextualList);
|
||||
|
||||
return contextualList;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
|
||||
public class PluginExecutor implements ChatQueryExecutor {
|
||||
|
||||
@@ -18,5 +18,4 @@ public class PluginExecutor implements ChatQueryExecutor {
|
||||
query.setParseInfo(parseInfo);
|
||||
return query.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.ResultFormatter;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
@@ -10,13 +13,10 @@ import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import lombok.SneakyThrows;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@@ -31,28 +31,35 @@ public class SqlExecutor implements ChatQueryExecutor {
|
||||
QueryResult queryResult = doExecute(executeContext);
|
||||
|
||||
if (queryResult != null) {
|
||||
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
|
||||
queryResult.getQueryResults());
|
||||
String textResult =
|
||||
ResultFormatter.transform2TextNew(
|
||||
queryResult.getQueryColumns(), queryResult.getQueryResults());
|
||||
queryResult.setTextResult(textResult);
|
||||
|
||||
if (queryResult.getQueryState().equals(QueryState.SUCCESS)
|
||||
&& queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
||||
Text2SQLExemplar exemplar = JsonUtil.toObject(JsonUtil.toString(
|
||||
executeContext.getParseInfo().getProperties()
|
||||
.get(Text2SQLExemplar.PROPERTY_KEY)), Text2SQLExemplar.class);
|
||||
Text2SQLExemplar exemplar =
|
||||
JsonUtil.toObject(
|
||||
JsonUtil.toString(
|
||||
executeContext
|
||||
.getParseInfo()
|
||||
.getProperties()
|
||||
.get(Text2SQLExemplar.PROPERTY_KEY)),
|
||||
Text2SQLExemplar.class);
|
||||
|
||||
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
|
||||
memoryService.createMemory(ChatMemoryDO.builder()
|
||||
.agentId(executeContext.getAgent().getId())
|
||||
.status(MemoryStatus.PENDING)
|
||||
.question(exemplar.getQuestion())
|
||||
.sideInfo(exemplar.getSideInfo())
|
||||
.dbSchema(exemplar.getDbSchema())
|
||||
.s2sql(exemplar.getSql())
|
||||
.createdBy(executeContext.getUser().getName())
|
||||
.updatedBy(executeContext.getUser().getName())
|
||||
.createdAt(new Date())
|
||||
.build());
|
||||
memoryService.createMemory(
|
||||
ChatMemoryDO.builder()
|
||||
.agentId(executeContext.getAgent().getId())
|
||||
.status(MemoryStatus.PENDING)
|
||||
.question(exemplar.getQuestion())
|
||||
.sideInfo(exemplar.getSideInfo())
|
||||
.dbSchema(exemplar.getDbSchema())
|
||||
.s2sql(exemplar.getSql())
|
||||
.createdBy(executeContext.getUser().getName())
|
||||
.updatedBy(executeContext.getUser().getName())
|
||||
.createdAt(new Date())
|
||||
.build());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,9 +78,8 @@ public class SqlExecutor implements ChatQueryExecutor {
|
||||
return null;
|
||||
}
|
||||
|
||||
QuerySqlReq sqlReq = QuerySqlReq.builder()
|
||||
.sql(parseInfo.getSqlInfo().getCorrectedS2SQL())
|
||||
.build();
|
||||
QuerySqlReq sqlReq =
|
||||
QuerySqlReq.builder().sql(parseInfo.getSqlInfo().getCorrectedS2SQL()).build();
|
||||
sqlReq.setSqlInfo(parseInfo.getSqlInfo());
|
||||
sqlReq.setDataSetId(parseInfo.getDataSetId());
|
||||
|
||||
@@ -97,5 +103,4 @@ public class SqlExecutor implements ChatQueryExecutor {
|
||||
}
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -27,27 +27,26 @@ public class MemoryReviewTask {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
private static final String INSTRUCTION = ""
|
||||
+ "#Role: You are a senior data engineer experienced in writing SQL.\n"
|
||||
+ "#Task: Your will be provided with a user question and the SQL written by junior engineer,"
|
||||
+ "please take a review and give your opinion.\n"
|
||||
+ "#Rules: "
|
||||
+ "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`."
|
||||
+ "2.ALWAYS recognize `数据日期` as the date field."
|
||||
+ "3.IGNORE `数据日期` if not expressed in the `Question`."
|
||||
+ "#Question: %s\n"
|
||||
+ "#Schema: %s\n"
|
||||
+ "#SideInfo: %s\n"
|
||||
+ "#SQL: %s\n"
|
||||
+ "#Response: ";
|
||||
private static final String INSTRUCTION =
|
||||
""
|
||||
+ "#Role: You are a senior data engineer experienced in writing SQL.\n"
|
||||
+ "#Task: Your will be provided with a user question and the SQL written by junior engineer,"
|
||||
+ "please take a review and give your opinion.\n"
|
||||
+ "#Rules: "
|
||||
+ "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`."
|
||||
+ "2.ALWAYS recognize `数据日期` as the date field."
|
||||
+ "3.IGNORE `数据日期` if not expressed in the `Question`."
|
||||
+ "#Question: %s\n"
|
||||
+ "#Schema: %s\n"
|
||||
+ "#SideInfo: %s\n"
|
||||
+ "#SQL: %s\n"
|
||||
+ "#Response: ";
|
||||
|
||||
private static final Pattern OUTPUT_PATTERN = Pattern.compile("opinion=(.*),.*comment=(.*)");
|
||||
|
||||
@Autowired
|
||||
private MemoryService memoryService;
|
||||
@Autowired private MemoryService memoryService;
|
||||
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
@Autowired private AgentService agentService;
|
||||
|
||||
@Scheduled(fixedDelay = 60 * 1000)
|
||||
public void review() {
|
||||
@@ -68,7 +67,8 @@ public class MemoryReviewTask {
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
|
||||
keyPipelineLog.info("MemoryReviewTask reqPrompt:\n{}", promptStr);
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatAgent.getModelConfig());
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(chatAgent.getModelConfig());
|
||||
if (Objects.nonNull(chatLanguageModel)) {
|
||||
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
|
||||
keyPipelineLog.info("MemoryReviewTask modelResp:\n{}", response);
|
||||
@@ -79,7 +79,8 @@ public class MemoryReviewTask {
|
||||
}
|
||||
|
||||
private String createPromptString(ChatMemoryDO m) {
|
||||
return String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getSideInfo(), m.getS2sql());
|
||||
return String.format(
|
||||
INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getSideInfo(), m.getS2sql());
|
||||
}
|
||||
|
||||
private void processResponse(String response, ChatMemoryDO m) {
|
||||
|
||||
@@ -6,5 +6,4 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
public interface ChatQueryParser {
|
||||
|
||||
void parse(ParseContext parseContext, ParseResp parseResp);
|
||||
|
||||
}
|
||||
|
||||
@@ -6,12 +6,14 @@ import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
public class NL2PluginParser implements ChatQueryParser {
|
||||
|
||||
private final List<PluginRecognizer> pluginRecognizers = ComponentFactory.getPluginRecognizers();
|
||||
private final List<PluginRecognizer> pluginRecognizers =
|
||||
ComponentFactory.getPluginRecognizers();
|
||||
|
||||
@Override
|
||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||
@@ -19,11 +21,13 @@ public class NL2PluginParser implements ChatQueryParser {
|
||||
return;
|
||||
}
|
||||
|
||||
pluginRecognizers.forEach(pluginRecognizer -> {
|
||||
pluginRecognizer.recognize(parseContext, parseResp);
|
||||
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
|
||||
JsonUtil.toString(parseResp));
|
||||
});
|
||||
pluginRecognizers.forEach(
|
||||
pluginRecognizer -> {
|
||||
pluginRecognizer.recognize(parseContext, parseResp);
|
||||
log.info(
|
||||
"{} recallResult:{}",
|
||||
pluginRecognizer.getClass().getSimpleName(),
|
||||
JsonUtil.toString(parseResp));
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.server.parser;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
@@ -19,7 +20,6 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
@@ -50,31 +50,33 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
private static final String REWRITE_USER_QUESTION_INSTRUCTION = ""
|
||||
+ "#Role: You are a data product manager experienced in data requirements."
|
||||
+ "#Task: Your will be provided with current and history questions asked by a user,"
|
||||
+ "along with their mapped schema elements(metric, dimension and value),"
|
||||
+ "please try understanding the semantics and rewrite a question."
|
||||
+ "#Rules: "
|
||||
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges."
|
||||
+ "2.ONLY respond with the rewritten question."
|
||||
+ "#Current Question: {{current_question}}"
|
||||
+ "#Current Mapped Schema: {{current_schema}}"
|
||||
+ "#History Question: {{history_question}}"
|
||||
+ "#History Mapped Schema: {{history_schema}}"
|
||||
+ "#History SQL: {{history_sql}}"
|
||||
+ "#Rewritten Question: ";
|
||||
private static final String REWRITE_USER_QUESTION_INSTRUCTION =
|
||||
""
|
||||
+ "#Role: You are a data product manager experienced in data requirements."
|
||||
+ "#Task: Your will be provided with current and history questions asked by a user,"
|
||||
+ "along with their mapped schema elements(metric, dimension and value),"
|
||||
+ "please try understanding the semantics and rewrite a question."
|
||||
+ "#Rules: "
|
||||
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges."
|
||||
+ "2.ONLY respond with the rewritten question."
|
||||
+ "#Current Question: {{current_question}}"
|
||||
+ "#Current Mapped Schema: {{current_schema}}"
|
||||
+ "#History Question: {{history_question}}"
|
||||
+ "#History Mapped Schema: {{history_schema}}"
|
||||
+ "#History SQL: {{history_sql}}"
|
||||
+ "#Rewritten Question: ";
|
||||
|
||||
private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = ""
|
||||
+ "#Role: You are a data business partner who closely interacts with business people.\n"
|
||||
+ "#Task: Your will be provided with user input, system output and some examples, "
|
||||
+ "please respond shortly to teach user how to ask the right question, "
|
||||
+ "by using `Examples` as references."
|
||||
+ "#Rules: ALWAYS respond with the same language as the `Input`.\n"
|
||||
+ "#Input: {{user_question}}\n"
|
||||
+ "#Output: {{system_message}}\n"
|
||||
+ "#Examples: {{examples}}\n"
|
||||
+ "#Response: ";
|
||||
private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION =
|
||||
""
|
||||
+ "#Role: You are a data business partner who closely interacts with business people.\n"
|
||||
+ "#Task: Your will be provided with user input, system output and some examples, "
|
||||
+ "please respond shortly to teach user how to ask the right question, "
|
||||
+ "by using `Examples` as references."
|
||||
+ "#Rules: ALWAYS respond with the same language as the `Input`.\n"
|
||||
+ "#Input: {{user_question}}\n"
|
||||
+ "#Output: {{system_message}}\n"
|
||||
+ "#Examples: {{examples}}\n"
|
||||
+ "#Response: ";
|
||||
|
||||
@Override
|
||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||
@@ -84,8 +86,8 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||
ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId());
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
||||
parseContext.getAgent().getModelConfig());
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(parseContext.getAgent().getModelConfig());
|
||||
|
||||
processMultiTurn(chatLanguageModel, parseContext);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
|
||||
@@ -96,11 +98,13 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
|
||||
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
||||
} else {
|
||||
parseResp.setErrorMsg(rewriteErrorMessage(chatLanguageModel,
|
||||
parseContext.getQueryText(),
|
||||
text2SqlParseResp.getErrorMsg(),
|
||||
queryNLReq.getDynamicExemplars(),
|
||||
parseContext.getAgent().getExamples()));
|
||||
parseResp.setErrorMsg(
|
||||
rewriteErrorMessage(
|
||||
chatLanguageModel,
|
||||
parseContext.getQueryText(),
|
||||
text2SqlParseResp.getErrorMsg(),
|
||||
queryNLReq.getDynamicExemplars(),
|
||||
parseContext.getAgent().getExamples()));
|
||||
}
|
||||
parseResp.setState(text2SqlParseResp.getState());
|
||||
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
|
||||
@@ -134,24 +138,35 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
StringBuilder textBuilder = new StringBuilder();
|
||||
textBuilder.append("**数据集:** ").append(parseInfo.getDataSet().getName()).append(" ");
|
||||
Optional<SchemaElement> metric = parseInfo.getMetrics().stream().findFirst();
|
||||
metric.ifPresent(schemaElement ->
|
||||
textBuilder.append("**指标:** ").append(schemaElement.getName()).append(" "));
|
||||
List<String> dimensionNames = parseInfo.getDimensions().stream()
|
||||
.map(SchemaElement::getName).filter(Objects::nonNull).collect(Collectors.toList());
|
||||
metric.ifPresent(
|
||||
schemaElement ->
|
||||
textBuilder.append("**指标:** ").append(schemaElement.getName()).append(" "));
|
||||
List<String> dimensionNames =
|
||||
parseInfo.getDimensions().stream()
|
||||
.map(SchemaElement::getName)
|
||||
.filter(Objects::nonNull)
|
||||
.collect(Collectors.toList());
|
||||
if (!CollectionUtils.isEmpty(dimensionNames)) {
|
||||
textBuilder.append("**维度:** ").append(String.join(",", dimensionNames));
|
||||
}
|
||||
textBuilder.append("\n\n**筛选条件:** \n");
|
||||
if (parseInfo.getDateInfo() != null) {
|
||||
textBuilder.append("**数据时间:** ").append(parseInfo.getDateInfo().getStartDate()).append("~")
|
||||
.append(parseInfo.getDateInfo().getEndDate()).append(" ");
|
||||
textBuilder
|
||||
.append("**数据时间:** ")
|
||||
.append(parseInfo.getDateInfo().getStartDate())
|
||||
.append("~")
|
||||
.append(parseInfo.getDateInfo().getEndDate())
|
||||
.append(" ");
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(parseInfo.getDimensionFilters())
|
||||
|| CollectionUtils.isEmpty(parseInfo.getMetricFilters())) {
|
||||
Set<QueryFilter> queryFilters = parseInfo.getDimensionFilters();
|
||||
queryFilters.addAll(parseInfo.getMetricFilters());
|
||||
for (QueryFilter queryFilter : queryFilters) {
|
||||
textBuilder.append("**").append(queryFilter.getName()).append("**")
|
||||
textBuilder
|
||||
.append("**")
|
||||
.append(queryFilter.getName())
|
||||
.append("**")
|
||||
.append(" ")
|
||||
.append(queryFilter.getOperator().getValue())
|
||||
.append(" ")
|
||||
@@ -165,10 +180,13 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
private void processMultiTurn(ChatLanguageModel chatLanguageModel, ParseContext parseContext) {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig();
|
||||
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
||||
Boolean globalMultiTurnConfig =
|
||||
Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
||||
|
||||
Boolean multiTurnConfig = agentMultiTurnConfig != null
|
||||
? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig;
|
||||
Boolean multiTurnConfig =
|
||||
agentMultiTurnConfig != null
|
||||
? agentMultiTurnConfig.isEnableMultiTurn()
|
||||
: globalMultiTurnConfig;
|
||||
if (!Boolean.TRUE.equals(multiTurnConfig)) {
|
||||
return;
|
||||
}
|
||||
@@ -186,7 +204,8 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
SemanticParseInfo lastParseInfo = lastQuery.getParseInfos().get(0);
|
||||
Long dataId = lastParseInfo.getDataSetId();
|
||||
|
||||
String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
|
||||
String curtMapStr =
|
||||
generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
|
||||
String histMapStr = generateSchemaPrompt(lastParseInfo.getElementMatches());
|
||||
String histSQL = lastParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
|
||||
@@ -207,22 +226,31 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
QueryNLReq rewrittenQueryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||
MapResp rewrittenQueryMapResult = chatLayerService.performMapping(rewrittenQueryNLReq);
|
||||
parseContext.setMapInfo(rewrittenQueryMapResult.getMapInfo());
|
||||
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
|
||||
lastQuery.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
|
||||
log.info(
|
||||
"Last Query: {} Current Query: {}, Rewritten Query: {}",
|
||||
lastQuery.getQueryText(),
|
||||
currentMapResult.getQueryText(),
|
||||
rewrittenQuery);
|
||||
}
|
||||
|
||||
private String rewriteErrorMessage(ChatLanguageModel chatLanguageModel, String userQuestion,
|
||||
String errMsg, List<Text2SQLExemplar> similarExemplars,
|
||||
List<String> agentExamples) {
|
||||
private String rewriteErrorMessage(
|
||||
ChatLanguageModel chatLanguageModel,
|
||||
String userQuestion,
|
||||
String errMsg,
|
||||
List<Text2SQLExemplar> similarExemplars,
|
||||
List<String> agentExamples) {
|
||||
Map<String, Object> variables = new HashMap<>();
|
||||
variables.put("user_question", userQuestion);
|
||||
variables.put("system_message", errMsg);
|
||||
|
||||
StringBuilder exampleStr = new StringBuilder();
|
||||
similarExemplars.forEach(e ->
|
||||
exampleStr.append(String.format("<Question:{%s},Schema:{%s}> ", e.getQuestion(), e.getDbSchema())));
|
||||
agentExamples.forEach(e ->
|
||||
exampleStr.append(String.format("<Question:{%s}> ", e)));
|
||||
similarExemplars.forEach(
|
||||
e ->
|
||||
exampleStr.append(
|
||||
String.format(
|
||||
"<Question:{%s},Schema:{%s}> ",
|
||||
e.getQuestion(), e.getDbSchema())));
|
||||
agentExamples.forEach(e -> exampleStr.append(String.format("<Question:{%s}> ", e)));
|
||||
variables.put("examples", exampleStr);
|
||||
|
||||
Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables);
|
||||
@@ -262,14 +290,18 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
|
||||
private List<QueryResp> getHistoryQueries(int chatId, int multiNum) {
|
||||
ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class);
|
||||
List<QueryResp> contextualParseInfoList = chatManageService.getChatQueries(chatId)
|
||||
.stream()
|
||||
.filter(q -> Objects.nonNull(q.getQueryResult())
|
||||
&& q.getQueryResult().getQueryState() == QueryState.SUCCESS)
|
||||
.collect(Collectors.toList());
|
||||
List<QueryResp> contextualParseInfoList =
|
||||
chatManageService.getChatQueries(chatId).stream()
|
||||
.filter(
|
||||
q ->
|
||||
Objects.nonNull(q.getQueryResult())
|
||||
&& q.getQueryResult().getQueryState()
|
||||
== QueryState.SUCCESS)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<QueryResp> contextualList = contextualParseInfoList.subList(0,
|
||||
Math.min(multiNum, contextualParseInfoList.size()));
|
||||
List<QueryResp> contextualList =
|
||||
contextualParseInfoList.subList(
|
||||
0, Math.min(multiNum, contextualParseInfoList.size()));
|
||||
Collections.reverse(contextualList);
|
||||
return contextualList;
|
||||
}
|
||||
@@ -278,9 +310,8 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
||||
List<Text2SQLExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
|
||||
queryNLReq.getQueryText(), 5);
|
||||
List<Text2SQLExemplar> exemplars =
|
||||
exemplarManager.recallExemplars(memoryCollectionName, queryNLReq.getQueryText(), 5);
|
||||
queryNLReq.getDynamicExemplars().addAll(exemplars);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -10,8 +10,11 @@ import org.springframework.stereotype.Service;
|
||||
public class ParserConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter PARSER_MULTI_TURN_ENABLE =
|
||||
new Parameter("s2.parser.multi-turn.enable", "false",
|
||||
"是否开启多轮对话", "开启多轮对话将消耗更多token",
|
||||
"bool", "Parser相关配置");
|
||||
|
||||
new Parameter(
|
||||
"s2.parser.multi-turn.enable",
|
||||
"false",
|
||||
"是否开启多轮对话",
|
||||
"开启多轮对话将消耗更多token",
|
||||
"bool",
|
||||
"Parser相关配置");
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
|
||||
public class PlainTextParser implements ChatQueryParser {
|
||||
|
||||
@Override
|
||||
@@ -18,5 +17,4 @@ public class PlainTextParser implements ChatQueryParser {
|
||||
parseResp.getSelectedParses().add(parseInfo);
|
||||
parseResp.setState(ParseResp.ParseState.COMPLETED);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -10,61 +10,40 @@ import java.util.Date;
|
||||
@Data
|
||||
@TableName("s2_agent")
|
||||
public class AgentDO {
|
||||
/**
|
||||
*
|
||||
*/
|
||||
/** */
|
||||
@TableId(type = IdType.AUTO)
|
||||
private Integer id;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
/** */
|
||||
private String name;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
/** */
|
||||
private String description;
|
||||
|
||||
/**
|
||||
* 0 offline, 1 online
|
||||
*/
|
||||
/** 0 offline, 1 online */
|
||||
private Integer status;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
/** */
|
||||
private String examples;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
/** */
|
||||
private String config;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
/** */
|
||||
private String createdBy;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
/** */
|
||||
private Date createdAt;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
/** */
|
||||
private String updatedBy;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
/** */
|
||||
private Date updatedAt;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
/** */
|
||||
private Integer enableSearch;
|
||||
|
||||
private Integer enableMemoryReview;
|
||||
private String modelConfig;
|
||||
private String multiTurnConfig;
|
||||
@@ -72,5 +51,4 @@ public class AgentDO {
|
||||
private String visualConfig;
|
||||
|
||||
private String promptConfig;
|
||||
|
||||
}
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class ChatConfigDO {
|
||||
|
||||
/**
|
||||
* database auto-increment primary key
|
||||
*/
|
||||
/** database auto-increment primary key */
|
||||
private Long id;
|
||||
|
||||
private Long modelId;
|
||||
@@ -27,12 +24,10 @@ public class ChatConfigDO {
|
||||
|
||||
private String llmExamples;
|
||||
|
||||
/**
|
||||
* record info
|
||||
*/
|
||||
/** record info */
|
||||
private String createdBy;
|
||||
|
||||
private String updatedBy;
|
||||
private Date createdAt;
|
||||
private Date updatedAt;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,5 +61,4 @@ public class ChatMemoryDO {
|
||||
|
||||
@TableField("updated_at")
|
||||
private Date updatedAt;
|
||||
|
||||
}
|
||||
|
||||
@@ -4,54 +4,44 @@ import com.baomidou.mybatisplus.annotation.IdType;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
@TableName("s2_chat_query")
|
||||
public class ChatQueryDO {
|
||||
/**
|
||||
*/
|
||||
/** */
|
||||
@TableId(type = IdType.AUTO)
|
||||
private Long questionId;
|
||||
|
||||
/**
|
||||
*/
|
||||
/** */
|
||||
private Integer agentId;
|
||||
|
||||
/**
|
||||
*/
|
||||
/** */
|
||||
private Date createTime;
|
||||
|
||||
/**
|
||||
*/
|
||||
/** */
|
||||
private String userName;
|
||||
|
||||
/**
|
||||
*/
|
||||
/** */
|
||||
private Integer queryState;
|
||||
|
||||
/**
|
||||
*/
|
||||
/** */
|
||||
private Long chatId;
|
||||
|
||||
/**
|
||||
*/
|
||||
/** */
|
||||
private Integer score;
|
||||
|
||||
/**
|
||||
*/
|
||||
/** */
|
||||
private String feedback;
|
||||
|
||||
/**
|
||||
*/
|
||||
/** */
|
||||
private String queryText;
|
||||
|
||||
/**
|
||||
*/
|
||||
/** */
|
||||
private String queryResult;
|
||||
|
||||
private String similarQueries;
|
||||
|
||||
private String parseTimeCost;
|
||||
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
public class DictConfDO {
|
||||
|
||||
@@ -17,5 +17,4 @@ public class DictConfDO {
|
||||
private String updatedBy;
|
||||
private Date createdAt;
|
||||
private Date updatedAt;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
import org.apache.commons.codec.digest.DigestUtils;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class DictTaskDO {
|
||||
@@ -35,4 +35,4 @@ public class DictTaskDO {
|
||||
public String getCommandMd5() {
|
||||
return DigestUtils.md5Hex(command);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.core.config.DefaultMetric;
|
||||
import com.tencent.supersonic.headless.core.config.Dim4Dict;
|
||||
import lombok.Data;
|
||||
@@ -9,7 +8,6 @@ import lombok.ToString;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class DimValueDO {
|
||||
@@ -34,4 +32,4 @@ public class DimValueDO {
|
||||
this.dimensions = dimensions;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.baomidou.mybatisplus.annotation.IdType;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
@@ -36,5 +37,4 @@ public class PluginDO {
|
||||
private String config;
|
||||
|
||||
private String comment;
|
||||
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@@ -14,40 +14,25 @@ import java.util.Date;
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public class StatisticsDO {
|
||||
/**
|
||||
* questionId
|
||||
*/
|
||||
/** questionId */
|
||||
private Long questionId;
|
||||
|
||||
/**
|
||||
* chatId
|
||||
*/
|
||||
/** chatId */
|
||||
private Long chatId;
|
||||
|
||||
/**
|
||||
* createTime
|
||||
*/
|
||||
/** createTime */
|
||||
private Date createTime;
|
||||
|
||||
/**
|
||||
* queryText
|
||||
*/
|
||||
/** queryText */
|
||||
private String queryText;
|
||||
|
||||
/**
|
||||
* userName
|
||||
*/
|
||||
/** userName */
|
||||
private String userName;
|
||||
|
||||
|
||||
/**
|
||||
* interface
|
||||
*/
|
||||
/** interface */
|
||||
private String interfaceName;
|
||||
|
||||
/**
|
||||
* cost
|
||||
*/
|
||||
/** cost */
|
||||
private Integer cost;
|
||||
|
||||
private Integer type;
|
||||
|
||||
@@ -5,6 +5,4 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
@Mapper
|
||||
public interface AgentDOMapper extends BaseMapper<AgentDO> {
|
||||
|
||||
}
|
||||
public interface AgentDOMapper extends BaseMapper<AgentDO> {}
|
||||
|
||||
@@ -2,9 +2,10 @@ package com.tencent.supersonic.chat.server.persistence.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.server.config.ChatConfigFilterInternal;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatConfigDO;
|
||||
import java.util.List;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Mapper
|
||||
public interface ChatConfigMapper {
|
||||
|
||||
|
||||
@@ -2,9 +2,10 @@ package com.tencent.supersonic.chat.server.persistence.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.QueryDO;
|
||||
import java.util.List;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Mapper
|
||||
public interface ChatMapper {
|
||||
|
||||
@@ -21,5 +22,4 @@ public interface ChatMapper {
|
||||
boolean updateFeedback(QueryDO queryDO);
|
||||
|
||||
Boolean deleteChat(Long chatId, String userName);
|
||||
|
||||
}
|
||||
|
||||
@@ -5,6 +5,4 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
@Mapper
|
||||
public interface ChatMemoryMapper extends BaseMapper<ChatMemoryDO> {
|
||||
|
||||
}
|
||||
public interface ChatMemoryMapper extends BaseMapper<ChatMemoryDO> {}
|
||||
|
||||
@@ -6,7 +6,6 @@ import org.apache.ibatis.annotations.Param;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@Mapper
|
||||
public interface ChatParseMapper {
|
||||
|
||||
@@ -19,5 +18,4 @@ public interface ChatParseMapper {
|
||||
List<ChatParseDO> getParseInfoList(List<Long> questionIds);
|
||||
|
||||
List<ChatParseDO> getContextualParseInfo(Integer chatId);
|
||||
|
||||
}
|
||||
|
||||
@@ -5,6 +5,4 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
@Mapper
|
||||
public interface ChatQueryDOMapper extends BaseMapper<ChatQueryDO> {
|
||||
|
||||
}
|
||||
public interface ChatQueryDOMapper extends BaseMapper<ChatQueryDO> {}
|
||||
|
||||
@@ -5,6 +5,4 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.PluginDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
@Mapper
|
||||
public interface PluginDOMapper extends BaseMapper<PluginDO> {
|
||||
|
||||
}
|
||||
public interface PluginDOMapper extends BaseMapper<PluginDO> {}
|
||||
|
||||
@@ -2,11 +2,11 @@ package com.tencent.supersonic.chat.server.persistence.mapper.custom;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Mapper
|
||||
public interface ShowCaseCustomMapper {
|
||||
|
||||
List<ChatQueryDO> queryShowCase(int start, int limit, int agentId, String userName);
|
||||
|
||||
}
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.server.config.ChatConfig;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
|
||||
import com.tencent.supersonic.chat.server.config.ChatConfig;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@@ -7,5 +7,4 @@ public interface ChatContextRepository {
|
||||
ChatContext getOrCreateContext(Integer chatId);
|
||||
|
||||
void updateContext(ChatContext chatCtx);
|
||||
|
||||
}
|
||||
|
||||
@@ -30,11 +30,12 @@ public interface ChatQueryRepository {
|
||||
|
||||
Long createChatQuery(ChatParseReq chatParseReq);
|
||||
|
||||
List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq, ParseResp parseResult,
|
||||
List<SemanticParseInfo> candidateParses);
|
||||
List<ChatParseDO> batchSaveParseInfo(
|
||||
ChatParseReq chatParseReq,
|
||||
ParseResp parseResult,
|
||||
List<SemanticParseInfo> candidateParses);
|
||||
|
||||
ChatParseDO getParseInfo(Long questionId, int parseId);
|
||||
|
||||
List<ChatParseDO> getParseInfoList(List<Long> questionIds);
|
||||
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.QueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.QueryDO;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface ChatRepository {
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository.impl;
|
||||
|
||||
import com.tencent.supersonic.chat.server.config.ChatConfig;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
|
||||
import com.tencent.supersonic.chat.server.config.ChatConfigFilterInternal;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
|
||||
import com.tencent.supersonic.chat.server.config.ChatConfig;
|
||||
import com.tencent.supersonic.chat.server.config.ChatConfigFilterInternal;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatConfigDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.ChatConfigMapper;
|
||||
import com.tencent.supersonic.chat.server.util.ChatConfigHelper;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatConfigRepository;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import com.tencent.supersonic.chat.server.util.ChatConfigHelper;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.stereotype.Repository;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Repository
|
||||
@Primary
|
||||
public class ChatConfigRepositoryImpl implements ChatConfigRepository {
|
||||
@@ -23,8 +23,8 @@ public class ChatConfigRepositoryImpl implements ChatConfigRepository {
|
||||
private final ChatConfigHelper chatConfigHelper;
|
||||
private final ChatConfigMapper chatConfigMapper;
|
||||
|
||||
public ChatConfigRepositoryImpl(ChatConfigHelper chatConfigHelper,
|
||||
ChatConfigMapper chatConfigMapper) {
|
||||
public ChatConfigRepositoryImpl(
|
||||
ChatConfigHelper chatConfigHelper, ChatConfigMapper chatConfigMapper) {
|
||||
this.chatConfigHelper = chatConfigHelper;
|
||||
this.chatConfigMapper = chatConfigMapper;
|
||||
}
|
||||
@@ -41,7 +41,6 @@ public class ChatConfigRepositoryImpl implements ChatConfigRepository {
|
||||
ChatConfigDO chaConfigDO = chatConfigHelper.chatConfig2DO(chaConfig);
|
||||
|
||||
return chatConfigMapper.editConfig(chaConfigDO);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -52,9 +51,12 @@ public class ChatConfigRepositoryImpl implements ChatConfigRepository {
|
||||
filterInternal.setStatus(filter.getStatus().getCode());
|
||||
List<ChatConfigDO> chaConfigDOList = chatConfigMapper.search(filterInternal);
|
||||
if (!CollectionUtils.isEmpty(chaConfigDOList)) {
|
||||
chaConfigDOList.stream().forEach(chaConfigDO ->
|
||||
chaConfigDescriptorList.add(chatConfigHelper
|
||||
.chatConfigDO2Descriptor(chaConfigDO.getModelId(), chaConfigDO)));
|
||||
chaConfigDOList.stream()
|
||||
.forEach(
|
||||
chaConfigDO ->
|
||||
chaConfigDescriptorList.add(
|
||||
chatConfigHelper.chatConfigDO2Descriptor(
|
||||
chaConfigDO.getModelId(), chaConfigDO)));
|
||||
}
|
||||
return chaConfigDescriptorList;
|
||||
}
|
||||
@@ -64,5 +66,4 @@ public class ChatConfigRepositoryImpl implements ChatConfigRepository {
|
||||
ChatConfigDO chaConfigPO = chatConfigMapper.fetchConfigByModelId(modelId);
|
||||
return chatConfigHelper.chatConfigDO2Descriptor(modelId, chaConfigPO);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository.impl;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.ChatContextMapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatContextRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.ChatContextMapper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.stereotype.Repository;
|
||||
@@ -16,7 +16,6 @@ import org.springframework.stereotype.Repository;
|
||||
@Slf4j
|
||||
public class ChatContextRepositoryImpl implements ChatContextRepository {
|
||||
|
||||
|
||||
private final ChatContextMapper chatContextMapper;
|
||||
|
||||
public ChatContextRepositoryImpl(ChatContextMapper chatContextMapper) {
|
||||
@@ -50,8 +49,8 @@ public class ChatContextRepositoryImpl implements ChatContextRepository {
|
||||
chatContext.setUser(contextDO.getUser());
|
||||
chatContext.setQueryText(contextDO.getQueryText());
|
||||
if (contextDO.getSemanticParse() != null && !contextDO.getSemanticParse().isEmpty()) {
|
||||
SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(),
|
||||
SemanticParseInfo.class);
|
||||
SemanticParseInfo semanticParseInfo =
|
||||
JsonUtil.toObject(contextDO.getSemanticParse(), SemanticParseInfo.class);
|
||||
chatContext.setParseInfo(semanticParseInfo);
|
||||
}
|
||||
return chatContext;
|
||||
|
||||
@@ -38,5 +38,4 @@ public class ChatMemoryRepositoryImpl implements ChatMemoryRepository {
|
||||
public List<ChatMemoryDO> getMemories(QueryWrapper<ChatMemoryDO> queryWrapper) {
|
||||
return chatMemoryMapper.selectList(queryWrapper);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import com.github.pagehelper.PageHelper;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
@@ -18,9 +20,7 @@ import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.PageUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseTimeCostResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
@@ -40,14 +40,11 @@ import java.util.stream.Collectors;
|
||||
@Slf4j
|
||||
public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
|
||||
@Autowired
|
||||
private ChatQueryDOMapper chatQueryDOMapper;
|
||||
@Autowired private ChatQueryDOMapper chatQueryDOMapper;
|
||||
|
||||
@Autowired
|
||||
private ChatParseMapper chatParseMapper;
|
||||
@Autowired private ChatParseMapper chatParseMapper;
|
||||
|
||||
@Autowired
|
||||
private ShowCaseCustomMapper showCaseCustomMapper;
|
||||
@Autowired private ShowCaseCustomMapper showCaseCustomMapper;
|
||||
|
||||
@Override
|
||||
public PageInfo<QueryResp> getChatQuery(PageQueryInfoReq pageQueryInfoReq, Long chatId) {
|
||||
@@ -65,9 +62,9 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
queryWrapper.lambda().ne(ChatQueryDO::getQueryResult, "");
|
||||
queryWrapper.lambda().orderByDesc(ChatQueryDO::getQuestionId);
|
||||
|
||||
PageInfo<ChatQueryDO> pageInfo = PageHelper.startPage(pageQueryInfoReq.getCurrent(),
|
||||
pageQueryInfoReq.getPageSize())
|
||||
.doSelectPageInfo(() -> chatQueryDOMapper.selectList(queryWrapper));
|
||||
PageInfo<ChatQueryDO> pageInfo =
|
||||
PageHelper.startPage(pageQueryInfoReq.getCurrent(), pageQueryInfoReq.getPageSize())
|
||||
.doSelectPageInfo(() -> chatQueryDOMapper.selectList(queryWrapper));
|
||||
|
||||
PageInfo<QueryResp> chatQueryVOPageInfo = PageUtils.pageInfo2PageInfoVo(pageInfo);
|
||||
chatQueryVOPageInfo.setList(
|
||||
@@ -104,24 +101,31 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
|
||||
@Override
|
||||
public List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) {
|
||||
return showCaseCustomMapper.queryShowCase(pageQueryInfoReq.getLimitStart(),
|
||||
pageQueryInfoReq.getPageSize(), agentId, pageQueryInfoReq.getUserName())
|
||||
.stream().map(this::convertTo)
|
||||
return showCaseCustomMapper
|
||||
.queryShowCase(
|
||||
pageQueryInfoReq.getLimitStart(),
|
||||
pageQueryInfoReq.getPageSize(),
|
||||
agentId,
|
||||
pageQueryInfoReq.getUserName())
|
||||
.stream()
|
||||
.map(this::convertTo)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private QueryResp convertTo(ChatQueryDO chatQueryDO) {
|
||||
QueryResp queryResp = new QueryResp();
|
||||
BeanUtils.copyProperties(chatQueryDO, queryResp);
|
||||
QueryResult queryResult = JsonUtil.toObject(chatQueryDO.getQueryResult(), QueryResult.class);
|
||||
QueryResult queryResult =
|
||||
JsonUtil.toObject(chatQueryDO.getQueryResult(), QueryResult.class);
|
||||
if (queryResult != null) {
|
||||
queryResult.setQueryId(chatQueryDO.getQuestionId());
|
||||
queryResp.setQueryResult(queryResult);
|
||||
}
|
||||
queryResp.setSimilarQueries(JSONObject.parseArray(chatQueryDO.getSimilarQueries(),
|
||||
SimilarQueryRecallResp.class));
|
||||
queryResp.setParseTimeCost(JsonUtil.toObject(chatQueryDO.getParseTimeCost(),
|
||||
ParseTimeCostResp.class));
|
||||
queryResp.setSimilarQueries(
|
||||
JSONObject.parseArray(
|
||||
chatQueryDO.getSimilarQueries(), SimilarQueryRecallResp.class));
|
||||
queryResp.setParseTimeCost(
|
||||
JsonUtil.toObject(chatQueryDO.getParseTimeCost(), ParseTimeCostResp.class));
|
||||
return queryResp;
|
||||
}
|
||||
|
||||
@@ -143,8 +147,10 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq,
|
||||
ParseResp parseResult, List<SemanticParseInfo> candidateParses) {
|
||||
public List<ChatParseDO> batchSaveParseInfo(
|
||||
ChatParseReq chatParseReq,
|
||||
ParseResp parseResult,
|
||||
List<SemanticParseInfo> candidateParses) {
|
||||
List<ChatParseDO> chatParseDOList = new ArrayList<>();
|
||||
getChatParseDO(chatParseReq, parseResult.getQueryId(), candidateParses, chatParseDOList);
|
||||
if (!CollectionUtils.isEmpty(candidateParses)) {
|
||||
@@ -153,8 +159,11 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
return chatParseDOList;
|
||||
}
|
||||
|
||||
public void getChatParseDO(ChatParseReq chatParseReq, Long queryId,
|
||||
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
|
||||
public void getChatParseDO(
|
||||
ChatParseReq chatParseReq,
|
||||
Long queryId,
|
||||
List<SemanticParseInfo> parses,
|
||||
List<ChatParseDO> chatParseDOList) {
|
||||
for (int i = 0; i < parses.size(); i++) {
|
||||
ChatParseDO chatParseDO = new ChatParseDO();
|
||||
chatParseDO.setChatId(chatParseReq.getChatId());
|
||||
@@ -190,5 +199,4 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
public List<ChatParseDO> getParseInfoList(List<Long> questionIds) {
|
||||
return chatParseMapper.getParseInfoList(questionIds);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -56,5 +56,4 @@ public class ChatRepositoryImpl implements ChatRepository {
|
||||
public Boolean deleteChat(Long chatId, String userName) {
|
||||
return chatMapper.deleteChat(chatId, userName);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -77,5 +77,4 @@ public class PluginRepositoryImpl implements PluginRepository {
|
||||
public void deletePlugin(Long id) {
|
||||
pluginDOMapper.deleteById(id);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package com.tencent.supersonic.chat.server.plugin;
|
||||
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@@ -14,37 +14,30 @@ public class ChatPlugin extends RecordInfo {
|
||||
|
||||
private Long id;
|
||||
|
||||
/***
|
||||
* plugin type WEB_PAGE WEB_SERVICE
|
||||
*/
|
||||
/** * plugin type WEB_PAGE WEB_SERVICE */
|
||||
private String type;
|
||||
|
||||
private List<Long> dataSetList = Lists.newArrayList();
|
||||
|
||||
/**
|
||||
* description, for parsing
|
||||
*/
|
||||
/** description, for parsing */
|
||||
private String pattern;
|
||||
|
||||
/**
|
||||
* parse
|
||||
*/
|
||||
/** parse */
|
||||
private ParseMode parseMode;
|
||||
|
||||
private String parseModeConfig;
|
||||
|
||||
private String name;
|
||||
|
||||
/**
|
||||
* config for different plugin type
|
||||
*/
|
||||
/** config for different plugin type */
|
||||
private String config;
|
||||
|
||||
private String comment;
|
||||
|
||||
public List<String> getExampleQuestionList() {
|
||||
if (StringUtils.isNotBlank(parseModeConfig)) {
|
||||
PluginParseConfig pluginParseConfig = JSONObject.parseObject(parseModeConfig, PluginParseConfig.class);
|
||||
PluginParseConfig pluginParseConfig =
|
||||
JSONObject.parseObject(parseModeConfig, PluginParseConfig.class);
|
||||
return pluginParseConfig.getExamples();
|
||||
}
|
||||
return Lists.newArrayList();
|
||||
@@ -57,5 +50,4 @@ public class ChatPlugin extends RecordInfo {
|
||||
public Long getDefaultMode() {
|
||||
return -1L;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.plugin;
|
||||
|
||||
public enum ParseMode {
|
||||
|
||||
EMBEDDING_RECALL,
|
||||
FUNCTION_CALL;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,11 +46,9 @@ import java.util.stream.Collectors;
|
||||
@Component
|
||||
public class PluginManager {
|
||||
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
@Autowired private EmbeddingConfig embeddingConfig;
|
||||
|
||||
@Autowired
|
||||
private EmbeddingService embeddingService;
|
||||
@Autowired private EmbeddingService embeddingService;
|
||||
|
||||
public static List<ChatPlugin> getPluginAgentCanSupport(ParseContext parseContext) {
|
||||
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
||||
@@ -59,14 +57,21 @@ public class PluginManager {
|
||||
if (Objects.isNull(agent)) {
|
||||
return plugins;
|
||||
}
|
||||
List<Long> pluginIds = getPluginTools(agent).stream().map(PluginTool::getPlugins)
|
||||
.flatMap(Collection::stream).collect(Collectors.toList());
|
||||
List<Long> pluginIds =
|
||||
getPluginTools(agent).stream()
|
||||
.map(PluginTool::getPlugins)
|
||||
.flatMap(Collection::stream)
|
||||
.collect(Collectors.toList());
|
||||
if (CollectionUtils.isEmpty(pluginIds)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
plugins = plugins.stream().filter(plugin -> pluginIds.contains(plugin.getId()))
|
||||
.collect(Collectors.toList());
|
||||
log.info("plugins witch can be supported by cur agent :{} {}", agent.getName(),
|
||||
plugins =
|
||||
plugins.stream()
|
||||
.filter(plugin -> pluginIds.contains(plugin.getId()))
|
||||
.collect(Collectors.toList());
|
||||
log.info(
|
||||
"plugins witch can be supported by cur agent :{} {}",
|
||||
agent.getName(),
|
||||
plugins.stream().map(ChatPlugin::getName).collect(Collectors.toList()));
|
||||
return plugins;
|
||||
}
|
||||
@@ -79,7 +84,8 @@ public class PluginManager {
|
||||
if (CollectionUtils.isEmpty(tools)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return tools.stream().map(tool -> JSONObject.parseObject(tool, PluginTool.class))
|
||||
return tools.stream()
|
||||
.map(tool -> JSONObject.parseObject(tool, PluginTool.class))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@@ -136,18 +142,23 @@ public class PluginManager {
|
||||
|
||||
public RetrieveQueryResult recognize(String embeddingText) {
|
||||
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||
.queryTextsList(Collections.singletonList(embeddingText))
|
||||
.build();
|
||||
RetrieveQuery retrieveQuery =
|
||||
RetrieveQuery.builder()
|
||||
.queryTextsList(Collections.singletonList(embeddingText))
|
||||
.build();
|
||||
|
||||
List<RetrieveQueryResult> resultList = embeddingService.retrieveQuery(embeddingConfig.getPresetCollection(),
|
||||
retrieveQuery, embeddingConfig.getNResult());
|
||||
List<RetrieveQueryResult> resultList =
|
||||
embeddingService.retrieveQuery(
|
||||
embeddingConfig.getPresetCollection(),
|
||||
retrieveQuery,
|
||||
embeddingConfig.getNResult());
|
||||
|
||||
if (CollectionUtils.isNotEmpty(resultList)) {
|
||||
for (RetrieveQueryResult embeddingResp : resultList) {
|
||||
List<Retrieval> embeddingRetrievals = embeddingResp.getRetrieval();
|
||||
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
embeddingRetrieval.setId(getPluginIdFromEmbeddingId(embeddingRetrieval.getId()));
|
||||
embeddingRetrieval.setId(
|
||||
getPluginIdFromEmbeddingId(embeddingRetrieval.getId()));
|
||||
}
|
||||
}
|
||||
return resultList.get(0);
|
||||
@@ -162,7 +173,8 @@ public class PluginManager {
|
||||
int num = 0;
|
||||
for (String pattern : exampleQuestions) {
|
||||
TextSegment query = TextSegment.from(pattern);
|
||||
TextSegmentConvert.addQueryId(query, generateUniqueEmbeddingId(num, plugin.getId()));
|
||||
TextSegmentConvert.addQueryId(
|
||||
query, generateUniqueEmbeddingId(num, plugin.getId()));
|
||||
queries.add(query);
|
||||
num++;
|
||||
}
|
||||
@@ -178,7 +190,7 @@ public class PluginManager {
|
||||
return embeddingIdSet;
|
||||
}
|
||||
|
||||
//num can not bigger than 100
|
||||
// num can not bigger than 100
|
||||
private String generateUniqueEmbeddingId(int num, Long pluginId) {
|
||||
if (num < 10) {
|
||||
return String.format("%s00%s", pluginId, num);
|
||||
@@ -202,8 +214,8 @@ public class PluginManager {
|
||||
return Pair.of(true, pluginMatchedDataSet);
|
||||
}
|
||||
Set<Long> matchedDataSet = Sets.newHashSet();
|
||||
Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream()
|
||||
.collect(Collectors.groupingBy(ParamOption::getDataSetId));
|
||||
Map<Long, List<ParamOption>> paramOptionMap =
|
||||
paramOptions.stream().collect(Collectors.groupingBy(ParamOption::getDataSetId));
|
||||
for (Long dataSetId : paramOptionMap.keySet()) {
|
||||
List<ParamOption> params = paramOptionMap.get(dataSetId);
|
||||
if (CollectionUtils.isEmpty(params)) {
|
||||
@@ -237,9 +249,13 @@ public class PluginManager {
|
||||
if (org.springframework.util.CollectionUtils.isEmpty(schemaElementMatches)) {
|
||||
return Sets.newHashSet();
|
||||
}
|
||||
return schemaElementMatches.stream().filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
|
||||
return schemaElementMatches.stream()
|
||||
.filter(
|
||||
schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(
|
||||
schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.ID.equals(
|
||||
schemaElementMatch.getElement().getType()))
|
||||
.map(SchemaElementMatch::getElement)
|
||||
.map(SchemaElement::getId)
|
||||
.collect(Collectors.toSet());
|
||||
@@ -255,7 +271,9 @@ public class PluginManager {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return paramOptions.stream()
|
||||
.filter(paramOption -> ParamOption.ParamType.SEMANTIC.equals(paramOption.getParamType()))
|
||||
.filter(
|
||||
paramOption ->
|
||||
ParamOption.ParamType.SEMANTIC.equals(paramOption.getParamType()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@@ -273,5 +291,4 @@ public class PluginManager {
|
||||
}
|
||||
return pluginMatchedDataSet;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.server.plugin;
|
||||
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
@@ -22,5 +21,4 @@ public class PluginParseConfig implements Serializable {
|
||||
private String name;
|
||||
|
||||
private String description;
|
||||
|
||||
}
|
||||
|
||||
@@ -20,5 +20,4 @@ public class PluginQueryManager {
|
||||
public static PluginSemanticQuery getPluginQuery(String queryMode) {
|
||||
return pluginQueries.get(queryMode);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -20,5 +20,4 @@ public class PluginRecallResult {
|
||||
private double score;
|
||||
|
||||
private double distance;
|
||||
|
||||
}
|
||||
|
||||
@@ -22,16 +22,17 @@ public class ParamOption {
|
||||
private Object value;
|
||||
|
||||
/**
|
||||
* CUSTOM: the value is specified by the user
|
||||
* SEMANTIC: the value of element
|
||||
* FORWARD: only forward
|
||||
* CUSTOM: the value is specified by the user SEMANTIC: the value of element FORWARD: only
|
||||
* forward
|
||||
*/
|
||||
public enum ParamType {
|
||||
CUSTOM, SEMANTIC, FORWARD
|
||||
CUSTOM,
|
||||
SEMANTIC,
|
||||
FORWARD
|
||||
}
|
||||
|
||||
public enum OptionType {
|
||||
REQUIRED, OPTIONAL
|
||||
REQUIRED,
|
||||
OPTIONAL
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.build;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -43,29 +43,40 @@ public abstract class PluginSemanticQuery {
|
||||
protected Map<String, Object> getElementMap(PluginParseResult pluginParseResult) {
|
||||
Map<String, Object> elementValueMap = new HashMap<>();
|
||||
Map<Long, Object> filterValueMap = getFilterMap(pluginParseResult);
|
||||
List<SchemaElementMatch> schemaElementMatchList = parseInfo.getElementMatches()
|
||||
.stream().filter(schemaElementMatch -> schemaElementMatch.getFrequency() != null)
|
||||
.sorted(Comparator.comparingLong(SchemaElementMatch::getFrequency).reversed())
|
||||
.collect(Collectors.toList());
|
||||
List<SchemaElementMatch> schemaElementMatchList =
|
||||
parseInfo.getElementMatches().stream()
|
||||
.filter(schemaElementMatch -> schemaElementMatch.getFrequency() != null)
|
||||
.sorted(
|
||||
Comparator.comparingLong(SchemaElementMatch::getFrequency)
|
||||
.reversed())
|
||||
.collect(Collectors.toList());
|
||||
if (!CollectionUtils.isEmpty(schemaElementMatchList)) {
|
||||
schemaElementMatchList.stream().filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
|
||||
schemaElementMatchList.stream()
|
||||
.filter(
|
||||
schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(
|
||||
schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.ID.equals(
|
||||
schemaElementMatch.getElement().getType()))
|
||||
.filter(schemaElementMatch -> schemaElementMatch.getSimilarity() == 1.0)
|
||||
.forEach(schemaElementMatch -> {
|
||||
Object queryFilterValue = filterValueMap.get(schemaElementMatch.getElement().getId());
|
||||
if (queryFilterValue != null) {
|
||||
if (String.valueOf(queryFilterValue).equals(String.valueOf(schemaElementMatch.getWord()))) {
|
||||
elementValueMap.put(
|
||||
String.valueOf(schemaElementMatch.getElement().getId()),
|
||||
schemaElementMatch.getWord());
|
||||
}
|
||||
} else {
|
||||
elementValueMap.computeIfAbsent(
|
||||
String.valueOf(schemaElementMatch.getElement().getId()),
|
||||
k -> schemaElementMatch.getWord());
|
||||
}
|
||||
});
|
||||
.forEach(
|
||||
schemaElementMatch -> {
|
||||
Object queryFilterValue =
|
||||
filterValueMap.get(schemaElementMatch.getElement().getId());
|
||||
if (queryFilterValue != null) {
|
||||
if (String.valueOf(queryFilterValue)
|
||||
.equals(String.valueOf(schemaElementMatch.getWord()))) {
|
||||
elementValueMap.put(
|
||||
String.valueOf(
|
||||
schemaElementMatch.getElement().getId()),
|
||||
schemaElementMatch.getWord());
|
||||
}
|
||||
} else {
|
||||
elementValueMap.computeIfAbsent(
|
||||
String.valueOf(schemaElementMatch.getElement().getId()),
|
||||
k -> schemaElementMatch.getWord());
|
||||
}
|
||||
});
|
||||
}
|
||||
return elementValueMap;
|
||||
}
|
||||
@@ -97,5 +108,4 @@ public abstract class PluginSemanticQuery {
|
||||
public void setParseInfo(SemanticParseInfo parseInfo) {
|
||||
this.parseInfo = parseInfo;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -15,5 +15,4 @@ public class WebBase {
|
||||
public List<ParamOption> getParams() {
|
||||
return paramOptions;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.build.webpage;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
@@ -7,7 +8,6 @@ import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
@@ -41,12 +41,13 @@ public class WebPageQuery extends PluginSemanticQuery {
|
||||
QueryResult queryResult = new QueryResult();
|
||||
queryResult.setQueryMode(QUERY_MODE);
|
||||
Map<String, Object> properties = parseInfo.getProperties();
|
||||
PluginParseResult pluginParseResult = JsonUtil.toObject(JsonUtil.toString(properties.get(Constants.CONTEXT)),
|
||||
PluginParseResult.class);
|
||||
PluginParseResult pluginParseResult =
|
||||
JsonUtil.toObject(
|
||||
JsonUtil.toString(properties.get(Constants.CONTEXT)),
|
||||
PluginParseResult.class);
|
||||
WebPageResp webPageResponse = buildResponse(pluginParseResult);
|
||||
queryResult.setResponse(webPageResponse);
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.build.webpage;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@Data
|
||||
public class WebPageResp {
|
||||
|
||||
@@ -21,5 +19,4 @@ public class WebPageResp {
|
||||
private WebBase webPage;
|
||||
|
||||
private List<WebBase> moreWebPage;
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.build.webservice;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
@@ -11,7 +12,6 @@ import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.http.HttpEntity;
|
||||
@@ -28,7 +28,6 @@ import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class WebServiceQuery extends PluginSemanticQuery {
|
||||
@@ -46,15 +45,18 @@ public class WebServiceQuery extends PluginSemanticQuery {
|
||||
QueryResult queryResult = new QueryResult();
|
||||
queryResult.setQueryMode(QUERY_MODE);
|
||||
Map<String, Object> properties = parseInfo.getProperties();
|
||||
PluginParseResult pluginParseResult = JsonUtil.toObject(
|
||||
JsonUtil.toString(properties.get(Constants.CONTEXT)), PluginParseResult.class);
|
||||
PluginParseResult pluginParseResult =
|
||||
JsonUtil.toObject(
|
||||
JsonUtil.toString(properties.get(Constants.CONTEXT)),
|
||||
PluginParseResult.class);
|
||||
WebServiceResp webServiceResponse = buildResponse(pluginParseResult);
|
||||
Object object = webServiceResponse.getResult();
|
||||
// in order to show webServiceQuery result int frontend conveniently,
|
||||
// webServiceResponse result format is consistent with queryByStruct result.
|
||||
log.info("webServiceResponse result:{}", JsonUtil.toString(object));
|
||||
try {
|
||||
Map<String, Object> data = JsonUtil.toMap(JsonUtil.toString(object), String.class, Object.class);
|
||||
Map<String, Object> data =
|
||||
JsonUtil.toMap(JsonUtil.toString(object), String.class, Object.class);
|
||||
if (data.get("resultList") != null) {
|
||||
queryResult.setQueryResults((List<Map<String, Object>>) data.get("resultList"));
|
||||
}
|
||||
@@ -72,7 +74,9 @@ public class WebServiceQuery extends PluginSemanticQuery {
|
||||
protected WebServiceResp buildResponse(PluginParseResult pluginParseResult) {
|
||||
WebServiceResp webServiceResponse = new WebServiceResp();
|
||||
ChatPlugin plugin = pluginParseResult.getPlugin();
|
||||
WebBase webBase = fillWebBaseResult(JsonUtil.toObject(plugin.getConfig(), WebBase.class), pluginParseResult);
|
||||
WebBase webBase =
|
||||
fillWebBaseResult(
|
||||
JsonUtil.toObject(plugin.getConfig(), WebBase.class), pluginParseResult);
|
||||
webServiceResponse.setWebBase(webBase);
|
||||
List<ParamOption> paramOptions = webBase.getParamOptions();
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
@@ -86,7 +90,8 @@ public class WebServiceQuery extends PluginSemanticQuery {
|
||||
Object objectResponse = null;
|
||||
restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
try {
|
||||
responseEntity = restTemplate.exchange(requestUrl, HttpMethod.POST, entity, Object.class);
|
||||
responseEntity =
|
||||
restTemplate.exchange(requestUrl, HttpMethod.POST, entity, Object.class);
|
||||
objectResponse = responseEntity.getBody();
|
||||
log.info("objectResponse:{}", objectResponse);
|
||||
Map<String, Object> response = JsonUtil.objectToMap(objectResponse);
|
||||
@@ -96,5 +101,4 @@ public class WebServiceQuery extends PluginSemanticQuery {
|
||||
}
|
||||
return webServiceResponse;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.build.webservice;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||
import lombok.Data;
|
||||
|
||||
|
||||
@Data
|
||||
public class WebServiceResp {
|
||||
|
||||
private WebBase webBase;
|
||||
|
||||
private Object result;
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.event;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
||||
import org.springframework.context.ApplicationEvent;
|
||||
|
||||
|
||||
@@ -22,5 +22,4 @@ public class PluginUpdateEvent extends ApplicationEvent {
|
||||
public ChatPlugin getNewPlugin() {
|
||||
return newPlugin;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -23,9 +23,7 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* PluginParser defines the basic process and common methods for recalling plugins.
|
||||
*/
|
||||
/** PluginParser defines the basic process and common methods for recalling plugins. */
|
||||
public abstract class PluginRecognizer {
|
||||
|
||||
public void recognize(ParseContext parseContext, ParseResp parseResp) {
|
||||
@@ -43,16 +41,17 @@ public abstract class PluginRecognizer {
|
||||
|
||||
public abstract PluginRecallResult recallPlugin(ParseContext parseContext);
|
||||
|
||||
public void buildQuery(ParseContext parseContext, ParseResp parseResp,
|
||||
PluginRecallResult pluginRecallResult) {
|
||||
public void buildQuery(
|
||||
ParseContext parseContext, ParseResp parseResp, PluginRecallResult pluginRecallResult) {
|
||||
ChatPlugin plugin = pluginRecallResult.getPlugin();
|
||||
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
|
||||
if (plugin.isContainsAllDataSet()) {
|
||||
dataSetIds = Sets.newHashSet(-1L);
|
||||
}
|
||||
for (Long dataSetId : dataSetIds) {
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
|
||||
parseContext, pluginRecallResult.getDistance());
|
||||
SemanticParseInfo semanticParseInfo =
|
||||
buildSemanticParseInfo(
|
||||
dataSetId, plugin, parseContext, pluginRecallResult.getDistance());
|
||||
semanticParseInfo.setQueryMode(plugin.getType());
|
||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||
parseResp.getSelectedParses().add(semanticParseInfo);
|
||||
@@ -63,9 +62,10 @@ public abstract class PluginRecognizer {
|
||||
return PluginManager.getPluginAgentCanSupport(parseContext);
|
||||
}
|
||||
|
||||
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin,
|
||||
ParseContext parseContext, double distance) {
|
||||
List<SchemaElementMatch> schemaElementMatches = parseContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
protected SemanticParseInfo buildSemanticParseInfo(
|
||||
Long dataSetId, ChatPlugin plugin, ParseContext parseContext, double distance) {
|
||||
List<SchemaElementMatch> schemaElementMatches =
|
||||
parseContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
QueryFilters queryFilters = parseContext.getQueryFilters();
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = Lists.newArrayList();
|
||||
@@ -96,18 +96,22 @@ public abstract class PluginRecognizer {
|
||||
if (CollectionUtils.isEmpty(schemaElementMatches)) {
|
||||
return;
|
||||
}
|
||||
schemaElementMatches.stream().filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
|
||||
.forEach(schemaElementMatch -> {
|
||||
QueryFilter queryFilter = new QueryFilter();
|
||||
queryFilter.setValue(schemaElementMatch.getWord());
|
||||
queryFilter.setElementID(schemaElementMatch.getElement().getId());
|
||||
queryFilter.setName(schemaElementMatch.getElement().getName());
|
||||
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
|
||||
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
|
||||
semanticParseInfo.getDimensionFilters().add(queryFilter);
|
||||
});
|
||||
schemaElementMatches.stream()
|
||||
.filter(
|
||||
schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(
|
||||
schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.ID.equals(
|
||||
schemaElementMatch.getElement().getType()))
|
||||
.forEach(
|
||||
schemaElementMatch -> {
|
||||
QueryFilter queryFilter = new QueryFilter();
|
||||
queryFilter.setValue(schemaElementMatch.getWord());
|
||||
queryFilter.setElementID(schemaElementMatch.getElement().getId());
|
||||
queryFilter.setName(schemaElementMatch.getElement().getName());
|
||||
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
|
||||
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
|
||||
semanticParseInfo.getDimensionFilters().add(queryFilter);
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.recognize.embedding;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.server.plugin.ParseMode;
|
||||
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.ParseMode;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
||||
@@ -20,9 +20,7 @@ import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* EmbeddingRecallParser is an implementation of a recall plugin based on Embedding
|
||||
*/
|
||||
/** EmbeddingRecallParser is an implementation of a recall plugin based on Embedding */
|
||||
@Slf4j
|
||||
public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
||||
|
||||
@@ -38,7 +36,8 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
||||
return null;
|
||||
}
|
||||
List<ChatPlugin> plugins = getPluginList(parseContext);
|
||||
Map<Long, ChatPlugin> pluginMap = plugins.stream().collect(Collectors.toMap(ChatPlugin::getId, p -> p));
|
||||
Map<Long, ChatPlugin> pluginMap =
|
||||
plugins.stream().collect(Collectors.toMap(ChatPlugin::getId, p -> p));
|
||||
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
ChatPlugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
|
||||
if (plugin == null) {
|
||||
@@ -55,7 +54,11 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
||||
double distance = embeddingRetrieval.getDistance();
|
||||
double score = parseContext.getQueryText().length() * (1 - distance);
|
||||
return PluginRecallResult.builder()
|
||||
.plugin(plugin).dataSetIds(dataSetList).score(score).distance(distance).build();
|
||||
.plugin(plugin)
|
||||
.dataSetIds(dataSetList)
|
||||
.score(score)
|
||||
.distance(distance)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
@@ -68,8 +71,10 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
||||
|
||||
List<Retrieval> embeddingRetrievals = embeddingResp.getRetrieval();
|
||||
if (!CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||
embeddingRetrievals = embeddingRetrievals.stream().sorted(Comparator.comparingDouble(o ->
|
||||
Math.abs(o.getDistance()))).collect(Collectors.toList());
|
||||
embeddingRetrievals =
|
||||
embeddingRetrievals.stream()
|
||||
.sorted(Comparator.comparingDouble(o -> Math.abs(o.getDistance())))
|
||||
.collect(Collectors.toList());
|
||||
embeddingResp.setRetrieval(embeddingRetrievals);
|
||||
}
|
||||
return embeddingRetrievals;
|
||||
@@ -78,5 +83,4 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
||||
}
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.recognize.embedding;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@@ -15,5 +14,4 @@ public class RecallRetrieval {
|
||||
private String presetId;
|
||||
|
||||
private String query;
|
||||
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.recognize.embedding;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@@ -10,6 +10,4 @@ public class RecallRetrievalResp {
|
||||
private String query;
|
||||
|
||||
private List<RecallRetrieval> retrieval;
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.server.pojo;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import lombok.Data;
|
||||
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
package com.tencent.supersonic.chat.server.processor;
|
||||
|
||||
/**
|
||||
* A ResultProcessor wraps things up before returning results to users.
|
||||
*/
|
||||
public interface ResultProcessor {
|
||||
|
||||
}
|
||||
/** A ResultProcessor wraps things up before returning results to users. */
|
||||
public interface ResultProcessor {}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -8,7 +9,6 @@ import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.RelatedSchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -20,8 +20,7 @@ import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* DimensionRecommendProcessor recommend some dimensions
|
||||
* related to metrics based on configuration
|
||||
* DimensionRecommendProcessor recommend some dimensions related to metrics based on configuration
|
||||
*/
|
||||
public class DimensionRecommendProcessor implements ExecuteResultProcessor {
|
||||
|
||||
@@ -50,13 +49,20 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
|
||||
List<Long> drillDownDimensions = Lists.newArrayList();
|
||||
Set<SchemaElement> metricElements = dataSetSchema.getMetrics();
|
||||
if (!CollectionUtils.isEmpty(metricElements)) {
|
||||
Optional<SchemaElement> metric = metricElements.stream().filter(schemaElement ->
|
||||
metricId.equals(schemaElement.getId())
|
||||
&& !CollectionUtils.isEmpty(schemaElement.getRelatedSchemaElements()))
|
||||
.findFirst();
|
||||
Optional<SchemaElement> metric =
|
||||
metricElements.stream()
|
||||
.filter(
|
||||
schemaElement ->
|
||||
metricId.equals(schemaElement.getId())
|
||||
&& !CollectionUtils.isEmpty(
|
||||
schemaElement
|
||||
.getRelatedSchemaElements()))
|
||||
.findFirst();
|
||||
if (metric.isPresent()) {
|
||||
drillDownDimensions = metric.get().getRelatedSchemaElements().stream()
|
||||
.map(RelatedSchemaElement::getDimensionId).collect(Collectors.toList());
|
||||
drillDownDimensions =
|
||||
metric.get().getRelatedSchemaElements().stream()
|
||||
.map(RelatedSchemaElement::getDimensionId)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
final List<Long> drillDownDimensionsFinal = drillDownDimensions;
|
||||
@@ -76,5 +82,4 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
|
||||
}
|
||||
return Objects.nonNull(dimension.getUseCnt());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
|
||||
/**
|
||||
* A ExecuteResultProcessor wraps things up before returning
|
||||
* execution results to the users.
|
||||
*/
|
||||
/** A ExecuteResultProcessor wraps things up before returning execution results to the users. */
|
||||
public interface ExecuteResultProcessor extends ResultProcessor {
|
||||
|
||||
void process(ExecuteContext executeContext, QueryResult queryResult);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT_INT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT_INT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TIMES_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
@@ -20,17 +12,19 @@ import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.RatioOverType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.AggregateInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.MetricInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.core.config.AggregatorConfig;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.headless.core.config.AggregatorConfig;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.text.DecimalFormat;
|
||||
import java.time.DayOfWeek;
|
||||
import java.time.LocalDate;
|
||||
@@ -49,13 +43,17 @@ import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT_INT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT_INT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TIMES_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
|
||||
|
||||
/**
|
||||
* Add ratio queries for metric queries.
|
||||
*/
|
||||
/** Add ratio queries for metric queries. */
|
||||
@Slf4j
|
||||
public class MetricRatioProcessor implements ExecuteResultProcessor {
|
||||
|
||||
@@ -68,18 +66,24 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
|
||||
|| !QueryType.METRIC.equals(semanticParseInfo.getQueryType())) {
|
||||
return;
|
||||
}
|
||||
AggregateInfo aggregateInfo = getAggregateInfo(executeContext.getUser(),
|
||||
semanticParseInfo, queryResult);
|
||||
AggregateInfo aggregateInfo =
|
||||
getAggregateInfo(executeContext.getUser(), semanticParseInfo, queryResult);
|
||||
queryResult.setAggregateInfo(aggregateInfo);
|
||||
}
|
||||
|
||||
public AggregateInfo getAggregateInfo(User user, SemanticParseInfo semanticParseInfo, QueryResult queryResult) {
|
||||
public AggregateInfo getAggregateInfo(
|
||||
User user, SemanticParseInfo semanticParseInfo, QueryResult queryResult) {
|
||||
|
||||
Set<String> resultMetricNames = new HashSet<>();
|
||||
queryResult.getQueryColumns()
|
||||
.stream().forEach(c -> resultMetricNames.addAll(SqlSelectHelper.getColumnFromExpr(c.getNameEn())));
|
||||
Optional<SchemaElement> ratioMetric = semanticParseInfo.getMetrics().stream()
|
||||
.filter(m -> resultMetricNames.contains(m.getBizName())).findFirst();
|
||||
queryResult.getQueryColumns().stream()
|
||||
.forEach(
|
||||
c ->
|
||||
resultMetricNames.addAll(
|
||||
SqlSelectHelper.getColumnFromExpr(c.getNameEn())));
|
||||
Optional<SchemaElement> ratioMetric =
|
||||
semanticParseInfo.getMetrics().stream()
|
||||
.filter(m -> resultMetricNames.contains(m.getBizName()))
|
||||
.findFirst();
|
||||
|
||||
AggregateInfo aggregateInfo = new AggregateInfo();
|
||||
if (!ratioMetric.isPresent()) {
|
||||
@@ -88,31 +92,48 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
|
||||
|
||||
try {
|
||||
String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo());
|
||||
Optional<String> lastDayOp = queryResult.getQueryResults().stream()
|
||||
.filter(r -> r.containsKey(dateField))
|
||||
.map(r -> r.get(dateField).toString())
|
||||
.sorted(Comparator.reverseOrder()).findFirst();
|
||||
Optional<String> lastDayOp =
|
||||
queryResult.getQueryResults().stream()
|
||||
.filter(r -> r.containsKey(dateField))
|
||||
.map(r -> r.get(dateField).toString())
|
||||
.sorted(Comparator.reverseOrder())
|
||||
.findFirst();
|
||||
|
||||
if (!lastDayOp.isPresent()) {
|
||||
return new AggregateInfo();
|
||||
}
|
||||
Optional<Map<String, Object>> lastValue = queryResult.getQueryResults().stream()
|
||||
.filter(r -> r.get(dateField).toString().equals(lastDayOp.get())).findFirst();
|
||||
Optional<Map<String, Object>> lastValue =
|
||||
queryResult.getQueryResults().stream()
|
||||
.filter(r -> r.get(dateField).toString().equals(lastDayOp.get()))
|
||||
.findFirst();
|
||||
|
||||
MetricInfo metricInfo = new MetricInfo();
|
||||
metricInfo.setStatistics(new HashMap<>());
|
||||
if (lastValue.isPresent() && lastValue.get().containsKey(ratioMetric.get().getBizName())) {
|
||||
if (lastValue.isPresent()
|
||||
&& lastValue.get().containsKey(ratioMetric.get().getBizName())) {
|
||||
DecimalFormat df = new DecimalFormat("#.####");
|
||||
metricInfo.setValue(df.format(lastValue.get().get(ratioMetric.get().getBizName())));
|
||||
}
|
||||
metricInfo.setDate(lastValue.get().get(dateField).toString());
|
||||
|
||||
CompletableFuture<MetricInfo> metricInfoRoll = CompletableFuture.supplyAsync(
|
||||
() -> queryRatio(user, semanticParseInfo, ratioMetric.get(), AggOperatorEnum.RATIO_ROLL,
|
||||
queryResult));
|
||||
CompletableFuture<MetricInfo> metricInfoOver = CompletableFuture.supplyAsync(
|
||||
() -> queryRatio(user, semanticParseInfo, ratioMetric.get(), AggOperatorEnum.RATIO_OVER,
|
||||
queryResult));
|
||||
CompletableFuture<MetricInfo> metricInfoRoll =
|
||||
CompletableFuture.supplyAsync(
|
||||
() ->
|
||||
queryRatio(
|
||||
user,
|
||||
semanticParseInfo,
|
||||
ratioMetric.get(),
|
||||
AggOperatorEnum.RATIO_ROLL,
|
||||
queryResult));
|
||||
CompletableFuture<MetricInfo> metricInfoOver =
|
||||
CompletableFuture.supplyAsync(
|
||||
() ->
|
||||
queryRatio(
|
||||
user,
|
||||
semanticParseInfo,
|
||||
ratioMetric.get(),
|
||||
AggOperatorEnum.RATIO_OVER,
|
||||
queryResult));
|
||||
CompletableFuture.allOf(metricInfoRoll, metricInfoOver);
|
||||
metricInfo.setName(metricInfoRoll.get().getName());
|
||||
metricInfo.setValue(metricInfoRoll.get().getValue());
|
||||
@@ -126,13 +147,19 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
|
||||
}
|
||||
|
||||
@SneakyThrows
|
||||
private MetricInfo queryRatio(User user, SemanticParseInfo semanticParseInfo, SchemaElement metric,
|
||||
AggOperatorEnum aggOperatorEnum, QueryResult queryResult) {
|
||||
private MetricInfo queryRatio(
|
||||
User user,
|
||||
SemanticParseInfo semanticParseInfo,
|
||||
SchemaElement metric,
|
||||
AggOperatorEnum aggOperatorEnum,
|
||||
QueryResult queryResult) {
|
||||
|
||||
QueryStructReq queryStructReq = QueryReqBuilder.buildStructRatioReq(semanticParseInfo, metric, aggOperatorEnum);
|
||||
QueryStructReq queryStructReq =
|
||||
QueryReqBuilder.buildStructRatioReq(semanticParseInfo, metric, aggOperatorEnum);
|
||||
String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo());
|
||||
queryStructReq.setGroups(new ArrayList<>(Arrays.asList(dateField)));
|
||||
queryStructReq.setDateInfo(getRatioDateConf(aggOperatorEnum, semanticParseInfo, queryResult));
|
||||
queryStructReq.setDateInfo(
|
||||
getRatioDateConf(aggOperatorEnum, semanticParseInfo, queryResult));
|
||||
queryStructReq.setConvertToSql(false);
|
||||
SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class);
|
||||
SemanticQueryResp queryResp = queryService.queryByReq(queryStructReq, user);
|
||||
@@ -143,21 +170,26 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
|
||||
}
|
||||
|
||||
Map<String, Object> result = queryResp.getResultList().get(0);
|
||||
Optional<QueryColumn> valueColumn = queryResp.getColumns().stream()
|
||||
.filter(c -> c.getNameEn().equals(metric.getBizName())).findFirst();
|
||||
Optional<QueryColumn> valueColumn =
|
||||
queryResp.getColumns().stream()
|
||||
.filter(c -> c.getNameEn().equals(metric.getBizName()))
|
||||
.findFirst();
|
||||
|
||||
if (!valueColumn.isPresent()) {
|
||||
return metricInfo;
|
||||
}
|
||||
String valueField = String.format("%s_%s", valueColumn.get().getNameEn(), aggOperatorEnum.getOperator());
|
||||
String valueField =
|
||||
String.format(
|
||||
"%s_%s", valueColumn.get().getNameEn(), aggOperatorEnum.getOperator());
|
||||
if (result.containsKey(valueColumn.get().getNameEn())) {
|
||||
DecimalFormat df = new DecimalFormat("#.####");
|
||||
metricInfo.setValue(df.format(result.get(valueColumn.get().getNameEn())));
|
||||
}
|
||||
String ratio = "";
|
||||
if (Objects.nonNull(result.get(valueField))) {
|
||||
ratio = String.format("%.2f",
|
||||
(Double.valueOf(result.get(valueField).toString()) * 100)) + "%";
|
||||
ratio =
|
||||
String.format("%.2f", (Double.valueOf(result.get(valueField).toString()) * 100))
|
||||
+ "%";
|
||||
}
|
||||
String statisticsRollName = RatioOverType.DAY_ON_DAY.getShowName();
|
||||
String statisticsOverName = RatioOverType.WEEK_ON_DAY.getShowName();
|
||||
@@ -169,19 +201,28 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
|
||||
statisticsRollName = RatioOverType.WEEK_ON_WEEK.getShowName();
|
||||
statisticsOverName = RatioOverType.MONTH_ON_WEEK.getShowName();
|
||||
}
|
||||
metricInfo.getStatistics().put(aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL)
|
||||
? statisticsRollName : statisticsOverName, ratio);
|
||||
metricInfo
|
||||
.getStatistics()
|
||||
.put(
|
||||
aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL)
|
||||
? statisticsRollName
|
||||
: statisticsOverName,
|
||||
ratio);
|
||||
metricInfo.setName(metric.getName());
|
||||
return metricInfo;
|
||||
}
|
||||
|
||||
private DateConf getRatioDateConf(AggOperatorEnum aggOperatorEnum, SemanticParseInfo semanticParseInfo,
|
||||
private DateConf getRatioDateConf(
|
||||
AggOperatorEnum aggOperatorEnum,
|
||||
SemanticParseInfo semanticParseInfo,
|
||||
QueryResult queryResult) {
|
||||
String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo());
|
||||
|
||||
Optional<String> lastDayOp = queryResult.getQueryResults()
|
||||
.stream().map(r -> r.get(dateField).toString())
|
||||
.sorted(Comparator.reverseOrder()).findFirst();
|
||||
Optional<String> lastDayOp =
|
||||
queryResult.getQueryResults().stream()
|
||||
.map(r -> r.get(dateField).toString())
|
||||
.sorted(Comparator.reverseOrder())
|
||||
.findFirst();
|
||||
|
||||
if (!lastDayOp.isPresent()) {
|
||||
return semanticParseInfo.getDateInfo();
|
||||
@@ -194,29 +235,37 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
|
||||
dayList.add(lastDay);
|
||||
String start = "";
|
||||
if (DAY.equalsIgnoreCase(semanticParseInfo.getDateInfo().getPeriod())) {
|
||||
DateTimeFormatter formatter = DateUtils.getDateFormatter(lastDay,
|
||||
new String[]{DAY_FORMAT, DAY_FORMAT_INT});
|
||||
DateTimeFormatter formatter =
|
||||
DateUtils.getDateFormatter(lastDay, new String[] {DAY_FORMAT, DAY_FORMAT_INT});
|
||||
LocalDate end = LocalDate.parse(lastDay, formatter);
|
||||
start = aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL) ? end.minusDays(1).format(formatter)
|
||||
: end.minusWeeks(1).format(formatter);
|
||||
start =
|
||||
aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL)
|
||||
? end.minusDays(1).format(formatter)
|
||||
: end.minusWeeks(1).format(formatter);
|
||||
}
|
||||
if (WEEK.equalsIgnoreCase(semanticParseInfo.getDateInfo().getPeriod())) {
|
||||
DateTimeFormatter formatter = DateUtils.getTimeFormatter(lastDay,
|
||||
new String[]{TIMES_FORMAT, DAY_FORMAT, TIME_FORMAT, DAY_FORMAT_INT});
|
||||
DateTimeFormatter formatter =
|
||||
DateUtils.getTimeFormatter(
|
||||
lastDay,
|
||||
new String[] {TIMES_FORMAT, DAY_FORMAT, TIME_FORMAT, DAY_FORMAT_INT});
|
||||
LocalDateTime end = LocalDateTime.parse(lastDay, formatter);
|
||||
start = aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL) ? end.minusWeeks(1).format(formatter)
|
||||
: end.minusMonths(1).with(DayOfWeek.MONDAY).format(formatter);
|
||||
start =
|
||||
aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL)
|
||||
? end.minusWeeks(1).format(formatter)
|
||||
: end.minusMonths(1).with(DayOfWeek.MONDAY).format(formatter);
|
||||
}
|
||||
if (MONTH.equalsIgnoreCase(semanticParseInfo.getDateInfo().getPeriod())) {
|
||||
DateTimeFormatter formatter = DateUtils.getDateFormatter(lastDay,
|
||||
new String[]{MONTH_FORMAT, MONTH_FORMAT_INT});
|
||||
DateTimeFormatter formatter =
|
||||
DateUtils.getDateFormatter(
|
||||
lastDay, new String[] {MONTH_FORMAT, MONTH_FORMAT_INT});
|
||||
YearMonth end = YearMonth.parse(lastDay, formatter);
|
||||
start = aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL) ? end.minusMonths(1).format(formatter)
|
||||
: end.minusYears(1).format(formatter);
|
||||
start =
|
||||
aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL)
|
||||
? end.minusMonths(1).format(formatter)
|
||||
: end.minusYears(1).format(formatter);
|
||||
}
|
||||
dayList.add(start);
|
||||
dateConf.setDateList(dayList);
|
||||
return dateConf;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -26,9 +26,7 @@ import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* MetricRecommendProcessor fills recommended metrics based on embedding similarity.
|
||||
*/
|
||||
/** MetricRecommendProcessor fills recommended metrics based on embedding similarity. */
|
||||
public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
||||
|
||||
private static final int METRIC_RECOMMEND_SIZE = 5;
|
||||
@@ -44,24 +42,36 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
||||
|| CollectionUtils.isEmpty(parseInfo.getMetrics())) {
|
||||
return;
|
||||
}
|
||||
List<String> metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
|
||||
List<String> metricNames =
|
||||
Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
|
||||
Map<String, Object> filterCondition = new HashMap<>();
|
||||
filterCondition.put("modelId", parseInfo.getMetrics().iterator().next().getDataSetId().toString());
|
||||
filterCondition.put(
|
||||
"modelId", parseInfo.getMetrics().iterator().next().getDataSetId().toString());
|
||||
filterCondition.put("type", SchemaElementType.METRIC.name());
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)
|
||||
.filterCondition(filterCondition).queryEmbeddings(null).build();
|
||||
MetaEmbeddingService metaEmbeddingService = ContextUtils.getBean(MetaEmbeddingService.class);
|
||||
RetrieveQuery retrieveQuery =
|
||||
RetrieveQuery.builder()
|
||||
.queryTextsList(metricNames)
|
||||
.filterCondition(filterCondition)
|
||||
.queryEmbeddings(null)
|
||||
.build();
|
||||
MetaEmbeddingService metaEmbeddingService =
|
||||
ContextUtils.getBean(MetaEmbeddingService.class);
|
||||
List<RetrieveQueryResult> retrieveQueryResults =
|
||||
metaEmbeddingService.retrieveQuery(retrieveQuery, METRIC_RECOMMEND_SIZE + 1, new HashMap<>(),
|
||||
new HashSet<>());
|
||||
metaEmbeddingService.retrieveQuery(
|
||||
retrieveQuery, METRIC_RECOMMEND_SIZE + 1, new HashMap<>(), new HashSet<>());
|
||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||
return;
|
||||
}
|
||||
List<Retrieval> retrievals = retrieveQueryResults.stream()
|
||||
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream())
|
||||
.sorted(Comparator.comparingDouble(Retrieval::getDistance))
|
||||
.distinct().collect(Collectors.toList());
|
||||
Set<Long> metricIds = parseInfo.getMetrics().stream().map(SchemaElement::getId).collect(Collectors.toSet());
|
||||
List<Retrieval> retrievals =
|
||||
retrieveQueryResults.stream()
|
||||
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream())
|
||||
.sorted(Comparator.comparingDouble(Retrieval::getDistance))
|
||||
.distinct()
|
||||
.collect(Collectors.toList());
|
||||
Set<Long> metricIds =
|
||||
parseInfo.getMetrics().stream()
|
||||
.map(SchemaElement::getId)
|
||||
.collect(Collectors.toSet());
|
||||
int metricOrder = 0;
|
||||
for (SchemaElement metric : parseInfo.getMetrics()) {
|
||||
metric.setOrder(metricOrder++);
|
||||
@@ -69,15 +79,23 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
||||
for (Retrieval retrieval : retrievals) {
|
||||
if (!metricIds.contains(Retrieval.getLongId(retrieval.getId()))) {
|
||||
if (Objects.nonNull(retrieval.getMetadata().get("id"))) {
|
||||
String idStr = retrieval.getMetadata().get("id").toString()
|
||||
.replaceAll(DictWordType.NATURE_SPILT, "");
|
||||
String idStr =
|
||||
retrieval
|
||||
.getMetadata()
|
||||
.get("id")
|
||||
.toString()
|
||||
.replaceAll(DictWordType.NATURE_SPILT, "");
|
||||
retrieval.getMetadata().put("id", idStr);
|
||||
}
|
||||
String metaStr = JSONObject.toJSONString(retrieval.getMetadata());
|
||||
SchemaElement schemaElement = JSONObject.parseObject(metaStr, SchemaElement.class);
|
||||
if (retrieval.getMetadata().containsKey("dataSetId")) {
|
||||
String dataSetId = retrieval.getMetadata().get("dataSetId").toString()
|
||||
.replace(Constants.UNDERLINE, "");
|
||||
String dataSetId =
|
||||
retrieval
|
||||
.getMetadata()
|
||||
.get("dataSetId")
|
||||
.toString()
|
||||
.replace(Constants.UNDERLINE, "");
|
||||
schemaElement.setDataSetId(Long.parseLong(dataSetId));
|
||||
}
|
||||
schemaElement.setOrder(++metricOrder);
|
||||
@@ -85,5 +103,4 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -4,12 +4,8 @@ import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
/**
|
||||
* A ParseResultProcessor wraps things up before returning
|
||||
* parsing results to the users.
|
||||
*/
|
||||
/** A ParseResultProcessor wraps things up before returning parsing results to the users. */
|
||||
public interface ParseResultProcessor extends ResultProcessor {
|
||||
|
||||
void process(ParseContext parseContext, ParseResp parseResp);
|
||||
|
||||
}
|
||||
|
||||
@@ -18,9 +18,7 @@ import java.util.List;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* MetricRecommendProcessor fills recommended query based on embedding similarity.
|
||||
*/
|
||||
/** MetricRecommendProcessor fills recommended query based on embedding similarity. */
|
||||
@Slf4j
|
||||
public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
|
||||
@@ -32,8 +30,8 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
@SneakyThrows
|
||||
private void doProcess(ParseResp parseResp, ParseContext parseContext) {
|
||||
Long queryId = parseResp.getQueryId();
|
||||
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(parseContext.getQueryText(),
|
||||
parseContext.getAgent().getId());
|
||||
List<SimilarQueryRecallResp> solvedQueries =
|
||||
getSimilarQueries(parseContext.getQueryText(), parseContext.getAgent().getId());
|
||||
ChatQueryDO chatQueryDO = getChatQuery(queryId);
|
||||
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
|
||||
updateChatQuery(chatQueryDO);
|
||||
@@ -43,9 +41,14 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class);
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
||||
List<Text2SQLExemplar> exemplars = exemplarService.recallExemplars(memoryCollectionName, queryText, 5);
|
||||
return exemplars.stream().map(sqlExemplar ->
|
||||
SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build())
|
||||
List<Text2SQLExemplar> exemplars =
|
||||
exemplarService.recallExemplars(memoryCollectionName, queryText, 5);
|
||||
return exemplars.stream()
|
||||
.map(
|
||||
sqlExemplar ->
|
||||
SimilarQueryRecallResp.builder()
|
||||
.queryText(sqlExemplar.getQuestion())
|
||||
.build())
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@@ -61,5 +64,4 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
updateWrapper.set("similar_queries", chatQueryDO.getSimilarQueries());
|
||||
chatQueryRepository.updateChatQuery(chatQueryDO, updateWrapper);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
/**
|
||||
* TimeCostProcessor adds time cost of parsing.
|
||||
**/
|
||||
/** TimeCostProcessor adds time cost of parsing. */
|
||||
@Slf4j
|
||||
public class TimeCostProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ParseContext parseContext, ParseResp parseResp) {
|
||||
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
|
||||
parseResp.getParseTimeCost().setParseTime(
|
||||
System.currentTimeMillis() - parseStartTime - parseResp.getParseTimeCost().getSqlTime());
|
||||
parseResp
|
||||
.getParseTimeCost()
|
||||
.setParseTime(
|
||||
System.currentTimeMillis()
|
||||
- parseStartTime
|
||||
- parseResp.getParseTimeCost().getSqlTime());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.rest;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
@@ -16,8 +19,6 @@ import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@@ -25,21 +26,22 @@ import java.util.Map;
|
||||
@RequestMapping({"/api/chat/agent", "/openapi/chat/agent"})
|
||||
public class AgentController {
|
||||
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
@Autowired private AgentService agentService;
|
||||
|
||||
@PostMapping
|
||||
public Agent createAgent(@RequestBody Agent agent,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
public Agent createAgent(
|
||||
@RequestBody Agent agent,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
return agentService.createAgent(agent, user);
|
||||
}
|
||||
|
||||
@PutMapping
|
||||
public Agent updateAgent(@RequestBody Agent agent,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
public Agent updateAgent(
|
||||
@RequestBody Agent agent,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
return agentService.updateAgent(agent, user);
|
||||
}
|
||||
@@ -64,5 +66,4 @@ public class AgentController {
|
||||
public Map<AgentToolType, String> getToolTypes() {
|
||||
return AgentToolType.getToolTypes();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.rest;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
|
||||
@@ -20,41 +23,39 @@ import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@RestController
|
||||
@RequestMapping({"/api/chat/conf", "/openapi/chat/conf"})
|
||||
public class ChatConfigController {
|
||||
|
||||
@Autowired
|
||||
private ConfigService configService;
|
||||
@Autowired private ConfigService configService;
|
||||
|
||||
@Autowired
|
||||
private SemanticLayerService semanticLayerService;
|
||||
@Autowired private SemanticLayerService semanticLayerService;
|
||||
|
||||
@PostMapping
|
||||
public Long addChatConfig(@RequestBody ChatConfigBaseReq extendBaseCmd,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
public Long addChatConfig(
|
||||
@RequestBody ChatConfigBaseReq extendBaseCmd,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
return configService.addConfig(extendBaseCmd, user);
|
||||
}
|
||||
|
||||
@PutMapping
|
||||
public Long editModelExtend(@RequestBody ChatConfigEditReqReq extendEditCmd,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
public Long editModelExtend(
|
||||
@RequestBody ChatConfigEditReqReq extendEditCmd,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
return configService.editConfig(extendEditCmd, user);
|
||||
}
|
||||
|
||||
@PostMapping("/search")
|
||||
public List<ChatConfigResp> search(@RequestBody ChatConfigFilter filter,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
public List<ChatConfigResp> search(
|
||||
@RequestBody ChatConfigFilter filter,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
return configService.search(filter, user);
|
||||
}
|
||||
@@ -78,5 +79,4 @@ public class ChatConfigController {
|
||||
public DataSetSchema getDataSetSchema(@PathVariable("id") Long id) {
|
||||
return semanticLayerService.getDataSetSchema(id);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.rest;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
@@ -17,65 +19,72 @@ import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.util.List;
|
||||
|
||||
@RestController
|
||||
@RequestMapping({"/api/chat/manage", "/openapi/chat/manage"})
|
||||
public class ChatController {
|
||||
|
||||
@Autowired
|
||||
private ChatManageService chatService;
|
||||
@Autowired private ChatManageService chatService;
|
||||
|
||||
@PostMapping("/save")
|
||||
public Boolean save(@RequestParam(value = "chatName") String chatName,
|
||||
@RequestParam(value = "agentId", required = false) Integer agentId,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
public Boolean save(
|
||||
@RequestParam(value = "chatName") String chatName,
|
||||
@RequestParam(value = "agentId", required = false) Integer agentId,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
chatService.addChat(UserHolder.findUser(request, response), chatName, agentId);
|
||||
return true;
|
||||
}
|
||||
|
||||
@GetMapping("/getAll")
|
||||
public List<ChatDO> getAllConversions(@RequestParam(value = "agentId", required = false) Integer agentId,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
public List<ChatDO> getAllConversions(
|
||||
@RequestParam(value = "agentId", required = false) Integer agentId,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
String userName = UserHolder.findUser(request, response).getName();
|
||||
return chatService.getAll(userName, agentId);
|
||||
}
|
||||
|
||||
@PostMapping("/delete")
|
||||
public Boolean deleteConversion(@RequestParam(value = "chatId") long chatId,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
public Boolean deleteConversion(
|
||||
@RequestParam(value = "chatId") long chatId,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
String userName = UserHolder.findUser(request, response).getName();
|
||||
return chatService.deleteChat(chatId, userName);
|
||||
}
|
||||
|
||||
@PostMapping("/updateChatName")
|
||||
public Boolean updateConversionName(@RequestParam(value = "chatId") Long chatId,
|
||||
public Boolean updateConversionName(
|
||||
@RequestParam(value = "chatId") Long chatId,
|
||||
@RequestParam(value = "chatName") String chatName,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
String userName = UserHolder.findUser(request, response).getName();
|
||||
return chatService.updateChatName(chatId, chatName, userName);
|
||||
}
|
||||
|
||||
@PostMapping("/updateQAFeedback")
|
||||
public Boolean updateQAFeedback(@RequestParam(value = "id") Integer id,
|
||||
public Boolean updateQAFeedback(
|
||||
@RequestParam(value = "id") Integer id,
|
||||
@RequestParam(value = "score") Integer score,
|
||||
@RequestParam(value = "feedback", required = false) String feedback) {
|
||||
return chatService.updateFeedback(id, score, feedback);
|
||||
}
|
||||
|
||||
@PostMapping("/updateChatIsTop")
|
||||
public Boolean updateConversionIsTop(@RequestParam(value = "chatId") Long chatId,
|
||||
@RequestParam(value = "isTop") int isTop) {
|
||||
public Boolean updateConversionIsTop(
|
||||
@RequestParam(value = "chatId") Long chatId, @RequestParam(value = "isTop") int isTop) {
|
||||
return chatService.updateChatIsTop(chatId, isTop);
|
||||
}
|
||||
|
||||
@PostMapping("/pageQueryInfo")
|
||||
public PageInfo<QueryResp> pageQueryInfo(@RequestBody PageQueryInfoReq pageQueryInfoCommand,
|
||||
@RequestParam(value = "chatId") long chatId,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
public PageInfo<QueryResp> pageQueryInfo(
|
||||
@RequestBody PageQueryInfoReq pageQueryInfoCommand,
|
||||
@RequestParam(value = "chatId") long chatId,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
pageQueryInfoCommand.setUserName(UserHolder.findUser(request, response).getName());
|
||||
return chatService.queryInfo(pageQueryInfoCommand, chatId);
|
||||
}
|
||||
@@ -86,9 +95,9 @@ public class ChatController {
|
||||
}
|
||||
|
||||
@PostMapping("/queryShowCase")
|
||||
public ShowCaseResp queryShowCase(@RequestBody PageQueryInfoReq pageQueryInfoCommand,
|
||||
@RequestParam(value = "agentId") int agentId) {
|
||||
public ShowCaseResp queryShowCase(
|
||||
@RequestBody PageQueryInfoReq pageQueryInfoCommand,
|
||||
@RequestParam(value = "agentId") int agentId) {
|
||||
return chatService.queryShowCase(pageQueryInfoCommand, agentId);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.rest;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.validation.Valid;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
@@ -19,45 +22,47 @@ import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.validation.Valid;
|
||||
|
||||
/**
|
||||
* query controller
|
||||
*/
|
||||
/** query controller */
|
||||
@RestController
|
||||
@RequestMapping({"/api/chat/query", "/openapi/chat/query"})
|
||||
public class ChatQueryController {
|
||||
|
||||
@Autowired
|
||||
private ChatQueryService chatQueryService;
|
||||
@Autowired private ChatQueryService chatQueryService;
|
||||
|
||||
@PostMapping("search")
|
||||
public Object search(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
public Object search(
|
||||
@RequestBody ChatParseReq chatParseReq,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
chatParseReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatQueryService.search(chatParseReq);
|
||||
}
|
||||
|
||||
@PostMapping("parse")
|
||||
public Object parse(@RequestBody ChatParseReq chatParseReq,
|
||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||
public Object parse(
|
||||
@RequestBody ChatParseReq chatParseReq,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response)
|
||||
throws Exception {
|
||||
chatParseReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatQueryService.performParsing(chatParseReq);
|
||||
}
|
||||
|
||||
@PostMapping("execute")
|
||||
public Object execute(@RequestBody ChatExecuteReq chatExecuteReq,
|
||||
HttpServletRequest request, HttpServletResponse response)
|
||||
public Object execute(
|
||||
@RequestBody ChatExecuteReq chatExecuteReq,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response)
|
||||
throws Exception {
|
||||
chatExecuteReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatQueryService.performExecution(chatExecuteReq);
|
||||
}
|
||||
|
||||
@PostMapping("/")
|
||||
public Object query(@RequestBody ChatParseReq chatParseReq,
|
||||
HttpServletRequest request, HttpServletResponse response)
|
||||
public Object query(
|
||||
@RequestBody ChatParseReq chatParseReq,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response)
|
||||
throws Exception {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
chatParseReq.setUser(user);
|
||||
@@ -75,17 +80,22 @@ public class ChatQueryController {
|
||||
}
|
||||
|
||||
@PostMapping("queryData")
|
||||
public Object queryData(@RequestBody ChatQueryDataReq chatQueryDataReq,
|
||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||
public Object queryData(
|
||||
@RequestBody ChatQueryDataReq chatQueryDataReq,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response)
|
||||
throws Exception {
|
||||
chatQueryDataReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatQueryService.queryData(chatQueryDataReq, UserHolder.findUser(request, response));
|
||||
}
|
||||
|
||||
@PostMapping("queryDimensionValue")
|
||||
public Object queryDimensionValue(@RequestBody @Valid DimensionValueReq dimensionValueReq,
|
||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||
return chatQueryService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
|
||||
public Object queryDimensionValue(
|
||||
@RequestBody @Valid DimensionValueReq dimensionValueReq,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response)
|
||||
throws Exception {
|
||||
return chatQueryService.queryDimensionValue(
|
||||
dimensionValueReq, UserHolder.findUser(request, response));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.rest;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
@@ -13,20 +16,17 @@ import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
@RestController
|
||||
@RequestMapping({"/api/chat/memory"})
|
||||
public class MemoryController {
|
||||
|
||||
@Autowired
|
||||
private MemoryService memoryService;
|
||||
@Autowired private MemoryService memoryService;
|
||||
|
||||
@PostMapping("/updateMemory")
|
||||
public Boolean updateMemory(@RequestBody ChatMemoryUpdateReq chatMemoryUpdateReq,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
public Boolean updateMemory(
|
||||
@RequestBody ChatMemoryUpdateReq chatMemoryUpdateReq,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
memoryService.updateMemory(chatMemoryUpdateReq, user);
|
||||
return true;
|
||||
@@ -36,5 +36,4 @@ public class MemoryController {
|
||||
public PageInfo<ChatMemoryDO> pageMemories(@RequestBody PageMemoryReq pageMemoryReq) {
|
||||
return memoryService.pageMemories(pageMemoryReq);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.rest;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.annotation.AuthenticationIgnore;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
@@ -16,30 +19,29 @@ import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.util.List;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/api/chat/plugin")
|
||||
public class PluginController {
|
||||
|
||||
@Autowired
|
||||
protected PluginService pluginService;
|
||||
@Autowired protected PluginService pluginService;
|
||||
|
||||
@PostMapping
|
||||
public boolean createPlugin(@RequestBody ChatPlugin plugin,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
public boolean createPlugin(
|
||||
@RequestBody ChatPlugin plugin,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
pluginService.createPlugin(plugin, user);
|
||||
return true;
|
||||
}
|
||||
|
||||
@PutMapping
|
||||
public boolean updatePlugin(@RequestBody ChatPlugin plugin,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
public boolean updatePlugin(
|
||||
@RequestBody ChatPlugin plugin,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
pluginService.updatePlugin(plugin, user);
|
||||
return true;
|
||||
@@ -57,18 +59,18 @@ public class PluginController {
|
||||
}
|
||||
|
||||
@PostMapping("/query")
|
||||
List<ChatPlugin> query(@RequestBody PluginQueryReq pluginQueryReq,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
List<ChatPlugin> query(
|
||||
@RequestBody PluginQueryReq pluginQueryReq,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
return pluginService.queryWithAuthCheck(pluginQueryReq, user);
|
||||
}
|
||||
|
||||
@AuthenticationIgnore
|
||||
@PostMapping("/pluginDemo")
|
||||
public String pluginDemo(@RequestParam("queryText") String queryText,
|
||||
@RequestBody Object object) {
|
||||
public String pluginDemo(
|
||||
@RequestParam("queryText") String queryText, @RequestBody Object object) {
|
||||
return String.format("已收到您的问题:%s, 但这只是一个demo~", queryText);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface AgentService {
|
||||
@@ -15,5 +16,4 @@ public interface AgentService {
|
||||
Agent getAgent(Integer id);
|
||||
|
||||
void deleteAgent(Integer id);
|
||||
|
||||
}
|
||||
|
||||
@@ -7,5 +7,4 @@ public interface ChatContextService {
|
||||
ChatContext getOrCreateContext(Integer chatId);
|
||||
|
||||
void updateContext(ChatContext chatCtx);
|
||||
|
||||
}
|
||||
|
||||
@@ -6,13 +6,13 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||
|
||||
import java.util.List;
|
||||
@@ -24,5 +24,4 @@ public interface ChatQueryService {
|
||||
Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception;
|
||||
|
||||
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
|
||||
@@ -27,5 +26,4 @@ public interface ConfigService {
|
||||
ChatConfigResp fetchConfigByModelId(Long modelId);
|
||||
|
||||
List<ChatConfigRichResp> getAllChatRichConfig();
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq;
|
||||
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
||||
@@ -28,5 +27,4 @@ public interface PluginService {
|
||||
List<ChatPlugin> queryWithAuthCheck(PluginQueryReq pluginQueryReq, User user);
|
||||
|
||||
Map<String, ChatPlugin> getNameToPlugin();
|
||||
|
||||
}
|
||||
|
||||
@@ -3,11 +3,10 @@ package com.tencent.supersonic.chat.server.service;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.RecommendReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.RecommendQuestionResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.RecommendResp;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/***
|
||||
* Recommend Service
|
||||
*/
|
||||
/** * Recommend Service */
|
||||
public interface RecommendService {
|
||||
|
||||
RecommendResp recommend(RecommendReq recommendReq, Long limit);
|
||||
|
||||
@@ -29,21 +29,17 @@ import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
implements AgentService {
|
||||
public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implements AgentService {
|
||||
|
||||
@Autowired
|
||||
private MemoryService memoryService;
|
||||
@Autowired private MemoryService memoryService;
|
||||
|
||||
@Autowired
|
||||
private ChatQueryService chatQueryService;
|
||||
@Autowired private ChatQueryService chatQueryService;
|
||||
|
||||
private ExecutorService executorService = Executors.newFixedThreadPool(1);
|
||||
|
||||
@Override
|
||||
public List<Agent> getAgents() {
|
||||
return getAgentDOList().stream()
|
||||
.map(this::convert).collect(Collectors.toList());
|
||||
return getAgentDOList().stream().map(this::convert).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -78,8 +74,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
}
|
||||
|
||||
/**
|
||||
* the example in the agent will be executed by default,
|
||||
* if the result is correct, it will be put into memory as a reference for LLM
|
||||
* the example in the agent will be executed by default, if the result is correct, it will be
|
||||
* put into memory as a reference for LLM
|
||||
*
|
||||
* @param agent
|
||||
*/
|
||||
@@ -88,16 +84,19 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
}
|
||||
|
||||
private synchronized void doExecuteAgentExamples(Agent agent) {
|
||||
if (!agent.containsLLMParserTool() || !LLMConnHelper.testConnection(agent.getModelConfig())
|
||||
if (!agent.containsLLMParserTool()
|
||||
|| !LLMConnHelper.testConnection(agent.getModelConfig())
|
||||
|| CollectionUtils.isEmpty(agent.getExamples())) {
|
||||
return;
|
||||
}
|
||||
|
||||
List<String> examples = agent.getExamples();
|
||||
ChatMemoryFilter chatMemoryFilter = ChatMemoryFilter.builder().agentId(agent.getId())
|
||||
.questions(examples).build();
|
||||
List<String> memoriesExisted = memoryService.getMemories(chatMemoryFilter)
|
||||
.stream().map(ChatMemoryDO::getQuestion).collect(Collectors.toList());
|
||||
ChatMemoryFilter chatMemoryFilter =
|
||||
ChatMemoryFilter.builder().agentId(agent.getId()).questions(examples).build();
|
||||
List<String> memoriesExisted =
|
||||
memoryService.getMemories(chatMemoryFilter).stream()
|
||||
.map(ChatMemoryDO::getQuestion)
|
||||
.collect(Collectors.toList());
|
||||
for (String example : examples) {
|
||||
if (memoriesExisted.contains(example)) {
|
||||
continue;
|
||||
@@ -124,7 +123,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
|
||||
agent.setModelConfig(JsonUtil.toObject(agentDO.getModelConfig(), ChatModelConfig.class));
|
||||
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
|
||||
agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
|
||||
agent.setMultiTurnConfig(
|
||||
JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
|
||||
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
|
||||
return agent;
|
||||
}
|
||||
@@ -143,5 +143,4 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
}
|
||||
return agentDO;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatContextRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -26,5 +26,4 @@ public class ChatContextServiceImpl implements ChatContextService {
|
||||
log.debug("save ChatContext {}", chatCtx);
|
||||
chatContextRepository.updateContext(chatCtx);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
|
||||
@@ -18,7 +19,6 @@ import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -33,15 +33,12 @@ import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class ChatManageServiceImpl implements ChatManageService {
|
||||
|
||||
@Autowired
|
||||
private ChatRepository chatRepository;
|
||||
@Autowired
|
||||
private ChatQueryRepository chatQueryRepository;
|
||||
@Autowired private ChatRepository chatRepository;
|
||||
@Autowired private ChatQueryRepository chatQueryRepository;
|
||||
|
||||
@Override
|
||||
public Long addChat(User user, String chatName, Integer agentId) {
|
||||
@@ -88,7 +85,8 @@ public class ChatManageServiceImpl implements ChatManageService {
|
||||
|
||||
@Override
|
||||
public PageInfo<QueryResp> queryInfo(PageQueryInfoReq pageQueryInfoReq, long chatId) {
|
||||
PageInfo<QueryResp> queryRespPageInfo = chatQueryRepository.getChatQuery(pageQueryInfoReq, chatId);
|
||||
PageInfo<QueryResp> queryRespPageInfo =
|
||||
chatQueryRepository.getChatQuery(pageQueryInfoReq, chatId);
|
||||
if (CollectionUtils.isEmpty(queryRespPageInfo.getList())) {
|
||||
return queryRespPageInfo;
|
||||
}
|
||||
@@ -123,47 +121,62 @@ public class ChatManageServiceImpl implements ChatManageService {
|
||||
if (CollectionUtils.isEmpty(queryResps)) {
|
||||
return showCaseResp;
|
||||
}
|
||||
queryResps.removeIf(queryResp -> {
|
||||
if (queryResp.getQueryResult() == null) {
|
||||
return true;
|
||||
}
|
||||
if (queryResp.getQueryResult().getResponse() != null) {
|
||||
return false;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(queryResp.getQueryResult().getQueryResults())) {
|
||||
return true;
|
||||
}
|
||||
Map<String, Object> data = queryResp.getQueryResult().getQueryResults().get(0);
|
||||
return CollectionUtils.isEmpty(data);
|
||||
});
|
||||
queryResps = new ArrayList<>(queryResps.stream()
|
||||
.collect(Collectors.toMap(QueryResp::getQueryText, Function.identity(),
|
||||
(existing, replacement) -> existing, LinkedHashMap::new)).values());
|
||||
queryResps.removeIf(
|
||||
queryResp -> {
|
||||
if (queryResp.getQueryResult() == null) {
|
||||
return true;
|
||||
}
|
||||
if (queryResp.getQueryResult().getResponse() != null) {
|
||||
return false;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(queryResp.getQueryResult().getQueryResults())) {
|
||||
return true;
|
||||
}
|
||||
Map<String, Object> data = queryResp.getQueryResult().getQueryResults().get(0);
|
||||
return CollectionUtils.isEmpty(data);
|
||||
});
|
||||
queryResps =
|
||||
new ArrayList<>(
|
||||
queryResps.stream()
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
QueryResp::getQueryText,
|
||||
Function.identity(),
|
||||
(existing, replacement) -> existing,
|
||||
LinkedHashMap::new))
|
||||
.values());
|
||||
fillParseInfo(queryResps);
|
||||
Map<Long, List<QueryResp>> showCaseMap = queryResps.stream()
|
||||
.collect(Collectors.groupingBy(QueryResp::getChatId));
|
||||
Map<Long, List<QueryResp>> showCaseMap =
|
||||
queryResps.stream().collect(Collectors.groupingBy(QueryResp::getChatId));
|
||||
showCaseResp.setShowCaseMap(showCaseMap);
|
||||
return showCaseResp;
|
||||
}
|
||||
|
||||
private void fillParseInfo(List<QueryResp> queryResps) {
|
||||
List<Long> queryIds = queryResps.stream()
|
||||
.map(QueryResp::getQuestionId).collect(Collectors.toList());
|
||||
List<Long> queryIds =
|
||||
queryResps.stream().map(QueryResp::getQuestionId).collect(Collectors.toList());
|
||||
List<ChatParseDO> chatParseDOs = chatQueryRepository.getParseInfoList(queryIds);
|
||||
if (CollectionUtils.isEmpty(chatParseDOs)) {
|
||||
return;
|
||||
}
|
||||
Map<Long, List<ChatParseDO>> chatParseMap = chatParseDOs.stream()
|
||||
.collect(Collectors.groupingBy(ChatParseDO::getQuestionId));
|
||||
Map<Long, List<ChatParseDO>> chatParseMap =
|
||||
chatParseDOs.stream().collect(Collectors.groupingBy(ChatParseDO::getQuestionId));
|
||||
for (QueryResp queryResp : queryResps) {
|
||||
List<ChatParseDO> chatParseDOList = chatParseMap.get(queryResp.getQuestionId());
|
||||
if (CollectionUtils.isEmpty(chatParseDOList)) {
|
||||
continue;
|
||||
}
|
||||
List<SemanticParseInfo> parseInfos = chatParseDOList.stream().map(chatParseDO ->
|
||||
JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class))
|
||||
.sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
|
||||
.collect(Collectors.toList());
|
||||
List<SemanticParseInfo> parseInfos =
|
||||
chatParseDOList.stream()
|
||||
.map(
|
||||
chatParseDO ->
|
||||
JsonUtil.toObject(
|
||||
chatParseDO.getParseInfo(),
|
||||
SemanticParseInfo.class))
|
||||
.sorted(
|
||||
Comparator.comparingDouble(SemanticParseInfo::getScore)
|
||||
.reversed())
|
||||
.collect(Collectors.toList());
|
||||
queryResp.setParseInfos(parseInfos);
|
||||
}
|
||||
}
|
||||
@@ -175,8 +188,10 @@ public class ChatManageServiceImpl implements ChatManageService {
|
||||
chatQueryDO.setQueryResult(JsonUtil.toString(queryResult));
|
||||
chatQueryDO.setQueryState(1);
|
||||
updateQuery(chatQueryDO);
|
||||
chatRepository.updateLastQuestion(chatExecuteReq.getChatId().longValue(),
|
||||
chatExecuteReq.getQueryText(), getCurrentTime());
|
||||
chatRepository.updateLastQuestion(
|
||||
chatExecuteReq.getChatId().longValue(),
|
||||
chatExecuteReq.getQueryText(),
|
||||
getCurrentTime());
|
||||
return chatQueryDO;
|
||||
}
|
||||
|
||||
@@ -208,5 +223,4 @@ public class ChatManageServiceImpl implements ChatManageService {
|
||||
ChatParseDO chatParseDO = chatQueryRepository.getParseInfo(questionId, parseId);
|
||||
return JSONObject.parseObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -75,24 +75,21 @@ import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
|
||||
@Autowired
|
||||
private ChatManageService chatManageService;
|
||||
@Autowired
|
||||
private ChatLayerService chatLayerService;
|
||||
@Autowired
|
||||
private SemanticLayerService semanticLayerService;
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
@Autowired private ChatManageService chatManageService;
|
||||
@Autowired private ChatLayerService chatLayerService;
|
||||
@Autowired private SemanticLayerService semanticLayerService;
|
||||
@Autowired private AgentService agentService;
|
||||
|
||||
private List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
|
||||
private List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
|
||||
private List<ParseResultProcessor> parseResultProcessors = ComponentFactory.getParseProcessors();
|
||||
private List<ExecuteResultProcessor> executeResultProcessors = ComponentFactory.getExecuteProcessors();
|
||||
private List<ParseResultProcessor> parseResultProcessors =
|
||||
ComponentFactory.getParseProcessors();
|
||||
private List<ExecuteResultProcessor> executeResultProcessors =
|
||||
ComponentFactory.getExecuteProcessors();
|
||||
|
||||
@Override
|
||||
public List<SearchResult> search(ChatParseReq chatParseReq) {
|
||||
@@ -153,8 +150,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
chatParseReq.setUser(User.getFakeUser());
|
||||
ParseResp parseResp = performParsing(chatParseReq);
|
||||
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
|
||||
log.debug("chatId:{}, agentId:{}, queryText:{}, parseResp.getSelectedParses() is empty",
|
||||
chatId, agentId, queryText);
|
||||
log.debug(
|
||||
"chatId:{}, agentId:{}, queryText:{}, parseResp.getSelectedParses() is empty",
|
||||
chatId,
|
||||
agentId,
|
||||
queryText);
|
||||
return null;
|
||||
}
|
||||
ChatExecuteReq executeReq = new ChatExecuteReq();
|
||||
@@ -185,8 +185,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
|
||||
ExecuteContext executeContext = new ExecuteContext();
|
||||
BeanMapper.mapper(chatExecuteReq, executeContext);
|
||||
SemanticParseInfo parseInfo = chatManageService.getParseInfo(
|
||||
chatExecuteReq.getQueryId(), chatExecuteReq.getParseId());
|
||||
SemanticParseInfo parseInfo =
|
||||
chatManageService.getParseInfo(
|
||||
chatExecuteReq.getQueryId(), chatExecuteReq.getParseId());
|
||||
Agent agent = agentService.getAgent(chatExecuteReq.getAgentId());
|
||||
executeContext.setAgent(agent);
|
||||
executeContext.setParseInfo(parseInfo);
|
||||
@@ -196,9 +197,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
@Override
|
||||
public Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception {
|
||||
Integer parseId = chatQueryDataReq.getParseId();
|
||||
SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId);
|
||||
SemanticParseInfo parseInfo =
|
||||
chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId);
|
||||
parseInfo = mergeParseInfo(parseInfo, chatQueryDataReq);
|
||||
DataSetSchema dataSetSchema = semanticLayerService.getDataSetSchema(parseInfo.getDataSetId());
|
||||
DataSetSchema dataSetSchema =
|
||||
semanticLayerService.getDataSetSchema(parseInfo.getDataSetId());
|
||||
|
||||
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
|
||||
semanticQuery.setParseInfo(parseInfo);
|
||||
@@ -220,9 +223,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
return SqlSelectHelper.getAllSelectFields(sqlInfo.getCorrectedS2SQL());
|
||||
}
|
||||
|
||||
private void handleLLMQueryMode(ChatQueryDataReq chatQueryDataReq,
|
||||
SemanticQuery semanticQuery,
|
||||
User user) throws Exception {
|
||||
private void handleLLMQueryMode(
|
||||
ChatQueryDataReq chatQueryDataReq, SemanticQuery semanticQuery, User user)
|
||||
throws Exception {
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
List<String> fields = getFieldsFromSql(parseInfo);
|
||||
if (checkMetricReplace(fields, chatQueryDataReq.getMetrics())) {
|
||||
@@ -240,18 +243,16 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
}
|
||||
}
|
||||
|
||||
private void handleRuleQueryMode(SemanticQuery semanticQuery,
|
||||
DataSetSchema dataSetSchema,
|
||||
User user) {
|
||||
private void handleRuleQueryMode(
|
||||
SemanticQuery semanticQuery, DataSetSchema dataSetSchema, User user) {
|
||||
log.info("rule begin replace metrics and revise filters!");
|
||||
validFilter(semanticQuery.getParseInfo().getDimensionFilters());
|
||||
validFilter(semanticQuery.getParseInfo().getMetricFilters());
|
||||
semanticQuery.initS2Sql(dataSetSchema, user);
|
||||
}
|
||||
|
||||
private QueryResult executeQuery(SemanticQuery semanticQuery,
|
||||
User user,
|
||||
DataSetSchema dataSetSchema) throws Exception {
|
||||
private QueryResult executeQuery(
|
||||
SemanticQuery semanticQuery, User user, DataSetSchema dataSetSchema) throws Exception {
|
||||
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
QueryResult queryResult = doExecution(semanticQueryReq, parseInfo.getQueryMode(), user);
|
||||
@@ -266,7 +267,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
if (CollectionUtils.isEmpty(oriFields) || CollectionUtils.isEmpty(metrics)) {
|
||||
return false;
|
||||
}
|
||||
List<String> metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList());
|
||||
List<String> metricNames =
|
||||
metrics.stream().map(SchemaElement::getName).collect(Collectors.toList());
|
||||
return !oriFields.containsAll(metricNames);
|
||||
}
|
||||
|
||||
@@ -274,26 +276,41 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
log.info("correctorSql before replacing:{}", correctorSql);
|
||||
// get where filter and having filter
|
||||
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql);
|
||||
List<FieldExpression> whereExpressionList =
|
||||
SqlSelectHelper.getWhereExpressions(correctorSql);
|
||||
|
||||
// replace where filter
|
||||
List<Expression> addWhereConditions = new ArrayList<>();
|
||||
Set<String> removeWhereFieldNames = updateFilters(whereExpressionList, queryData.getDimensionFilters(),
|
||||
parseInfo.getDimensionFilters(), addWhereConditions);
|
||||
Set<String> removeWhereFieldNames =
|
||||
updateFilters(
|
||||
whereExpressionList,
|
||||
queryData.getDimensionFilters(),
|
||||
parseInfo.getDimensionFilters(),
|
||||
addWhereConditions);
|
||||
|
||||
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
|
||||
Set<String> removeDataFieldNames = updateDateInfo(queryData, parseInfo, filedNameToValueMap,
|
||||
whereExpressionList, addWhereConditions);
|
||||
Set<String> removeDataFieldNames =
|
||||
updateDateInfo(
|
||||
queryData,
|
||||
parseInfo,
|
||||
filedNameToValueMap,
|
||||
whereExpressionList,
|
||||
addWhereConditions);
|
||||
removeWhereFieldNames.addAll(removeDataFieldNames);
|
||||
|
||||
correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
|
||||
correctorSql = SqlRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames);
|
||||
|
||||
// replace having filter
|
||||
List<FieldExpression> havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql);
|
||||
List<FieldExpression> havingExpressionList =
|
||||
SqlSelectHelper.getHavingExpressions(correctorSql);
|
||||
List<Expression> addHavingConditions = new ArrayList<>();
|
||||
Set<String> removeHavingFieldNames = updateFilters(havingExpressionList,
|
||||
queryData.getDimensionFilters(), parseInfo.getDimensionFilters(), addHavingConditions);
|
||||
Set<String> removeHavingFieldNames =
|
||||
updateFilters(
|
||||
havingExpressionList,
|
||||
queryData.getDimensionFilters(),
|
||||
parseInfo.getDimensionFilters(),
|
||||
addHavingConditions);
|
||||
correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, new HashMap<>());
|
||||
correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
|
||||
|
||||
@@ -304,8 +321,10 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
}
|
||||
|
||||
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
|
||||
List<String> oriMetrics = parseInfo.getMetrics().stream()
|
||||
.map(SchemaElement::getName).collect(Collectors.toList());
|
||||
List<String> oriMetrics =
|
||||
parseInfo.getMetrics().stream()
|
||||
.map(SchemaElement::getName)
|
||||
.collect(Collectors.toList());
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
log.info("before replaceMetrics:{}", correctorSql);
|
||||
log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric);
|
||||
@@ -318,7 +337,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
|
||||
}
|
||||
|
||||
private QueryResult doExecution(SemanticQueryReq semanticQueryReq, String queryMode, User user) throws Exception {
|
||||
private QueryResult doExecution(SemanticQueryReq semanticQueryReq, String queryMode, User user)
|
||||
throws Exception {
|
||||
SemanticQueryResp queryResp = semanticLayerService.queryByReq(semanticQueryReq, user);
|
||||
QueryResult queryResult = new QueryResult();
|
||||
|
||||
@@ -337,16 +357,20 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
private Set<String> updateDateInfo(ChatQueryDataReq queryData, SemanticParseInfo parseInfo,
|
||||
Map<String, Map<String, String>> filedNameToValueMap,
|
||||
List<FieldExpression> fieldExpressionList,
|
||||
List<Expression> addConditions) {
|
||||
private Set<String> updateDateInfo(
|
||||
ChatQueryDataReq queryData,
|
||||
SemanticParseInfo parseInfo,
|
||||
Map<String, Map<String, String>> filedNameToValueMap,
|
||||
List<FieldExpression> fieldExpressionList,
|
||||
List<Expression> addConditions) {
|
||||
Set<String> removeFieldNames = new HashSet<>();
|
||||
if (Objects.isNull(queryData.getDateInfo())) {
|
||||
return removeFieldNames;
|
||||
}
|
||||
if (queryData.getDateInfo().getUnit() > 1) {
|
||||
queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1));
|
||||
queryData
|
||||
.getDateInfo()
|
||||
.setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1));
|
||||
queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(1));
|
||||
}
|
||||
// startDate equals to endDate
|
||||
@@ -355,17 +379,20 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
// first remove,then add
|
||||
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
|
||||
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
|
||||
addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions);
|
||||
addTimeFilters(
|
||||
queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions);
|
||||
MinorThanEquals minorThanEquals = new MinorThanEquals();
|
||||
addTimeFilters(queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions);
|
||||
addTimeFilters(
|
||||
queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions);
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||
for (QueryFilter queryFilter : queryData.getDimensionFilters()) {
|
||||
if (queryFilter.getOperator().equals(FilterOperatorEnum.LIKE)
|
||||
&& FilterOperatorEnum.LIKE.getValue().equalsIgnoreCase(
|
||||
fieldExpression.getOperator())) {
|
||||
&& FilterOperatorEnum.LIKE
|
||||
.getValue()
|
||||
.equalsIgnoreCase(fieldExpression.getOperator())) {
|
||||
Map<String, String> replaceMap = new HashMap<>();
|
||||
String preValue = fieldExpression.getFieldValue().toString();
|
||||
String curValue = queryFilter.getValue().toString();
|
||||
@@ -385,9 +412,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
return removeFieldNames;
|
||||
}
|
||||
|
||||
private <T extends ComparisonOperator> void addTimeFilters(String date,
|
||||
T comparisonExpression,
|
||||
List<Expression> addConditions) {
|
||||
private <T extends ComparisonOperator> void addTimeFilters(
|
||||
String date, T comparisonExpression, List<Expression> addConditions) {
|
||||
Column column = new Column(TimeDimensionEnum.DAY.getChName());
|
||||
StringValue stringValue = new StringValue(date);
|
||||
comparisonExpression.setLeftExpression(column);
|
||||
@@ -395,10 +421,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
addConditions.add(comparisonExpression);
|
||||
}
|
||||
|
||||
private Set<String> updateFilters(List<FieldExpression> fieldExpressionList,
|
||||
Set<QueryFilter> metricFilters,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions) {
|
||||
private Set<String> updateFilters(
|
||||
List<FieldExpression> fieldExpressionList,
|
||||
Set<QueryFilter> metricFilters,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions) {
|
||||
Set<String> removeFieldNames = new HashSet<>();
|
||||
if (CollectionUtils.isEmpty(metricFilters)) {
|
||||
return removeFieldNames;
|
||||
@@ -417,13 +444,15 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
return removeFieldNames;
|
||||
}
|
||||
|
||||
private void handleFilter(QueryFilter dslQueryFilter,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions) {
|
||||
private void handleFilter(
|
||||
QueryFilter dslQueryFilter,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions) {
|
||||
FilterOperatorEnum operator = dslQueryFilter.getOperator();
|
||||
|
||||
if (operator == FilterOperatorEnum.IN) {
|
||||
addWhereInFilters(dslQueryFilter, new InExpression(), contextMetricFilters, addConditions);
|
||||
addWhereInFilters(
|
||||
dslQueryFilter, new InExpression(), contextMetricFilters, addConditions);
|
||||
} else {
|
||||
ComparisonOperator expression = FilterOperatorEnum.createExpression(operator);
|
||||
if (Objects.nonNull(expression)) {
|
||||
@@ -433,37 +462,43 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
}
|
||||
|
||||
// add in condition to sql where condition
|
||||
private void addWhereInFilters(QueryFilter dslQueryFilter,
|
||||
InExpression inExpression,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions) {
|
||||
private void addWhereInFilters(
|
||||
QueryFilter dslQueryFilter,
|
||||
InExpression inExpression,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions) {
|
||||
Column column = new Column(dslQueryFilter.getName());
|
||||
ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>();
|
||||
List<String> valueList = JsonUtil.toList(
|
||||
JsonUtil.toString(dslQueryFilter.getValue()), String.class);
|
||||
List<String> valueList =
|
||||
JsonUtil.toList(JsonUtil.toString(dslQueryFilter.getValue()), String.class);
|
||||
if (CollectionUtils.isEmpty(valueList)) {
|
||||
return;
|
||||
}
|
||||
valueList.stream().forEach(o -> {
|
||||
StringValue stringValue = new StringValue(o);
|
||||
parenthesedExpressionList.add(stringValue);
|
||||
});
|
||||
valueList.stream()
|
||||
.forEach(
|
||||
o -> {
|
||||
StringValue stringValue = new StringValue(o);
|
||||
parenthesedExpressionList.add(stringValue);
|
||||
});
|
||||
inExpression.setLeftExpression(column);
|
||||
inExpression.setRightExpression(parenthesedExpressionList);
|
||||
addConditions.add(inExpression);
|
||||
contextMetricFilters.stream().forEach(o -> {
|
||||
if (o.getName().equals(dslQueryFilter.getName())) {
|
||||
o.setValue(dslQueryFilter.getValue());
|
||||
o.setOperator(dslQueryFilter.getOperator());
|
||||
}
|
||||
});
|
||||
contextMetricFilters.stream()
|
||||
.forEach(
|
||||
o -> {
|
||||
if (o.getName().equals(dslQueryFilter.getName())) {
|
||||
o.setValue(dslQueryFilter.getValue());
|
||||
o.setOperator(dslQueryFilter.getOperator());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// add where filter
|
||||
private void addWhereFilters(QueryFilter dslQueryFilter,
|
||||
ComparisonOperator comparisonExpression,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions) {
|
||||
private void addWhereFilters(
|
||||
QueryFilter dslQueryFilter,
|
||||
ComparisonOperator comparisonExpression,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions) {
|
||||
String columnName = dslQueryFilter.getName();
|
||||
if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) {
|
||||
columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")";
|
||||
@@ -474,23 +509,26 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
Column column = new Column(columnName);
|
||||
comparisonExpression.setLeftExpression(column);
|
||||
if (StringUtils.isNumeric(dslQueryFilter.getValue().toString())) {
|
||||
LongValue longValue = new LongValue(Long.parseLong(dslQueryFilter.getValue().toString()));
|
||||
LongValue longValue =
|
||||
new LongValue(Long.parseLong(dslQueryFilter.getValue().toString()));
|
||||
comparisonExpression.setRightExpression(longValue);
|
||||
} else {
|
||||
StringValue stringValue = new StringValue(dslQueryFilter.getValue().toString());
|
||||
comparisonExpression.setRightExpression(stringValue);
|
||||
}
|
||||
addConditions.add(comparisonExpression);
|
||||
contextMetricFilters.stream().forEach(o -> {
|
||||
if (o.getName().equals(dslQueryFilter.getName())) {
|
||||
o.setValue(dslQueryFilter.getValue());
|
||||
o.setOperator(dslQueryFilter.getOperator());
|
||||
}
|
||||
});
|
||||
contextMetricFilters.stream()
|
||||
.forEach(
|
||||
o -> {
|
||||
if (o.getName().equals(dslQueryFilter.getName())) {
|
||||
o.setValue(dslQueryFilter.getValue());
|
||||
o.setOperator(dslQueryFilter.getOperator());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private SemanticParseInfo mergeParseInfo(SemanticParseInfo parseInfo,
|
||||
ChatQueryDataReq queryData) {
|
||||
private SemanticParseInfo mergeParseInfo(
|
||||
SemanticParseInfo parseInfo, ChatQueryDataReq queryData) {
|
||||
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
||||
return parseInfo;
|
||||
}
|
||||
@@ -521,7 +559,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
iterator.remove();
|
||||
continue;
|
||||
}
|
||||
List<String> collection = JsonUtil.toList(JsonUtil.toString(queryFilterValue), String.class);
|
||||
List<String> collection =
|
||||
JsonUtil.toList(JsonUtil.toString(queryFilterValue), String.class);
|
||||
if (FilterOperatorEnum.IN.equals(queryFilter.getOperator())
|
||||
&& CollectionUtils.isEmpty(collection)) {
|
||||
iterator.remove();
|
||||
@@ -538,11 +577,10 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
}
|
||||
|
||||
public void saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult) {
|
||||
//The history record only retains the query result of the first parse
|
||||
// The history record only retains the query result of the first parse
|
||||
if (chatExecuteReq.getParseId() > 1) {
|
||||
return;
|
||||
}
|
||||
chatManageService.saveQueryResult(chatExecuteReq, queryResult);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
|
||||
@@ -26,11 +25,11 @@ import com.tencent.supersonic.chat.server.service.ConfigService;
|
||||
import com.tencent.supersonic.chat.server.util.ChatConfigHelper;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.MetaFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.MetaFilter;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
@@ -44,7 +43,6 @@ import java.util.Objects;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class ConfigServiceImpl implements ConfigService {
|
||||
@@ -53,9 +51,10 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
private final ChatConfigHelper chatConfigHelper;
|
||||
private final SemanticLayerService semanticLayerService;
|
||||
|
||||
|
||||
public ConfigServiceImpl(ChatConfigRepository chatConfigRepository,
|
||||
ChatConfigHelper chatConfigHelper, SemanticLayerService semanticLayerService) {
|
||||
public ConfigServiceImpl(
|
||||
ChatConfigRepository chatConfigRepository,
|
||||
ChatConfigHelper chatConfigHelper,
|
||||
SemanticLayerService semanticLayerService) {
|
||||
this.chatConfigRepository = chatConfigRepository;
|
||||
this.chatConfigHelper = chatConfigHelper;
|
||||
this.semanticLayerService = semanticLayerService;
|
||||
@@ -81,9 +80,11 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
@Override
|
||||
public Long editConfig(ChatConfigEditReqReq configEditCmd, User user) {
|
||||
log.info("[edit model extend] object:{}", JsonUtil.toString(configEditCmd, true));
|
||||
if (Objects.isNull(configEditCmd) || Objects.isNull(configEditCmd.getId()) && Objects.isNull(
|
||||
configEditCmd.getModelId())) {
|
||||
throw new RuntimeException("editConfig, id and modelId are not allowed to be empty at the same time");
|
||||
if (Objects.isNull(configEditCmd)
|
||||
|| Objects.isNull(configEditCmd.getId())
|
||||
&& Objects.isNull(configEditCmd.getModelId())) {
|
||||
throw new RuntimeException(
|
||||
"editConfig, id and modelId are not allowed to be empty at the same time");
|
||||
}
|
||||
ChatConfig chaConfig = chatConfigHelper.editChatConfig(configEditCmd, user);
|
||||
chatConfigRepository.updateConfig(chaConfig);
|
||||
@@ -106,39 +107,51 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
List<Long> blackDimIdList = new ArrayList<>();
|
||||
if (Objects.nonNull(chatConfig.getChatAggConfig())
|
||||
&& Objects.nonNull(chatConfig.getChatAggConfig().getVisibility())) {
|
||||
blackDimIdList.addAll(chatConfig.getChatAggConfig().getVisibility().getBlackDimIdList());
|
||||
blackDimIdList.addAll(
|
||||
chatConfig.getChatAggConfig().getVisibility().getBlackDimIdList());
|
||||
}
|
||||
if (Objects.nonNull(chatConfig.getChatDetailConfig())
|
||||
&& Objects.nonNull(chatConfig.getChatDetailConfig().getVisibility())) {
|
||||
blackDimIdList.addAll(chatConfig.getChatDetailConfig().getVisibility().getBlackDimIdList());
|
||||
blackDimIdList.addAll(
|
||||
chatConfig.getChatDetailConfig().getVisibility().getBlackDimIdList());
|
||||
}
|
||||
List<Long> filterDimIdList = blackDimIdList.stream().distinct().collect(Collectors.toList());
|
||||
List<Long> filterDimIdList =
|
||||
blackDimIdList.stream().distinct().collect(Collectors.toList());
|
||||
|
||||
List<Long> blackMetricIdList = new ArrayList<>();
|
||||
if (Objects.nonNull(chatConfig.getChatAggConfig())
|
||||
&& Objects.nonNull(chatConfig.getChatAggConfig().getVisibility())) {
|
||||
blackMetricIdList.addAll(chatConfig.getChatAggConfig().getVisibility().getBlackMetricIdList());
|
||||
blackMetricIdList.addAll(
|
||||
chatConfig.getChatAggConfig().getVisibility().getBlackMetricIdList());
|
||||
}
|
||||
if (Objects.nonNull(chatConfig.getChatDetailConfig())
|
||||
&& Objects.nonNull(chatConfig.getChatDetailConfig().getVisibility())) {
|
||||
blackMetricIdList.addAll(chatConfig.getChatDetailConfig().getVisibility().getBlackMetricIdList());
|
||||
blackMetricIdList.addAll(
|
||||
chatConfig.getChatDetailConfig().getVisibility().getBlackMetricIdList());
|
||||
}
|
||||
List<Long> filterMetricIdList = blackMetricIdList.stream().distinct().collect(Collectors.toList());
|
||||
List<Long> filterMetricIdList =
|
||||
blackMetricIdList.stream().distinct().collect(Collectors.toList());
|
||||
|
||||
ItemNameVisibilityInfo itemNameVisibility = new ItemNameVisibilityInfo();
|
||||
MetaFilter metaFilter = new MetaFilter();
|
||||
metaFilter.setModelIds(Lists.newArrayList(modelId));
|
||||
if (!CollectionUtils.isEmpty(blackDimIdList)) {
|
||||
List<DimensionResp> dimensionRespList = semanticLayerService.getDimensions(metaFilter);
|
||||
List<String> blackDimNameList = dimensionRespList.stream().filter(o -> filterDimIdList.contains(o.getId()))
|
||||
.map(SchemaItem::getName).collect(Collectors.toList());
|
||||
List<String> blackDimNameList =
|
||||
dimensionRespList.stream()
|
||||
.filter(o -> filterDimIdList.contains(o.getId()))
|
||||
.map(SchemaItem::getName)
|
||||
.collect(Collectors.toList());
|
||||
itemNameVisibility.setBlackDimNameList(blackDimNameList);
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(blackMetricIdList)) {
|
||||
|
||||
List<MetricResp> metricRespList = semanticLayerService.getMetrics(metaFilter);
|
||||
List<String> blackMetricList = metricRespList.stream().filter(o -> filterMetricIdList.contains(o.getId()))
|
||||
.map(SchemaItem::getName).collect(Collectors.toList());
|
||||
List<String> blackMetricList =
|
||||
metricRespList.stream()
|
||||
.filter(o -> filterMetricIdList.contains(o.getId()))
|
||||
.map(SchemaItem::getName)
|
||||
.collect(Collectors.toList());
|
||||
itemNameVisibility.setBlackMetricNameList(blackMetricList);
|
||||
}
|
||||
return itemNameVisibility;
|
||||
@@ -156,8 +169,8 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
return chatConfigRepository.getConfigByModelId(modelId);
|
||||
}
|
||||
|
||||
private ItemVisibilityInfo fetchVisibilityDescByConfig(ItemVisibility visibility,
|
||||
DataSetSchema modelSchema) {
|
||||
private ItemVisibilityInfo fetchVisibilityDescByConfig(
|
||||
ItemVisibility visibility, DataSetSchema modelSchema) {
|
||||
ItemVisibilityInfo itemVisibilityDesc = new ItemVisibilityInfo();
|
||||
|
||||
List<Long> dimIdAllList = chatConfigHelper.generateAllDimIdList(modelSchema);
|
||||
@@ -173,16 +186,22 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
blackMetricIdList.addAll(visibility.getBlackMetricIdList());
|
||||
}
|
||||
}
|
||||
List<Long> whiteMetricIdList = metricIdAllList.stream()
|
||||
.filter(id -> !blackMetricIdList.contains(id) && metricIdAllList.contains(id))
|
||||
.collect(Collectors.toList());
|
||||
List<Long> whiteDimIdList = dimIdAllList.stream()
|
||||
.filter(id -> !blackDimIdList.contains(id) && dimIdAllList.contains(id))
|
||||
.collect(Collectors.toList());
|
||||
List<Long> whiteMetricIdList =
|
||||
metricIdAllList.stream()
|
||||
.filter(
|
||||
id ->
|
||||
!blackMetricIdList.contains(id)
|
||||
&& metricIdAllList.contains(id))
|
||||
.collect(Collectors.toList());
|
||||
List<Long> whiteDimIdList =
|
||||
dimIdAllList.stream()
|
||||
.filter(id -> !blackDimIdList.contains(id) && dimIdAllList.contains(id))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
itemVisibilityDesc.setBlackDimIdList(blackDimIdList);
|
||||
itemVisibilityDesc.setBlackMetricIdList(blackMetricIdList);
|
||||
itemVisibilityDesc.setWhiteDimIdList(Objects.isNull(whiteDimIdList) ? new ArrayList<>() : whiteDimIdList);
|
||||
itemVisibilityDesc.setWhiteDimIdList(
|
||||
Objects.isNull(whiteDimIdList) ? new ArrayList<>() : whiteDimIdList);
|
||||
itemVisibilityDesc.setWhiteMetricIdList(
|
||||
Objects.isNull(whiteMetricIdList) ? new ArrayList<>() : whiteMetricIdList);
|
||||
|
||||
@@ -207,26 +226,31 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
chatConfigRich.setModelName(dataSetSchema.getDataSet().getName());
|
||||
|
||||
chatConfigRich.setChatAggRichConfig(fillChatAggRichConfig(dataSetSchema, chatConfigResp));
|
||||
chatConfigRich.setChatDetailRichConfig(fillChatDetailRichConfig(dataSetSchema, chatConfigRich, chatConfigResp));
|
||||
chatConfigRich.setChatDetailRichConfig(
|
||||
fillChatDetailRichConfig(dataSetSchema, chatConfigRich, chatConfigResp));
|
||||
|
||||
return chatConfigRich;
|
||||
}
|
||||
|
||||
private ChatDetailRichConfigResp fillChatDetailRichConfig(DataSetSchema modelSchema,
|
||||
ChatConfigRichResp chatConfigRich,
|
||||
ChatConfigResp chatConfigResp) {
|
||||
if (Objects.isNull(chatConfigResp) || Objects.isNull(chatConfigResp.getChatDetailConfig())) {
|
||||
private ChatDetailRichConfigResp fillChatDetailRichConfig(
|
||||
DataSetSchema modelSchema,
|
||||
ChatConfigRichResp chatConfigRich,
|
||||
ChatConfigResp chatConfigResp) {
|
||||
if (Objects.isNull(chatConfigResp)
|
||||
|| Objects.isNull(chatConfigResp.getChatDetailConfig())) {
|
||||
return null;
|
||||
}
|
||||
ChatDetailRichConfigResp detailRichConfig = new ChatDetailRichConfigResp();
|
||||
ChatDetailConfigReq chatDetailConfig = chatConfigResp.getChatDetailConfig();
|
||||
ItemVisibilityInfo itemVisibilityInfo = fetchVisibilityDescByConfig(
|
||||
chatDetailConfig.getVisibility(), modelSchema);
|
||||
ItemVisibilityInfo itemVisibilityInfo =
|
||||
fetchVisibilityDescByConfig(chatDetailConfig.getVisibility(), modelSchema);
|
||||
detailRichConfig.setVisibility(itemVisibilityInfo);
|
||||
detailRichConfig.setKnowledgeInfos(fillKnowledgeBizName(chatDetailConfig.getKnowledgeInfos(), modelSchema));
|
||||
detailRichConfig.setKnowledgeInfos(
|
||||
fillKnowledgeBizName(chatDetailConfig.getKnowledgeInfos(), modelSchema));
|
||||
detailRichConfig.setGlobalKnowledgeConfig(chatDetailConfig.getGlobalKnowledgeConfig());
|
||||
detailRichConfig.setChatDefaultConfig(fetchDefaultConfig(chatDetailConfig.getChatDefaultConfig(),
|
||||
modelSchema, itemVisibilityInfo));
|
||||
detailRichConfig.setChatDefaultConfig(
|
||||
fetchDefaultConfig(
|
||||
chatDetailConfig.getChatDefaultConfig(), modelSchema, itemVisibilityInfo));
|
||||
|
||||
return detailRichConfig;
|
||||
}
|
||||
@@ -237,30 +261,38 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
return entityRichInfo;
|
||||
}
|
||||
BeanUtils.copyProperties(entity, entityRichInfo);
|
||||
Map<Long, SchemaElement> dimIdAndRespPair = modelSchema.getDimensions().stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
|
||||
Map<Long, SchemaElement> dimIdAndRespPair =
|
||||
modelSchema.getDimensions().stream()
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
|
||||
|
||||
entityRichInfo.setDimItem(dimIdAndRespPair.get(entity.getEntityId()));
|
||||
return entityRichInfo;
|
||||
}
|
||||
|
||||
private ChatAggRichConfigResp fillChatAggRichConfig(DataSetSchema modelSchema, ChatConfigResp chatConfigResp) {
|
||||
private ChatAggRichConfigResp fillChatAggRichConfig(
|
||||
DataSetSchema modelSchema, ChatConfigResp chatConfigResp) {
|
||||
if (Objects.isNull(chatConfigResp) || Objects.isNull(chatConfigResp.getChatAggConfig())) {
|
||||
return null;
|
||||
}
|
||||
ChatAggConfigReq chatAggConfig = chatConfigResp.getChatAggConfig();
|
||||
ChatAggRichConfigResp chatAggRichConfig = new ChatAggRichConfigResp();
|
||||
ItemVisibilityInfo itemVisibilityInfo = fetchVisibilityDescByConfig(chatAggConfig.getVisibility(), modelSchema);
|
||||
ItemVisibilityInfo itemVisibilityInfo =
|
||||
fetchVisibilityDescByConfig(chatAggConfig.getVisibility(), modelSchema);
|
||||
chatAggRichConfig.setVisibility(itemVisibilityInfo);
|
||||
chatAggRichConfig.setKnowledgeInfos(fillKnowledgeBizName(chatAggConfig.getKnowledgeInfos(), modelSchema));
|
||||
chatAggRichConfig.setKnowledgeInfos(
|
||||
fillKnowledgeBizName(chatAggConfig.getKnowledgeInfos(), modelSchema));
|
||||
chatAggRichConfig.setGlobalKnowledgeConfig(chatAggConfig.getGlobalKnowledgeConfig());
|
||||
chatAggRichConfig.setChatDefaultConfig(fetchDefaultConfig(chatAggConfig.getChatDefaultConfig(),
|
||||
modelSchema, itemVisibilityInfo));
|
||||
chatAggRichConfig.setChatDefaultConfig(
|
||||
fetchDefaultConfig(
|
||||
chatAggConfig.getChatDefaultConfig(), modelSchema, itemVisibilityInfo));
|
||||
|
||||
return chatAggRichConfig;
|
||||
}
|
||||
|
||||
private ChatDefaultRichConfigResp fetchDefaultConfig(ChatDefaultConfigReq chatDefaultConfig,
|
||||
private ChatDefaultRichConfigResp fetchDefaultConfig(
|
||||
ChatDefaultConfigReq chatDefaultConfig,
|
||||
DataSetSchema modelSchema,
|
||||
ItemVisibilityInfo itemVisibilityInfo) {
|
||||
ChatDefaultRichConfigResp defaultRichConfig = new ChatDefaultRichConfigResp();
|
||||
@@ -268,41 +300,56 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
return defaultRichConfig;
|
||||
}
|
||||
BeanUtils.copyProperties(chatDefaultConfig, defaultRichConfig);
|
||||
Map<Long, SchemaElement> dimIdAndRespPair = modelSchema.getDimensions().stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
|
||||
Map<Long, SchemaElement> dimIdAndRespPair =
|
||||
modelSchema.getDimensions().stream()
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
|
||||
|
||||
Map<Long, SchemaElement> metricIdAndRespPair = modelSchema.getMetrics().stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
|
||||
Map<Long, SchemaElement> metricIdAndRespPair =
|
||||
modelSchema.getMetrics().stream()
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
|
||||
|
||||
List<SchemaElement> dimensions = new ArrayList<>();
|
||||
List<SchemaElement> metrics = new ArrayList<>();
|
||||
if (!CollectionUtils.isEmpty(chatDefaultConfig.getDimensionIds())) {
|
||||
chatDefaultConfig.getDimensionIds().stream()
|
||||
.filter(dimId -> dimIdAndRespPair.containsKey(dimId)
|
||||
&& itemVisibilityInfo.getWhiteDimIdList().contains(dimId))
|
||||
.forEach(dimId -> {
|
||||
SchemaElement dimSchemaResp = dimIdAndRespPair.get(dimId);
|
||||
if (Objects.nonNull(dimSchemaResp)) {
|
||||
SchemaElement dimSchema = new SchemaElement();
|
||||
BeanUtils.copyProperties(dimSchemaResp, dimSchema);
|
||||
dimensions.add(dimSchema);
|
||||
}
|
||||
|
||||
});
|
||||
.filter(
|
||||
dimId ->
|
||||
dimIdAndRespPair.containsKey(dimId)
|
||||
&& itemVisibilityInfo
|
||||
.getWhiteDimIdList()
|
||||
.contains(dimId))
|
||||
.forEach(
|
||||
dimId -> {
|
||||
SchemaElement dimSchemaResp = dimIdAndRespPair.get(dimId);
|
||||
if (Objects.nonNull(dimSchemaResp)) {
|
||||
SchemaElement dimSchema = new SchemaElement();
|
||||
BeanUtils.copyProperties(dimSchemaResp, dimSchema);
|
||||
dimensions.add(dimSchema);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (!CollectionUtils.isEmpty(chatDefaultConfig.getMetricIds())) {
|
||||
chatDefaultConfig.getMetricIds().stream()
|
||||
.filter(metricId -> metricIdAndRespPair.containsKey(metricId)
|
||||
&& itemVisibilityInfo.getWhiteMetricIdList().contains(metricId))
|
||||
.forEach(metricId -> {
|
||||
SchemaElement metricSchemaResp = metricIdAndRespPair.get(metricId);
|
||||
if (Objects.nonNull(metricSchemaResp)) {
|
||||
SchemaElement metricSchema = new SchemaElement();
|
||||
BeanUtils.copyProperties(metricSchemaResp, metricSchema);
|
||||
metrics.add(metricSchema);
|
||||
}
|
||||
});
|
||||
.filter(
|
||||
metricId ->
|
||||
metricIdAndRespPair.containsKey(metricId)
|
||||
&& itemVisibilityInfo
|
||||
.getWhiteMetricIdList()
|
||||
.contains(metricId))
|
||||
.forEach(
|
||||
metricId -> {
|
||||
SchemaElement metricSchemaResp = metricIdAndRespPair.get(metricId);
|
||||
if (Objects.nonNull(metricSchemaResp)) {
|
||||
SchemaElement metricSchema = new SchemaElement();
|
||||
BeanUtils.copyProperties(metricSchemaResp, metricSchema);
|
||||
metrics.add(metricSchema);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
defaultRichConfig.setDimensions(dimensions);
|
||||
@@ -310,21 +357,27 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
return defaultRichConfig;
|
||||
}
|
||||
|
||||
private List<KnowledgeInfoReq> fillKnowledgeBizName(List<KnowledgeInfoReq> knowledgeInfos,
|
||||
DataSetSchema modelSchema) {
|
||||
private List<KnowledgeInfoReq> fillKnowledgeBizName(
|
||||
List<KnowledgeInfoReq> knowledgeInfos, DataSetSchema modelSchema) {
|
||||
if (CollectionUtils.isEmpty(knowledgeInfos)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
Map<Long, SchemaElement> dimIdAndRespPair = modelSchema.getDimensions().stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
|
||||
knowledgeInfos.stream().forEach(knowledgeInfo -> {
|
||||
if (Objects.nonNull(knowledgeInfo)) {
|
||||
SchemaElement dimSchemaResp = dimIdAndRespPair.get(knowledgeInfo.getItemId());
|
||||
if (Objects.nonNull(dimSchemaResp)) {
|
||||
knowledgeInfo.setBizName(dimSchemaResp.getBizName());
|
||||
}
|
||||
}
|
||||
});
|
||||
Map<Long, SchemaElement> dimIdAndRespPair =
|
||||
modelSchema.getDimensions().stream()
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
|
||||
knowledgeInfos.stream()
|
||||
.forEach(
|
||||
knowledgeInfo -> {
|
||||
if (Objects.nonNull(knowledgeInfo)) {
|
||||
SchemaElement dimSchemaResp =
|
||||
dimIdAndRespPair.get(knowledgeInfo.getItemId());
|
||||
if (Objects.nonNull(dimSchemaResp)) {
|
||||
knowledgeInfo.setBizName(dimSchemaResp.getBizName());
|
||||
}
|
||||
}
|
||||
});
|
||||
return knowledgeInfos;
|
||||
}
|
||||
|
||||
@@ -332,5 +385,4 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
public List<ChatConfigRichResp> getAllChatRichConfig() {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -15,23 +15,21 @@ import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Service
|
||||
public class MemoryServiceImpl implements MemoryService {
|
||||
|
||||
@Autowired
|
||||
private ChatMemoryRepository chatMemoryRepository;
|
||||
@Autowired private ChatMemoryRepository chatMemoryRepository;
|
||||
|
||||
@Autowired
|
||||
private ExemplarService exemplarService;
|
||||
@Autowired private ExemplarService exemplarService;
|
||||
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
@Autowired private EmbeddingConfig embeddingConfig;
|
||||
|
||||
@Override
|
||||
public void createMemory(ChatMemoryDO memory) {
|
||||
@@ -59,8 +57,7 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
|
||||
@Override
|
||||
public PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq) {
|
||||
return PageHelper.startPage(pageMemoryReq.getCurrent(),
|
||||
pageMemoryReq.getPageSize())
|
||||
return PageHelper.startPage(pageMemoryReq.getCurrent(), pageMemoryReq.getPageSize())
|
||||
.doSelectPageInfo(() -> getMemories(pageMemoryReq.getChatMemoryFilter()));
|
||||
}
|
||||
|
||||
@@ -80,10 +77,14 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
queryWrapper.lambda().eq(ChatMemoryDO::getStatus, chatMemoryFilter.getStatus());
|
||||
}
|
||||
if (chatMemoryFilter.getHumanReviewRet() != null) {
|
||||
queryWrapper.lambda().eq(ChatMemoryDO::getHumanReviewRet, chatMemoryFilter.getHumanReviewRet());
|
||||
queryWrapper
|
||||
.lambda()
|
||||
.eq(ChatMemoryDO::getHumanReviewRet, chatMemoryFilter.getHumanReviewRet());
|
||||
}
|
||||
if (chatMemoryFilter.getLlmReviewRet() != null) {
|
||||
queryWrapper.lambda().eq(ChatMemoryDO::getLlmReviewRet, chatMemoryFilter.getLlmReviewRet());
|
||||
queryWrapper
|
||||
.lambda()
|
||||
.eq(ChatMemoryDO::getLlmReviewRet, chatMemoryFilter.getLlmReviewRet());
|
||||
}
|
||||
return chatMemoryRepository.getMemories(queryWrapper);
|
||||
}
|
||||
@@ -91,7 +92,9 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
@Override
|
||||
public List<ChatMemoryDO> getMemoriesForLlmReview() {
|
||||
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
|
||||
queryWrapper.lambda().eq(ChatMemoryDO::getStatus, MemoryStatus.PENDING)
|
||||
queryWrapper
|
||||
.lambda()
|
||||
.eq(ChatMemoryDO::getStatus, MemoryStatus.PENDING)
|
||||
.isNull(ChatMemoryDO::getLlmReviewRet);
|
||||
return chatMemoryRepository.getMemories(queryWrapper);
|
||||
}
|
||||
@@ -99,7 +102,8 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
@Override
|
||||
public void enableMemory(ChatMemoryDO memory) {
|
||||
memory.setStatus(MemoryStatus.ENABLED);
|
||||
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
exemplarService.storeExemplar(
|
||||
embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
Text2SQLExemplar.builder()
|
||||
.question(memory.getQuestion())
|
||||
.sideInfo(memory.getSideInfo())
|
||||
@@ -111,7 +115,8 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
@Override
|
||||
public void disableMemory(ChatMemoryDO memory) {
|
||||
memory.setStatus(MemoryStatus.DISABLED);
|
||||
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
exemplarService.removeExemplar(
|
||||
embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
Text2SQLExemplar.builder()
|
||||
.question(memory.getQuestion())
|
||||
.sideInfo(memory.getSideInfo())
|
||||
@@ -119,5 +124,4 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
.sql(memory.getS2sql())
|
||||
.build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -4,13 +4,13 @@ import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.PluginDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.PluginRepository;
|
||||
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.PluginDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.PluginRepository;
|
||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -36,8 +36,8 @@ public class PluginServiceImpl implements PluginService {
|
||||
|
||||
private ApplicationEventPublisher publisher;
|
||||
|
||||
public PluginServiceImpl(PluginRepository pluginRepository,
|
||||
ApplicationEventPublisher publisher) {
|
||||
public PluginServiceImpl(
|
||||
PluginRepository pluginRepository, ApplicationEventPublisher publisher) {
|
||||
this.pluginRepository = pluginRepository;
|
||||
this.publisher = publisher;
|
||||
}
|
||||
@@ -46,7 +46,7 @@ public class PluginServiceImpl implements PluginService {
|
||||
public synchronized void createPlugin(ChatPlugin plugin, User user) {
|
||||
PluginDO pluginDO = convert(plugin, user);
|
||||
pluginRepository.createPlugin(pluginDO);
|
||||
//compatible with H2 db
|
||||
// compatible with H2 db
|
||||
List<ChatPlugin> plugins = getPluginList();
|
||||
publisher.publishEvent(new PluginAddEvent(this, plugins.get(plugins.size() - 1)));
|
||||
}
|
||||
@@ -110,11 +110,18 @@ public class PluginServiceImpl implements PluginService {
|
||||
}
|
||||
List<PluginDO> pluginDOS = pluginRepository.query(queryWrapper);
|
||||
if (StringUtils.isNotBlank(pluginQueryReq.getPattern())) {
|
||||
pluginDOS = pluginDOS.stream().filter(pluginDO ->
|
||||
pluginDO.getPattern().contains(pluginQueryReq.getPattern())
|
||||
|| (pluginDO.getName() != null
|
||||
&& pluginDO.getName().contains(pluginQueryReq.getPattern())))
|
||||
.collect(Collectors.toList());
|
||||
pluginDOS =
|
||||
pluginDOS.stream()
|
||||
.filter(
|
||||
pluginDO ->
|
||||
pluginDO.getPattern()
|
||||
.contains(pluginQueryReq.getPattern())
|
||||
|| (pluginDO.getName() != null
|
||||
&& pluginDO.getName()
|
||||
.contains(
|
||||
pluginQueryReq
|
||||
.getPattern())))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
return convertList(pluginDOS);
|
||||
}
|
||||
@@ -123,13 +130,14 @@ public class PluginServiceImpl implements PluginService {
|
||||
public Optional<ChatPlugin> getPluginByName(String name) {
|
||||
log.info("name:{}", name);
|
||||
return getPluginList().stream()
|
||||
.filter(plugin -> {
|
||||
PluginParseConfig functionCallConfig = getPluginParseConfig(plugin);
|
||||
if (functionCallConfig == null) {
|
||||
return false;
|
||||
}
|
||||
return functionCallConfig.getName().equalsIgnoreCase(name);
|
||||
})
|
||||
.filter(
|
||||
plugin -> {
|
||||
PluginParseConfig functionCallConfig = getPluginParseConfig(plugin);
|
||||
if (functionCallConfig == null) {
|
||||
return false;
|
||||
}
|
||||
return functionCallConfig.getName().equalsIgnoreCase(name);
|
||||
})
|
||||
.findFirst();
|
||||
}
|
||||
|
||||
@@ -137,9 +145,10 @@ public class PluginServiceImpl implements PluginService {
|
||||
if (StringUtils.isBlank(plugin.getParseModeConfig())) {
|
||||
return null;
|
||||
}
|
||||
PluginParseConfig functionCallConfig = JsonUtil.toObject(
|
||||
plugin.getParseModeConfig(), PluginParseConfig.class);
|
||||
if (Objects.isNull(functionCallConfig) || StringUtils.isEmpty(functionCallConfig.getName())) {
|
||||
PluginParseConfig functionCallConfig =
|
||||
JsonUtil.toObject(plugin.getParseModeConfig(), PluginParseConfig.class);
|
||||
if (Objects.isNull(functionCallConfig)
|
||||
|| StringUtils.isEmpty(functionCallConfig.getName())) {
|
||||
return null;
|
||||
}
|
||||
if (StringUtils.isBlank(functionCallConfig.getName())) {
|
||||
@@ -158,21 +167,28 @@ public class PluginServiceImpl implements PluginService {
|
||||
List<ChatPlugin> pluginList = getPluginList();
|
||||
|
||||
return pluginList.stream()
|
||||
.filter(plugin -> {
|
||||
PluginParseConfig functionCallConfig = getPluginParseConfig(plugin);
|
||||
if (functionCallConfig == null) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.collect(Collectors.toMap(a -> {
|
||||
PluginParseConfig functionCallConfig = JsonUtil.toObject(
|
||||
a.getParseModeConfig(), PluginParseConfig.class);
|
||||
return functionCallConfig.getName();
|
||||
}, a -> a, (k1, k2) -> k1));
|
||||
.filter(
|
||||
plugin -> {
|
||||
PluginParseConfig functionCallConfig = getPluginParseConfig(plugin);
|
||||
if (functionCallConfig == null) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
a -> {
|
||||
PluginParseConfig functionCallConfig =
|
||||
JsonUtil.toObject(
|
||||
a.getParseModeConfig(),
|
||||
PluginParseConfig.class);
|
||||
return functionCallConfig.getName();
|
||||
},
|
||||
a -> a,
|
||||
(k1, k2) -> k1));
|
||||
}
|
||||
|
||||
//todo
|
||||
// todo
|
||||
private List<ChatPlugin> authCheck(List<ChatPlugin> plugins, User user) {
|
||||
return plugins;
|
||||
}
|
||||
@@ -181,8 +197,10 @@ public class PluginServiceImpl implements PluginService {
|
||||
ChatPlugin plugin = new ChatPlugin();
|
||||
BeanUtils.copyProperties(pluginDO, plugin);
|
||||
if (StringUtils.isNotBlank(pluginDO.getDataSet())) {
|
||||
plugin.setDataSetList(Arrays.stream(pluginDO.getDataSet().split(","))
|
||||
.map(Long::parseLong).collect(Collectors.toList()));
|
||||
plugin.setDataSetList(
|
||||
Arrays.stream(pluginDO.getDataSet().split(","))
|
||||
.map(Long::parseLong)
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
return plugin;
|
||||
}
|
||||
@@ -212,5 +230,4 @@ public class PluginServiceImpl implements PluginService {
|
||||
}
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.server.service.StatisticsService;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.StatisticsMapper;
|
||||
import com.tencent.supersonic.chat.server.service.StatisticsService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
@@ -14,8 +14,7 @@ import java.util.List;
|
||||
@Slf4j
|
||||
public class StatisticsServiceImpl implements StatisticsService {
|
||||
|
||||
@Autowired
|
||||
private StatisticsMapper statisticsMapper;
|
||||
@Autowired private StatisticsMapper statisticsMapper;
|
||||
|
||||
@Async
|
||||
@Override
|
||||
|
||||
@@ -29,7 +29,6 @@ import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.Constants.ADMIN_LOWER;
|
||||
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class ChatConfigHelper {
|
||||
@@ -38,7 +37,10 @@ public class ChatConfigHelper {
|
||||
ChatConfig chatConfig = new ChatConfig();
|
||||
BeanUtils.copyProperties(extendBaseCmd, chatConfig);
|
||||
RecordInfo recordInfo = new RecordInfo();
|
||||
String creator = (Objects.isNull(user) || StringUtils.isEmpty(user.getName())) ? ADMIN_LOWER : user.getName();
|
||||
String creator =
|
||||
(Objects.isNull(user) || StringUtils.isEmpty(user.getName()))
|
||||
? ADMIN_LOWER
|
||||
: user.getName();
|
||||
recordInfo.createdBy(creator);
|
||||
chatConfig.setRecordInfo(recordInfo);
|
||||
chatConfig.setStatus(StatusEnum.ONLINE);
|
||||
@@ -50,8 +52,10 @@ public class ChatConfigHelper {
|
||||
|
||||
BeanUtils.copyProperties(extendEditCmd, chatConfig);
|
||||
RecordInfo recordInfo = new RecordInfo();
|
||||
String user = (Objects.isNull(facadeUser) || StringUtils.isEmpty(facadeUser.getName()))
|
||||
? ADMIN_LOWER : facadeUser.getName();
|
||||
String user =
|
||||
(Objects.isNull(facadeUser) || StringUtils.isEmpty(facadeUser.getName()))
|
||||
? ADMIN_LOWER
|
||||
: facadeUser.getName();
|
||||
recordInfo.updatedBy(user);
|
||||
chatConfig.setRecordInfo(recordInfo);
|
||||
return chatConfig;
|
||||
@@ -61,8 +65,9 @@ public class ChatConfigHelper {
|
||||
if (Objects.isNull(modelSchema) || CollectionUtils.isEmpty(modelSchema.getDimensions())) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
Map<Long, List<SchemaElement>> dimIdAndDescPair = modelSchema.getDimensions()
|
||||
.stream().collect(Collectors.groupingBy(SchemaElement::getId));
|
||||
Map<Long, List<SchemaElement>> dimIdAndDescPair =
|
||||
modelSchema.getDimensions().stream()
|
||||
.collect(Collectors.groupingBy(SchemaElement::getId));
|
||||
return new ArrayList<>(dimIdAndDescPair.keySet());
|
||||
}
|
||||
|
||||
@@ -70,8 +75,9 @@ public class ChatConfigHelper {
|
||||
if (Objects.isNull(modelSchema) || CollectionUtils.isEmpty(modelSchema.getMetrics())) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
Map<Long, List<SchemaElement>> metricIdAndDescPair = modelSchema.getMetrics()
|
||||
.stream().collect(Collectors.groupingBy(SchemaElement::getId));
|
||||
Map<Long, List<SchemaElement>> metricIdAndDescPair =
|
||||
modelSchema.getMetrics().stream()
|
||||
.collect(Collectors.groupingBy(SchemaElement::getId));
|
||||
return new ArrayList<>(metricIdAndDescPair.keySet());
|
||||
}
|
||||
|
||||
@@ -81,7 +87,8 @@ public class ChatConfigHelper {
|
||||
|
||||
chatConfigDO.setChatAggConfig(JsonUtil.toString(chatConfig.getChatAggConfig()));
|
||||
chatConfigDO.setChatDetailConfig(JsonUtil.toString(chatConfig.getChatDetailConfig()));
|
||||
chatConfigDO.setRecommendedQuestions(JsonUtil.toString(chatConfig.getRecommendedQuestions()));
|
||||
chatConfigDO.setRecommendedQuestions(
|
||||
JsonUtil.toString(chatConfig.getRecommendedQuestions()));
|
||||
|
||||
if (Objects.isNull(chatConfig.getStatus())) {
|
||||
chatConfigDO.setStatus(null);
|
||||
@@ -112,7 +119,8 @@ public class ChatConfigHelper {
|
||||
chatConfigDescriptor.setChatAggConfig(
|
||||
JsonUtil.toObject(chatConfigDO.getChatAggConfig(), ChatAggConfigReq.class));
|
||||
chatConfigDescriptor.setRecommendedQuestions(
|
||||
JsonUtil.toList(chatConfigDO.getRecommendedQuestions(), RecommendedQuestionReq.class));
|
||||
JsonUtil.toList(
|
||||
chatConfigDO.getRecommendedQuestions(), RecommendedQuestionReq.class));
|
||||
chatConfigDescriptor.setStatusEnum(StatusEnum.of(chatConfigDO.getStatus()));
|
||||
|
||||
chatConfigDescriptor.setCreatedBy(chatConfigDO.getCreatedBy());
|
||||
@@ -120,7 +128,6 @@ public class ChatConfigHelper {
|
||||
chatConfigDescriptor.setUpdatedBy(chatConfigDO.getUpdatedBy());
|
||||
chatConfigDescriptor.setUpdatedAt(chatConfigDO.getUpdatedAt());
|
||||
|
||||
|
||||
if (StringUtils.isEmpty(chatConfigDO.getChatAggConfig())) {
|
||||
chatConfigDescriptor.setChatAggConfig(generateEmptyChatAggConfigResp());
|
||||
}
|
||||
|
||||
@@ -21,39 +21,45 @@ public class ComponentFactory {
|
||||
private static List<PluginRecognizer> pluginRecognizers = new ArrayList<>();
|
||||
|
||||
public static List<ParseResultProcessor> getParseProcessors() {
|
||||
return CollectionUtils.isEmpty(parseProcessors) ? init(ParseResultProcessor.class,
|
||||
parseProcessors) : parseProcessors;
|
||||
return CollectionUtils.isEmpty(parseProcessors)
|
||||
? init(ParseResultProcessor.class, parseProcessors)
|
||||
: parseProcessors;
|
||||
}
|
||||
|
||||
public static List<ExecuteResultProcessor> getExecuteProcessors() {
|
||||
return CollectionUtils.isEmpty(executeProcessors)
|
||||
? init(ExecuteResultProcessor.class, executeProcessors) : executeProcessors;
|
||||
? init(ExecuteResultProcessor.class, executeProcessors)
|
||||
: executeProcessors;
|
||||
}
|
||||
|
||||
public static List<ChatQueryParser> getChatParsers() {
|
||||
return CollectionUtils.isEmpty(chatQueryParsers)
|
||||
? init(ChatQueryParser.class, chatQueryParsers) : chatQueryParsers;
|
||||
? init(ChatQueryParser.class, chatQueryParsers)
|
||||
: chatQueryParsers;
|
||||
}
|
||||
|
||||
public static List<ChatQueryExecutor> getChatExecutors() {
|
||||
return CollectionUtils.isEmpty(chatQueryExecutors)
|
||||
? init(ChatQueryExecutor.class, chatQueryExecutors) : chatQueryExecutors;
|
||||
? init(ChatQueryExecutor.class, chatQueryExecutors)
|
||||
: chatQueryExecutors;
|
||||
}
|
||||
|
||||
public static List<PluginRecognizer> getPluginRecognizers() {
|
||||
return CollectionUtils.isEmpty(pluginRecognizers)
|
||||
? init(PluginRecognizer.class, pluginRecognizers) : pluginRecognizers;
|
||||
? init(PluginRecognizer.class, pluginRecognizers)
|
||||
: pluginRecognizers;
|
||||
}
|
||||
|
||||
private static <T> List<T> init(Class<T> factoryType, List list) {
|
||||
list.addAll(SpringFactoriesLoader.loadFactories(factoryType,
|
||||
Thread.currentThread().getContextClassLoader()));
|
||||
list.addAll(
|
||||
SpringFactoriesLoader.loadFactories(
|
||||
factoryType, Thread.currentThread().getContextClassLoader()));
|
||||
return list;
|
||||
}
|
||||
|
||||
private static <T> T init(Class<T> factoryType) {
|
||||
return SpringFactoriesLoader.loadFactories(factoryType,
|
||||
Thread.currentThread().getContextClassLoader()).get(0);
|
||||
return SpringFactoriesLoader.loadFactories(
|
||||
factoryType, Thread.currentThread().getContextClassLoader())
|
||||
.get(0);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
|
||||
import java.util.Objects;
|
||||
@@ -49,5 +49,4 @@ public class QueryReqConverter {
|
||||
}
|
||||
return queryNLReq;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -8,7 +8,8 @@ import java.util.Map;
|
||||
|
||||
public class ResultFormatter {
|
||||
|
||||
public static String transform2TextNew(List<QueryColumn> queryColumns, List<Map<String, Object>> queryResults) {
|
||||
public static String transform2TextNew(
|
||||
List<QueryColumn> queryColumns, List<Map<String, Object>> queryResults) {
|
||||
if (CollectionUtils.isEmpty(queryColumns)) {
|
||||
return "";
|
||||
}
|
||||
@@ -35,5 +36,4 @@ public class ResultFormatter {
|
||||
}
|
||||
return table.toString();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user