mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 22:46:49 +00:00
(improvement)(headless)Introduce side_information to the prompt and exemplar.
This commit is contained in:
@@ -6,7 +6,9 @@ import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.ResultFormatter;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
@@ -39,13 +41,18 @@ public class SqlExecutor implements ChatQueryExecutor {
|
||||
|
||||
if (queryResult.getQueryState().equals(QueryState.SUCCESS)
|
||||
&& queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
||||
SqlExemplar exemplar = JsonUtil.toObject(JsonUtil.toString(
|
||||
executeContext.getParseInfo().getProperties()
|
||||
.get(SqlExemplar.PROPERTY_KEY)), SqlExemplar.class);
|
||||
|
||||
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
|
||||
memoryService.createMemory(ChatMemoryDO.builder()
|
||||
.agentId(executeContext.getAgent().getId())
|
||||
.status(MemoryStatus.PENDING)
|
||||
.question(executeContext.getQueryText())
|
||||
.s2sql(executeContext.getParseInfo().getSqlInfo().getParsedS2SQL())
|
||||
.dbSchema(buildSchemaStr(executeContext.getParseInfo()))
|
||||
.question(exemplar.getQuestion())
|
||||
.sideInfo(exemplar.getSideInfo())
|
||||
.dbSchema(exemplar.getDbSchema())
|
||||
.s2sql(exemplar.getSql())
|
||||
.createdBy(executeContext.getUser().getName())
|
||||
.updatedBy(executeContext.getUser().getName())
|
||||
.createdAt(new Date())
|
||||
@@ -98,36 +105,4 @@ public class SqlExecutor implements ChatQueryExecutor {
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
public String buildSchemaStr(SemanticParseInfo parseInfo) {
|
||||
String tableStr = parseInfo.getDataSet().getName();
|
||||
StringBuilder metricStr = new StringBuilder();
|
||||
StringBuilder dimensionStr = new StringBuilder();
|
||||
|
||||
parseInfo.getMetrics().stream().forEach(
|
||||
metric -> {
|
||||
metricStr.append(metric.getName());
|
||||
if (StringUtils.isNotEmpty(metric.getDescription())) {
|
||||
metricStr.append(" COMMENT '" + metric.getDescription() + "'");
|
||||
}
|
||||
if (StringUtils.isNotEmpty(metric.getDefaultAgg())) {
|
||||
metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'");
|
||||
}
|
||||
metricStr.append(",");
|
||||
}
|
||||
);
|
||||
|
||||
parseInfo.getDimensions().stream().forEach(
|
||||
dimension -> {
|
||||
dimensionStr.append(dimension.getName());
|
||||
if (StringUtils.isNotEmpty(dimension.getDescription())) {
|
||||
dimensionStr.append(" COMMENT '" + dimension.getDescription() + "'");
|
||||
}
|
||||
dimensionStr.append(",");
|
||||
}
|
||||
);
|
||||
|
||||
String template = "Table: %s, Metrics: [%s], Dimensions: [%s]";
|
||||
return String.format(template, tableStr, metricStr, dimensionStr);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ public class MemoryReviewTask {
|
||||
+ "2.ALWAYS recognize `数据日期` as the date field.\n"
|
||||
+ "#Question: %s\n"
|
||||
+ "#Schema: %s\n"
|
||||
+ "#SideInfo: %s\n"
|
||||
+ "#SQL: %s\n"
|
||||
+ "#Response: ";
|
||||
|
||||
@@ -52,7 +53,8 @@ public class MemoryReviewTask {
|
||||
.forEach(m -> {
|
||||
Agent chatAgent = agentService.getAgent(m.getAgentId());
|
||||
if (Objects.nonNull(chatAgent) && chatAgent.enableMemoryReview()) {
|
||||
String promptStr = String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getS2sql());
|
||||
String promptStr = String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(),
|
||||
m.getSideInfo(), m.getS2sql());
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
|
||||
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);
|
||||
|
||||
@@ -20,11 +20,14 @@ public class ChatMemoryDO {
|
||||
@TableId(type = IdType.AUTO)
|
||||
private Long id;
|
||||
|
||||
@TableField("agent_id")
|
||||
private Integer agentId;
|
||||
|
||||
@TableField("question")
|
||||
private String question;
|
||||
|
||||
@TableField("agent_id")
|
||||
private Integer agentId;
|
||||
@TableField("side_info")
|
||||
private String sideInfo;
|
||||
|
||||
@TableField("db_schema")
|
||||
private String dbSchema;
|
||||
|
||||
@@ -102,6 +102,7 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
SqlExemplar.builder()
|
||||
.question(memory.getQuestion())
|
||||
.sideInfo(memory.getSideInfo())
|
||||
.dbSchema(memory.getDbSchema())
|
||||
.sql(memory.getS2sql())
|
||||
.build());
|
||||
@@ -113,6 +114,7 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
SqlExemplar.builder()
|
||||
.question(memory.getQuestion())
|
||||
.sideInfo(memory.getSideInfo())
|
||||
.dbSchema(memory.getDbSchema())
|
||||
.sql(memory.getS2sql())
|
||||
.build());
|
||||
|
||||
Reference in New Issue
Block a user