(improvement)(headless)Introduce side_information to the prompt and exemplar.

This commit is contained in:
jerryjzhang
2024-07-18 11:29:07 +08:00
parent f30c74c18f
commit 2eac301076
16 changed files with 128 additions and 165 deletions

View File

@@ -6,7 +6,9 @@ import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.ResultFormatter;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
@@ -39,13 +41,18 @@ public class SqlExecutor implements ChatQueryExecutor {
if (queryResult.getQueryState().equals(QueryState.SUCCESS)
&& queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
SqlExemplar exemplar = JsonUtil.toObject(JsonUtil.toString(
executeContext.getParseInfo().getProperties()
.get(SqlExemplar.PROPERTY_KEY)), SqlExemplar.class);
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
memoryService.createMemory(ChatMemoryDO.builder()
.agentId(executeContext.getAgent().getId())
.status(MemoryStatus.PENDING)
.question(executeContext.getQueryText())
.s2sql(executeContext.getParseInfo().getSqlInfo().getParsedS2SQL())
.dbSchema(buildSchemaStr(executeContext.getParseInfo()))
.question(exemplar.getQuestion())
.sideInfo(exemplar.getSideInfo())
.dbSchema(exemplar.getDbSchema())
.s2sql(exemplar.getSql())
.createdBy(executeContext.getUser().getName())
.updatedBy(executeContext.getUser().getName())
.createdAt(new Date())
@@ -98,36 +105,4 @@ public class SqlExecutor implements ChatQueryExecutor {
return queryResult;
}
public String buildSchemaStr(SemanticParseInfo parseInfo) {
String tableStr = parseInfo.getDataSet().getName();
StringBuilder metricStr = new StringBuilder();
StringBuilder dimensionStr = new StringBuilder();
parseInfo.getMetrics().stream().forEach(
metric -> {
metricStr.append(metric.getName());
if (StringUtils.isNotEmpty(metric.getDescription())) {
metricStr.append(" COMMENT '" + metric.getDescription() + "'");
}
if (StringUtils.isNotEmpty(metric.getDefaultAgg())) {
metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'");
}
metricStr.append(",");
}
);
parseInfo.getDimensions().stream().forEach(
dimension -> {
dimensionStr.append(dimension.getName());
if (StringUtils.isNotEmpty(dimension.getDescription())) {
dimensionStr.append(" COMMENT '" + dimension.getDescription() + "'");
}
dimensionStr.append(",");
}
);
String template = "Table: %s, Metrics: [%s], Dimensions: [%s]";
return String.format(template, tableStr, metricStr, dimensionStr);
}
}

View File

@@ -35,6 +35,7 @@ public class MemoryReviewTask {
+ "2.ALWAYS recognize `数据日期` as the date field.\n"
+ "#Question: %s\n"
+ "#Schema: %s\n"
+ "#SideInfo: %s\n"
+ "#SQL: %s\n"
+ "#Response: ";
@@ -52,7 +53,8 @@ public class MemoryReviewTask {
.forEach(m -> {
Agent chatAgent = agentService.getAgent(m.getAgentId());
if (Objects.nonNull(chatAgent) && chatAgent.enableMemoryReview()) {
String promptStr = String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getS2sql());
String promptStr = String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(),
m.getSideInfo(), m.getS2sql());
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);

View File

@@ -20,11 +20,14 @@ public class ChatMemoryDO {
@TableId(type = IdType.AUTO)
private Long id;
@TableField("agent_id")
private Integer agentId;
@TableField("question")
private String question;
@TableField("agent_id")
private Integer agentId;
@TableField("side_info")
private String sideInfo;
@TableField("db_schema")
private String dbSchema;

View File

@@ -102,6 +102,7 @@ public class MemoryServiceImpl implements MemoryService {
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
SqlExemplar.builder()
.question(memory.getQuestion())
.sideInfo(memory.getSideInfo())
.dbSchema(memory.getDbSchema())
.sql(memory.getS2sql())
.build());
@@ -113,6 +114,7 @@ public class MemoryServiceImpl implements MemoryService {
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
SqlExemplar.builder()
.question(memory.getQuestion())
.sideInfo(memory.getSideInfo())
.dbSchema(memory.getDbSchema())
.sql(memory.getS2sql())
.build());