mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
(feature)(chat&common)Introduce ChatMemory module to support dynamic few-shot exemplars.#1097
This commit is contained in:
@@ -1,13 +1,18 @@
|
|||||||
package com.tencent.supersonic.chat.server.executor;
|
package com.tencent.supersonic.chat.server.executor;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||||
|
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||||
import com.tencent.supersonic.chat.server.util.ResultFormatter;
|
import com.tencent.supersonic.chat.server.util.ResultFormatter;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
|
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||||
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||||
import lombok.SneakyThrows;
|
import lombok.SneakyThrows;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
public class SqlExecutor implements ChatExecutor {
|
public class SqlExecutor implements ChatExecutor {
|
||||||
|
|
||||||
@@ -21,6 +26,18 @@ public class SqlExecutor implements ChatExecutor {
|
|||||||
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
|
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
|
||||||
queryResult.getQueryResults());
|
queryResult.getQueryResults());
|
||||||
queryResult.setTextResult(textResult);
|
queryResult.setTextResult(textResult);
|
||||||
|
|
||||||
|
if (queryResult.getQueryState().equals(QueryState.SUCCESS)
|
||||||
|
&& queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
||||||
|
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
|
||||||
|
memoryService.createMemory(ChatMemoryDO.builder()
|
||||||
|
.agentId(chatExecuteContext.getAgentId())
|
||||||
|
.status(ChatMemoryDO.Status.PENDING)
|
||||||
|
.question(chatExecuteContext.getQueryText())
|
||||||
|
.s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getS2SQL())
|
||||||
|
.schema(buildSchemaStr(chatExecuteContext.getParseInfo()))
|
||||||
|
.build());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return queryResult;
|
return queryResult;
|
||||||
@@ -38,4 +55,36 @@ public class SqlExecutor implements ChatExecutor {
|
|||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public String buildSchemaStr(SemanticParseInfo parseInfo) {
|
||||||
|
String tableStr = parseInfo.getDataSet().getName();
|
||||||
|
StringBuilder metricStr = new StringBuilder();
|
||||||
|
StringBuilder dimensionStr = new StringBuilder();
|
||||||
|
|
||||||
|
parseInfo.getMetrics().stream().forEach(
|
||||||
|
metric -> {
|
||||||
|
metricStr.append(metric.getName());
|
||||||
|
if (StringUtils.isNotEmpty(metric.getDescription())) {
|
||||||
|
metricStr.append(" COMMENT '" + metric.getDescription() + "'");
|
||||||
|
}
|
||||||
|
if (StringUtils.isNotEmpty(metric.getDefaultAgg())) {
|
||||||
|
metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'");
|
||||||
|
}
|
||||||
|
metricStr.append(",");
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
parseInfo.getDimensions().stream().forEach(
|
||||||
|
dimension -> {
|
||||||
|
dimensionStr.append(dimension.getName());
|
||||||
|
if (StringUtils.isNotEmpty(dimension.getDescription())) {
|
||||||
|
dimensionStr.append(" COMMENT '" + dimension.getDescription() + "'");
|
||||||
|
}
|
||||||
|
dimensionStr.append(",");
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
String template = "Table: %s, Metrics: [%s], Dimensions: [%s]";
|
||||||
|
return String.format(template, tableStr, metricStr, dimensionStr);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,71 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.memory;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.ReviewResult;
|
||||||
|
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||||
|
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||||
|
import com.tencent.supersonic.common.util.S2ChatModelProvider;
|
||||||
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
import dev.langchain4j.model.input.Prompt;
|
||||||
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.scheduling.annotation.Scheduled;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.regex.Matcher;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
|
@Component
|
||||||
|
@Slf4j
|
||||||
|
public class MemoryReviewTask {
|
||||||
|
|
||||||
|
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||||
|
|
||||||
|
private static final String INSTRUCTION = ""
|
||||||
|
+ "#Role: You are a senior data engineer experienced in writing SQL.\n"
|
||||||
|
+ "#Task: Your will be provided with a user question and the SQL written by junior engineer,"
|
||||||
|
+ "please take a review and give your opinion.\n"
|
||||||
|
+ "#Rules: "
|
||||||
|
+ "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`."
|
||||||
|
+ "2.DO NOT check the usage of `数据日期` field and `datediff()` function.\n"
|
||||||
|
+ "#Question: %s\n"
|
||||||
|
+ "#Schema: %s\n"
|
||||||
|
+ "#SQL: %s\n"
|
||||||
|
+ "#Response: ";
|
||||||
|
|
||||||
|
private static final Pattern OUTPUT_PATTERN = Pattern.compile("opinion=(.*),.*comment=(.*)");
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private MemoryService memoryService;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private AgentService agentService;
|
||||||
|
|
||||||
|
@Scheduled(fixedDelay = 60 * 1000)
|
||||||
|
public void review() {
|
||||||
|
memoryService.getMemories().stream()
|
||||||
|
.filter(c -> c.getStatus() == ChatMemoryDO.Status.PENDING)
|
||||||
|
.forEach(m -> {
|
||||||
|
Agent chatAgent = agentService.getAgent(m.getAgentId());
|
||||||
|
String promptStr = String.format(INSTRUCTION, m.getQuestion(), m.getSchema(), m.getS2sql());
|
||||||
|
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||||
|
|
||||||
|
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);
|
||||||
|
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(chatAgent.getLlmConfig());
|
||||||
|
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
|
||||||
|
keyPipelineLog.info("MemoryReviewTask modelResp:{}", response);
|
||||||
|
|
||||||
|
Matcher matcher = OUTPUT_PATTERN.matcher(response);
|
||||||
|
if (matcher.find()) {
|
||||||
|
m.setLlmReviewRet(ReviewResult.valueOf(matcher.group(1)));
|
||||||
|
m.setLlmReviewCmt(matcher.group(2));
|
||||||
|
memoryService.updateMemory(m);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.server.parser;
|
package com.tencent.supersonic.chat.server.parser;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||||
|
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||||
@@ -12,6 +13,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||||
@@ -64,9 +66,11 @@ public class NL2SQLParser implements ChatParser {
|
|||||||
if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
|
if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
processMultiTurn(chatParseContext, parseResp);
|
processMultiTurn(chatParseContext);
|
||||||
|
|
||||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||||
|
addExemplars(chatParseContext.getAgent().getId(), queryReq);
|
||||||
|
|
||||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||||
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
|
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
|
||||||
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
|
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
|
||||||
@@ -131,7 +135,7 @@ public class NL2SQLParser implements ChatParser {
|
|||||||
parseInfo.setTextInfo(textBuilder.toString());
|
parseInfo.setTextInfo(textBuilder.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void processMultiTurn(ChatParseContext chatParseContext, ParseResp parseResp) {
|
private void processMultiTurn(ChatParseContext chatParseContext) {
|
||||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||||
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
|
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
|
||||||
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
||||||
@@ -220,6 +224,13 @@ public class NL2SQLParser implements ChatParser {
|
|||||||
return contextualList;
|
return contextualList;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void addExemplars(Integer agentId, QueryReq queryReq) {
|
||||||
|
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
|
||||||
|
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(agentId.toString(),
|
||||||
|
queryReq.getQueryText(), 5);
|
||||||
|
queryReq.getExemplars().addAll(exemplars);
|
||||||
|
}
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
@Data
|
@Data
|
||||||
public static class RewriteContext {
|
public static class RewriteContext {
|
||||||
|
|||||||
@@ -0,0 +1,70 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||||
|
|
||||||
|
import com.baomidou.mybatisplus.annotation.IdType;
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableField;
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableId;
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableName;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.ToString;
|
||||||
|
|
||||||
|
import java.util.Date;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
@ToString
|
||||||
|
@TableName("s2_chat_memory")
|
||||||
|
public class ChatMemoryDO {
|
||||||
|
@TableId(type = IdType.AUTO)
|
||||||
|
private Long id;
|
||||||
|
|
||||||
|
@TableField("question")
|
||||||
|
private String question;
|
||||||
|
|
||||||
|
@TableField("agent_id")
|
||||||
|
private Integer agentId;
|
||||||
|
|
||||||
|
@TableField("db_schema")
|
||||||
|
private String schema;
|
||||||
|
|
||||||
|
@TableField("s2_sql")
|
||||||
|
private String s2sql;
|
||||||
|
|
||||||
|
@TableField("status")
|
||||||
|
private Status status;
|
||||||
|
|
||||||
|
@TableField("llm_review")
|
||||||
|
private ReviewResult llmReviewRet;
|
||||||
|
|
||||||
|
@TableField("llm_comment")
|
||||||
|
private String llmReviewCmt;
|
||||||
|
|
||||||
|
@TableField("human_review")
|
||||||
|
private ReviewResult humanReviewRet;
|
||||||
|
|
||||||
|
@TableField("human_comment")
|
||||||
|
private String humanReviewCmt;
|
||||||
|
|
||||||
|
@TableField("created_by")
|
||||||
|
private String createdBy;
|
||||||
|
|
||||||
|
@TableField("created_at")
|
||||||
|
private Date createdAt;
|
||||||
|
|
||||||
|
@TableField("updated_by")
|
||||||
|
private String updatedBy;
|
||||||
|
|
||||||
|
@TableField("updated_at")
|
||||||
|
private Date updatedAt;
|
||||||
|
|
||||||
|
public enum ReviewResult {
|
||||||
|
POSITIVE,
|
||||||
|
NEGATIVE
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum Status {
|
||||||
|
PENDING,
|
||||||
|
ENABLED,
|
||||||
|
DISABLED;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.persistence.mapper;
|
||||||
|
|
||||||
|
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||||
|
import org.apache.ibatis.annotations.Mapper;
|
||||||
|
|
||||||
|
@Mapper
|
||||||
|
public interface ChatMemoryMapper extends BaseMapper<ChatMemoryDO> {
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.server.persistence.repository;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
|
||||||
|
|
||||||
public interface ChatContextRepository {
|
|
||||||
|
|
||||||
ChatContext getOrCreateContext(int chatId);
|
|
||||||
|
|
||||||
void updateContext(ChatContext chatCtx);
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.persistence.repository;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface ChatMemoryRepository {
|
||||||
|
void createMemory(ChatMemoryDO chatMemoryDO);
|
||||||
|
|
||||||
|
void updateMemory(ChatMemoryDO chatMemoryDO);
|
||||||
|
|
||||||
|
ChatMemoryDO getMemory(Long id);
|
||||||
|
|
||||||
|
List<ChatMemoryDO> getMemories();
|
||||||
|
}
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.persistence.repository.impl;
|
||||||
|
|
||||||
|
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.mapper.ChatMemoryMapper;
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
|
||||||
|
import org.springframework.context.annotation.Primary;
|
||||||
|
import org.springframework.stereotype.Repository;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Repository
|
||||||
|
@Primary
|
||||||
|
public class ChatMemoryRepositoryImpl implements ChatMemoryRepository {
|
||||||
|
|
||||||
|
private final ChatMemoryMapper chatMemoryMapper;
|
||||||
|
|
||||||
|
public ChatMemoryRepositoryImpl(ChatMemoryMapper chatMemoryMapper) {
|
||||||
|
this.chatMemoryMapper = chatMemoryMapper;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void createMemory(ChatMemoryDO chatMemoryDO) {
|
||||||
|
chatMemoryMapper.insert(chatMemoryDO);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void updateMemory(ChatMemoryDO chatMemoryDO) {
|
||||||
|
chatMemoryMapper.updateById(chatMemoryDO);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ChatMemoryDO getMemory(Long id) {
|
||||||
|
return chatMemoryMapper.selectById(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<ChatMemoryDO> getMemories() {
|
||||||
|
return chatMemoryMapper.selectList(new QueryWrapper<>());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.service;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface MemoryService {
|
||||||
|
void createMemory(ChatMemoryDO memory);
|
||||||
|
|
||||||
|
void updateMemory(ChatMemoryDO memory);
|
||||||
|
|
||||||
|
List<ChatMemoryDO> getMemories();
|
||||||
|
}
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.service.impl;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.ReviewResult;
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO.Status;
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
|
||||||
|
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||||
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
|
import com.tencent.supersonic.common.service.ExemplarService;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
public class MemoryServiceImpl implements MemoryService {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private ChatMemoryRepository chatMemoryRepository;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private ExemplarService exemplarService;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void createMemory(ChatMemoryDO memory) {
|
||||||
|
if (ReviewResult.POSITIVE.equals(memory.getHumanReviewRet())) {
|
||||||
|
enableMemory(memory);
|
||||||
|
}
|
||||||
|
chatMemoryRepository.createMemory(memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void updateMemory(ChatMemoryDO memory) {
|
||||||
|
if (!ChatMemoryDO.Status.ENABLED.equals(memory.getStatus())
|
||||||
|
&& ReviewResult.POSITIVE.equals(memory.getHumanReviewRet())) {
|
||||||
|
enableMemory(memory);
|
||||||
|
}
|
||||||
|
chatMemoryRepository.updateMemory(memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<ChatMemoryDO> getMemories() {
|
||||||
|
return chatMemoryRepository.getMemories();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void enableMemory(ChatMemoryDO memory) {
|
||||||
|
exemplarService.storeExemplar(memory.getAgentId().toString(),
|
||||||
|
SqlExemplar.builder()
|
||||||
|
.question(memory.getQuestion())
|
||||||
|
.dbSchema(memory.getSchema())
|
||||||
|
.sql(memory.getS2sql())
|
||||||
|
.build());
|
||||||
|
memory.setStatus(Status.ENABLED);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package com.tencent.supersonic.common.pojo;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class SqlExemplar {
|
||||||
|
|
||||||
|
private String question;
|
||||||
|
|
||||||
|
private String dbSchema;
|
||||||
|
|
||||||
|
private String sql;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
package com.tencent.supersonic.common.service;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface ExemplarService {
|
||||||
|
void storeExemplar(String collection, SqlExemplar exemplar);
|
||||||
|
|
||||||
|
List<SqlExemplar> recallExemplars(String collection, String query, int num);
|
||||||
|
|
||||||
|
List<SqlExemplar> recallExemplars(String query, int num);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,87 @@
|
|||||||
|
package com.tencent.supersonic.common.service.impl;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.type.TypeReference;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
|
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||||
|
import com.tencent.supersonic.common.service.ExemplarService;
|
||||||
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
|
import dev.langchain4j.data.document.Metadata;
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||||
|
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.boot.CommandLineRunner;
|
||||||
|
import org.springframework.core.io.ClassPathResource;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
||||||
|
|
||||||
|
private static final String SYS_EXEMPLAR_FILE = "s2ql_exemplar.json";
|
||||||
|
|
||||||
|
private TypeReference<List<SqlExemplar>> valueTypeRef = new TypeReference<List<SqlExemplar>>() {
|
||||||
|
};
|
||||||
|
|
||||||
|
private final ObjectMapper objectMapper = JsonUtil.INSTANCE.getObjectMapper();
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private EmbeddingConfig embeddingConfig;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private EmbeddingService embeddingService;
|
||||||
|
|
||||||
|
public void storeExemplar(String collection, SqlExemplar exemplar) {
|
||||||
|
Metadata metadata = Metadata.from(JsonUtil.toMap(JsonUtil.toString(exemplar),
|
||||||
|
String.class, Object.class));
|
||||||
|
TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata);
|
||||||
|
|
||||||
|
embeddingService.addQuery(collection, Lists.newArrayList(segment));
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SqlExemplar> recallExemplars(String query, int num) {
|
||||||
|
String collection = embeddingConfig.getText2sqlCollectionName();
|
||||||
|
return recallExemplars(collection, query, num);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SqlExemplar> recallExemplars(String collection, String query, int num) {
|
||||||
|
List<SqlExemplar> exemplars = Lists.newArrayList();
|
||||||
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||||
|
.queryTextsList(Lists.newArrayList(query))
|
||||||
|
.build();
|
||||||
|
List<RetrieveQueryResult> results = embeddingService.retrieveQuery(collection, retrieveQuery, num);
|
||||||
|
results.stream().forEach(ret -> {
|
||||||
|
ret.getRetrieval().stream().forEach(r -> {
|
||||||
|
exemplars.add(JsonUtil.mapToObject(r.getMetadata(), SqlExemplar.class));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return exemplars;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run(String... args) {
|
||||||
|
try {
|
||||||
|
loadSysExemplars();
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.error("Failed to load system exemplars", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void loadSysExemplars() throws IOException {
|
||||||
|
ClassPathResource resource = new ClassPathResource(SYS_EXEMPLAR_FILE);
|
||||||
|
InputStream inputStream = resource.getInputStream();
|
||||||
|
List<SqlExemplar> exemplars = objectMapper.readValue(inputStream, valueTypeRef);
|
||||||
|
String collection = embeddingConfig.getText2sqlCollectionName();
|
||||||
|
exemplars.stream().forEach(e -> storeExemplar(collection, e));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,14 +1,17 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.request;
|
package com.tencent.supersonic.headless.api.pojo.request;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
import com.google.common.collect.Sets;
|
import com.google.common.collect.Sets;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||||
import com.tencent.supersonic.common.config.LLMConfig;
|
import com.tencent.supersonic.common.config.LLMConfig;
|
||||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@@ -24,4 +27,5 @@ public class QueryReq {
|
|||||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||||
private LLMConfig llmConfig;
|
private LLMConfig llmConfig;
|
||||||
|
private List<SqlExemplar> exemplars = Lists.newArrayList();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import com.tencent.supersonic.common.config.LLMConfig;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState;
|
import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||||
@@ -49,6 +50,7 @@ public class QueryContext {
|
|||||||
private WorkflowState workflowState;
|
private WorkflowState workflowState;
|
||||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||||
private LLMConfig llmConfig;
|
private LLMConfig llmConfig;
|
||||||
|
private List<SqlExemplar> exemplars;
|
||||||
|
|
||||||
public List<SemanticQuery> getCandidateQueries() {
|
public List<SemanticQuery> getCandidateQueries() {
|
||||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class Exemplar {
|
|
||||||
|
|
||||||
private String question;
|
|
||||||
|
|
||||||
private String questionAugmented;
|
|
||||||
|
|
||||||
private String dbSchema;
|
|
||||||
|
|
||||||
private String sql;
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
|
||||||
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.core.type.TypeReference;
|
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
|
||||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
|
||||||
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
|
||||||
import dev.langchain4j.data.document.Metadata;
|
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
|
||||||
import dev.langchain4j.store.embedding.Retrieval;
|
|
||||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
|
||||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
|
||||||
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.boot.CommandLineRunner;
|
|
||||||
import org.springframework.core.annotation.Order;
|
|
||||||
import org.springframework.core.io.ClassPathResource;
|
|
||||||
import org.springframework.stereotype.Component;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
@Component
|
|
||||||
@Order(0)
|
|
||||||
public class ExemplarManager implements CommandLineRunner {
|
|
||||||
|
|
||||||
private static final String EXAMPLE_JSON_FILE = "s2ql_exemplar.json";
|
|
||||||
|
|
||||||
@Autowired
|
|
||||||
private EmbeddingService embeddingService;
|
|
||||||
|
|
||||||
@Autowired
|
|
||||||
private EmbeddingConfig embeddingConfig;
|
|
||||||
|
|
||||||
private TypeReference<List<Exemplar>> valueTypeRef = new TypeReference<List<Exemplar>>() {
|
|
||||||
};
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void run(String... args) {
|
|
||||||
try {
|
|
||||||
if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) {
|
|
||||||
loadDefaultExemplars();
|
|
||||||
}
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Failed to init examples", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addExemplars(List<Exemplar> exemplars, String collectionName) {
|
|
||||||
List<TextSegment> queries = new ArrayList<>();
|
|
||||||
for (int i = 0; i < exemplars.size(); i++) {
|
|
||||||
Exemplar exemplar = exemplars.get(i);
|
|
||||||
String question = exemplar.getQuestion();
|
|
||||||
Map<String, Object> metaDataMap = JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class);
|
|
||||||
TextSegment embeddingQuery = TextSegment.from(question, new Metadata(metaDataMap));
|
|
||||||
TextSegmentConvert.addQueryId(embeddingQuery, String.valueOf(i));
|
|
||||||
queries.add(embeddingQuery);
|
|
||||||
}
|
|
||||||
embeddingService.addQuery(collectionName, queries);
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<Map<String, String>> recallExemplars(String queryText, int maxResults) {
|
|
||||||
String collectionName = embeddingConfig.getText2sqlCollectionName();
|
|
||||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText))
|
|
||||||
.queryEmbeddings(null).build();
|
|
||||||
|
|
||||||
List<RetrieveQueryResult> resultList = embeddingService.retrieveQuery(collectionName, retrieveQuery,
|
|
||||||
maxResults);
|
|
||||||
List<Map<String, String>> result = new ArrayList<>();
|
|
||||||
if (CollectionUtils.isEmpty(resultList)) {
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
for (Retrieval retrieval : resultList.get(0).getRetrieval()) {
|
|
||||||
if (Objects.nonNull(retrieval.getMetadata()) && !retrieval.getMetadata().isEmpty()) {
|
|
||||||
Map<String, String> convertedMap = retrieval.getMetadata().entrySet().stream()
|
|
||||||
.collect(Collectors.toMap(Map.Entry::getKey, entry -> String.valueOf(entry.getValue())));
|
|
||||||
result.add(convertedMap);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void loadDefaultExemplars() throws IOException {
|
|
||||||
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
|
|
||||||
InputStream inputStream = resource.getInputStream();
|
|
||||||
List<Exemplar> examples = JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
|
|
||||||
String collectionName = embeddingConfig.getText2sqlCollectionName();
|
|
||||||
addExemplars(examples, collectionName);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -109,6 +109,8 @@ public class LLMRequestService {
|
|||||||
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
||||||
llmReq.setLlmConfig(queryCtx.getLlmConfig());
|
llmReq.setLlmConfig(queryCtx.getLlmConfig());
|
||||||
|
|
||||||
|
llmReq.setExemplars(queryCtx.getExemplars());
|
||||||
|
|
||||||
return llmReq;
|
return llmReq;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
@@ -42,11 +43,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
public LLMResp generate(LLMReq llmReq) {
|
public LLMResp generate(LLMReq llmReq) {
|
||||||
//1.recall exemplars
|
//1.recall exemplars
|
||||||
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq);
|
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq);
|
||||||
List<List<Map<String, String>>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
|
List<List<SqlExemplar>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
|
||||||
|
|
||||||
//2.generate sql generation prompt for each self-consistency inference
|
//2.generate sql generation prompt for each self-consistency inference
|
||||||
Map<Prompt, List<Map<String, String>>> prompt2Exemplar = new HashMap<>();
|
Map<Prompt, List<SqlExemplar>> prompt2Exemplar = new HashMap<>();
|
||||||
for (List<Map<String, String>> exemplars : exemplarsList) {
|
for (List<SqlExemplar> exemplars : exemplarsList) {
|
||||||
Prompt prompt = generatePrompt(llmReq, exemplars);
|
Prompt prompt = generatePrompt(llmReq, exemplars);
|
||||||
prompt2Exemplar.put(prompt, exemplars);
|
prompt2Exemplar.put(prompt, exemplars);
|
||||||
}
|
}
|
||||||
@@ -67,25 +68,24 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(
|
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(
|
||||||
Lists.newArrayList(prompt2Output.values()));
|
Lists.newArrayList(prompt2Output.values()));
|
||||||
LLMResp llmResp = new LLMResp();
|
LLMResp llmResp = new LLMResp();
|
||||||
llmResp.setQuery(llmReq.getQueryText());
|
llmResp.setQuery(promptHelper.buildAugmentedQuestion(llmReq));
|
||||||
|
llmResp.setDbSchema(promptHelper.buildSchemaStr(llmReq));
|
||||||
|
llmResp.setSqlOutput(sqlMapPair.getLeft());
|
||||||
//TODO: should use the same few-shot exemplars as the one chose by self-consistency vote
|
//TODO: should use the same few-shot exemplars as the one chose by self-consistency vote
|
||||||
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(exemplarsList.get(0), sqlMapPair.getRight()));
|
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(exemplarsList.get(0), sqlMapPair.getRight()));
|
||||||
|
|
||||||
return llmResp;
|
return llmResp;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Prompt generatePrompt(LLMReq llmReq, List<Map<String, String>> fewshotExampleList) {
|
private Prompt generatePrompt(LLMReq llmReq, List<SqlExemplar> fewshotExampleList) {
|
||||||
StringBuilder exemplarsStr = new StringBuilder();
|
StringBuilder exemplarsStr = new StringBuilder();
|
||||||
for (Map<String, String> example : fewshotExampleList) {
|
for (SqlExemplar exemplar : fewshotExampleList) {
|
||||||
String metadata = example.get("dbSchema");
|
|
||||||
String question = example.get("questionAugmented");
|
|
||||||
String sql = example.get("sql");
|
|
||||||
String exemplarStr = String.format("#UserQuery: %s #Schema: %s #SQL: %s\n",
|
String exemplarStr = String.format("#UserQuery: %s #Schema: %s #SQL: %s\n",
|
||||||
question, metadata, sql);
|
exemplar.getQuestion(), exemplar.getDbSchema(), exemplar.getSql());
|
||||||
exemplarsStr.append(exemplarStr);
|
exemplarsStr.append(exemplarStr);
|
||||||
}
|
}
|
||||||
|
|
||||||
String dataSemanticsStr = promptHelper.buildMetadataStr(llmReq);
|
String dataSemanticsStr = promptHelper.buildSchemaStr(llmReq);
|
||||||
String questionAugmented = promptHelper.buildAugmentedQuestion(llmReq);
|
String questionAugmented = promptHelper.buildAugmentedQuestion(llmReq);
|
||||||
String promptStr = String.format(INSTRUCTION, exemplarsStr, questionAugmented, dataSemanticsStr);
|
String promptStr = String.format(INSTRUCTION, exemplarsStr, questionAugmented, dataSemanticsStr);
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
@@ -108,7 +109,7 @@ public class OutputFormat {
|
|||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Map<String, LLMSqlResp> buildSqlRespMap(List<Map<String, String>> sqlExamples,
|
public static Map<String, LLMSqlResp> buildSqlRespMap(List<SqlExemplar> sqlExamples,
|
||||||
Map<String, Double> sqlMap) {
|
Map<String, Double> sqlMap) {
|
||||||
if (sqlMap == null) {
|
if (sqlMap == null) {
|
||||||
return new HashMap<>();
|
return new HashMap<>();
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
|
import com.tencent.supersonic.common.service.ExemplarService;
|
||||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -11,7 +14,6 @@ import org.springframework.util.CollectionUtils;
|
|||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
|
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
|
||||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_FEW_SHOT_NUMBER;
|
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_FEW_SHOT_NUMBER;
|
||||||
@@ -25,20 +27,27 @@ public class PromptHelper {
|
|||||||
private ParserConfig parserConfig;
|
private ParserConfig parserConfig;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private ExemplarManager exemplarManager;
|
private ExemplarService exemplarService;
|
||||||
|
|
||||||
public List<List<Map<String, String>>> getFewShotExemplars(LLMReq llmReq) {
|
public List<List<SqlExemplar>> getFewShotExemplars(LLMReq llmReq) {
|
||||||
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
||||||
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
|
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
|
||||||
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
||||||
|
|
||||||
List<Map<String, String>> exemplars = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
List<SqlExemplar> exemplars = Lists.newArrayList();
|
||||||
exemplarRecallNumber);
|
llmReq.getExemplars().stream().forEach(e -> {
|
||||||
List<List<Map<String, String>>> results = new ArrayList<>();
|
exemplars.add(e);
|
||||||
|
});
|
||||||
|
|
||||||
|
int recallSize = exemplarRecallNumber - llmReq.getExemplars().size();
|
||||||
|
if (recallSize > 0) {
|
||||||
|
exemplars.addAll(exemplarService.recallExemplars(llmReq.getQueryText(), recallSize));
|
||||||
|
}
|
||||||
|
|
||||||
|
List<List<SqlExemplar>> results = new ArrayList<>();
|
||||||
// use random collection of exemplars for each self-consistency inference
|
// use random collection of exemplars for each self-consistency inference
|
||||||
for (int i = 0; i < selfConsistencyNumber; i++) {
|
for (int i = 0; i < selfConsistencyNumber; i++) {
|
||||||
List<Map<String, String>> shuffledList = new ArrayList<>(exemplars);
|
List<SqlExemplar> shuffledList = new ArrayList<>(exemplars);
|
||||||
Collections.shuffle(shuffledList);
|
Collections.shuffle(shuffledList);
|
||||||
results.add(shuffledList.subList(0, fewShotNumber));
|
results.add(shuffledList.subList(0, fewShotNumber));
|
||||||
}
|
}
|
||||||
@@ -64,7 +73,7 @@ public class PromptHelper {
|
|||||||
linkingListStr, currentDataStr, termStr, priorExts);
|
linkingListStr, currentDataStr, termStr, priorExts);
|
||||||
}
|
}
|
||||||
|
|
||||||
public String buildMetadataStr(LLMReq llmReq) {
|
public String buildSchemaStr(LLMReq llmReq) {
|
||||||
String tableStr = llmReq.getSchema().getDataSetName();
|
String tableStr = llmReq.getSchema().getDataSetName();
|
||||||
StringBuilder metricStr = new StringBuilder();
|
StringBuilder metricStr = new StringBuilder();
|
||||||
StringBuilder dimensionStr = new StringBuilder();
|
StringBuilder dimensionStr = new StringBuilder();
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonValue;
|
|||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.config.LLMConfig;
|
import com.tencent.supersonic.common.config.LLMConfig;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -26,6 +27,9 @@ public class LLMReq {
|
|||||||
private SqlGenType sqlGenType;
|
private SqlGenType sqlGenType;
|
||||||
|
|
||||||
private LLMConfig llmConfig;
|
private LLMConfig llmConfig;
|
||||||
|
|
||||||
|
private List<SqlExemplar> exemplars;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public static class ElementValue {
|
public static class ElementValue {
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ public class LLMResp {
|
|||||||
|
|
||||||
private String modelName;
|
private String modelName;
|
||||||
|
|
||||||
|
private String dbSchema;
|
||||||
|
|
||||||
private String sqlOutput;
|
private String sqlOutput;
|
||||||
|
|
||||||
private List<String> fields;
|
private List<String> fields;
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@@ -17,6 +16,6 @@ public class LLMSqlResp {
|
|||||||
|
|
||||||
private double sqlWeight;
|
private double sqlWeight;
|
||||||
|
|
||||||
private List<Map<String, String>> fewShots;
|
private List<SqlExemplar> fewShots;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -172,6 +172,7 @@ public abstract class S2BaseDemo implements CommandLineRunner {
|
|||||||
executeReq.setQueryText(queryText);
|
executeReq.setQueryText(queryText);
|
||||||
executeReq.setChatId(parseResp.getChatId());
|
executeReq.setChatId(parseResp.getChatId());
|
||||||
executeReq.setUser(User.getFakeUser());
|
executeReq.setUser(User.getFakeUser());
|
||||||
|
executeReq.setAgentId(agentId);
|
||||||
executeReq.setSaveAnswer(true);
|
executeReq.setSaveAnswer(true);
|
||||||
chatService.performExecution(executeReq);
|
chatService.performExecution(executeReq);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -83,6 +83,25 @@ CREATE TABLE IF NOT EXISTS `s2_chat_config` (
|
|||||||
COMMENT ON TABLE s2_chat_config IS 'chat config information table ';
|
COMMENT ON TABLE s2_chat_config IS 'chat config information table ';
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS `s2_chat_memory` (
|
||||||
|
`id` INT NOT NULL AUTO_INCREMENT,
|
||||||
|
`question` varchar(655) ,
|
||||||
|
`agent_id` INT ,
|
||||||
|
`db_schema` TEXT ,
|
||||||
|
`s2_sql` TEXT ,
|
||||||
|
`status` char(10) ,
|
||||||
|
`llm_review` char(10) ,
|
||||||
|
`llm_comment` TEXT,
|
||||||
|
`human_review` char(10) ,
|
||||||
|
`human_comment` TEXT ,
|
||||||
|
`created_at` TIMESTAMP ,
|
||||||
|
`updated_at` TIMESTAMP ,
|
||||||
|
`created_by` varchar(100) ,
|
||||||
|
`updated_by` varchar(100) ,
|
||||||
|
PRIMARY KEY (`id`)
|
||||||
|
) ;
|
||||||
|
COMMENT ON TABLE s2_chat_memory IS 'chat memory table ';
|
||||||
|
|
||||||
create table IF NOT EXISTS s2_user
|
create table IF NOT EXISTS s2_user
|
||||||
(
|
(
|
||||||
id INT AUTO_INCREMENT,
|
id INT AUTO_INCREMENT,
|
||||||
|
|||||||
@@ -136,6 +136,24 @@ CREATE TABLE `s2_chat_config` (
|
|||||||
PRIMARY KEY (`id`)
|
PRIMARY KEY (`id`)
|
||||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='主题域扩展信息表';
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='主题域扩展信息表';
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS `s2_chat_memory` (
|
||||||
|
`id` INT NOT NULL AUTO_INCREMENT,
|
||||||
|
`question` varchar(655) ,
|
||||||
|
`agent_id` INT ,
|
||||||
|
`db_schema` TEXT ,
|
||||||
|
`s2_sql` TEXT ,
|
||||||
|
`status` char(10) ,
|
||||||
|
`llm_review` char(10) ,
|
||||||
|
`llm_comment` TEXT,
|
||||||
|
`human_review` char(10) ,
|
||||||
|
`human_comment` TEXT ,
|
||||||
|
`created_at` TIMESTAMP NOT NULL ,
|
||||||
|
`updated_at` TIMESTAMP NOT NULL ,
|
||||||
|
`created_by` varchar(100) NOT NULL ,
|
||||||
|
`updated_by` varchar(100) NOT NULL ,
|
||||||
|
PRIMARY KEY (`id`)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
|
||||||
|
|
||||||
CREATE TABLE `s2_chat_context` (
|
CREATE TABLE `s2_chat_context` (
|
||||||
`chat_id` bigint(20) NOT NULL COMMENT 'context chat id',
|
`chat_id` bigint(20) NOT NULL COMMENT 'context chat id',
|
||||||
`modified_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'row modify time',
|
`modified_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'row modify time',
|
||||||
|
|||||||
@@ -82,6 +82,24 @@ CREATE TABLE IF NOT EXISTS `s2_chat_config` (
|
|||||||
) ;
|
) ;
|
||||||
COMMENT ON TABLE s2_chat_config IS 'chat config information table ';
|
COMMENT ON TABLE s2_chat_config IS 'chat config information table ';
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS `s2_chat_memory` (
|
||||||
|
`id` INT NOT NULL AUTO_INCREMENT,
|
||||||
|
`question` varchar(655) ,
|
||||||
|
`agent_id` INT ,
|
||||||
|
`db_schema` TEXT ,
|
||||||
|
`s2_sql` TEXT ,
|
||||||
|
`status` char(10) ,
|
||||||
|
`llm_review` char(10) ,
|
||||||
|
`llm_comment` TEXT,
|
||||||
|
`human_review` char(10) ,
|
||||||
|
`human_comment` TEXT ,
|
||||||
|
`created_at` TIMESTAMP ,
|
||||||
|
`updated_at` TIMESTAMP ,
|
||||||
|
`created_by` varchar(100) ,
|
||||||
|
`updated_by` varchar(100) ,
|
||||||
|
PRIMARY KEY (`id`)
|
||||||
|
) ;
|
||||||
|
COMMENT ON TABLE s2_chat_memory IS 'chat memory table ';
|
||||||
|
|
||||||
create table IF NOT EXISTS s2_user
|
create table IF NOT EXISTS s2_user
|
||||||
(
|
(
|
||||||
|
|||||||
Reference in New Issue
Block a user