mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-04-22 14:54:21 +08:00
Compare commits
6 Commits
8ce7fc7dd6
...
f0c5d9e6e0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0c5d9e6e0 | ||
|
|
21e213fb19 | ||
|
|
f67bf3eeac | ||
|
|
9d13038599 | ||
|
|
668f872743 | ||
|
|
acb9cef64e |
@@ -161,9 +161,11 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
JsonUtil.toMap(agentDO.getChatModelConfig(), String.class, ChatApp.class));
|
||||
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
|
||||
agent.getChatAppConfig().values().forEach(c -> {
|
||||
ChatModel chatModel = chatModelService.getChatModel(c.getChatModelId());
|
||||
if (Objects.nonNull(chatModel)) {
|
||||
c.setChatModelConfig(chatModelService.getChatModel(c.getChatModelId()).getConfig());
|
||||
if (c.isEnable()) {// 优化,减少访问数据库的次数
|
||||
ChatModel chatModel = chatModelService.getChatModel(c.getChatModelId());
|
||||
if (Objects.nonNull(chatModel)) {
|
||||
c.setChatModelConfig(chatModel.getConfig());
|
||||
}
|
||||
}
|
||||
});
|
||||
agent.setAdmins(JsonUtil.toList(agentDO.getAdmin(), String.class));
|
||||
|
||||
@@ -18,19 +18,23 @@ import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.CommandLineRunner;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
public class MemoryServiceImpl implements MemoryService {
|
||||
@Slf4j
|
||||
public class MemoryServiceImpl implements MemoryService , CommandLineRunner {
|
||||
|
||||
@Autowired
|
||||
private ChatMemoryRepository chatMemoryRepository;
|
||||
@@ -187,4 +191,23 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
return memory;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run(String... args) { // 优化,启动时检查,向量数据,将记忆放到向量数据库
|
||||
loadSysExemplars();
|
||||
}
|
||||
public void loadSysExemplars() {
|
||||
try {
|
||||
List<ChatMemory> memories =
|
||||
this.getMemories(ChatMemoryFilter.builder().status(MemoryStatus.ENABLED).build());
|
||||
for(ChatMemory memory:memories){
|
||||
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
Text2SQLExemplar.builder().question(memory.getQuestion())
|
||||
.sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema())
|
||||
.sql(memory.getS2sql()).build());
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("Failed to load system exemplars", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,6 +146,10 @@ public class SqlReplaceHelper {
|
||||
public static String replaceFields(String sql, Map<String, String> fieldNameMap,
|
||||
boolean exactReplace) {
|
||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||
// alias field should not be replaced
|
||||
Set<String> aliases = SqlSelectHelper.getAliasFields(sql);
|
||||
aliases.forEach(alias -> fieldNameMap.put(alias, alias));
|
||||
|
||||
Set<Select> plainSelectList = SqlSelectHelper.getAllSelect(selectStatement);
|
||||
for (Select plainSelect : plainSelectList) {
|
||||
if (plainSelect instanceof PlainSelect) {
|
||||
|
||||
@@ -225,7 +225,7 @@ public class SqlSelectHelper {
|
||||
public static Select getSelect(String sql) {
|
||||
Statement statement = null;
|
||||
try {
|
||||
statement = CCJSqlParserUtil.parse(sql);
|
||||
statement = CCJSqlParserUtil.parse(sql, parser -> parser.withTimeOut(20000));
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("parse error, sql:{}", sql, e);
|
||||
throw new RuntimeException(e);
|
||||
|
||||
@@ -57,6 +57,7 @@ public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModel
|
||||
chatModelDO.setViewer(JsonUtil.toString(chatModel.getViewers()));
|
||||
}
|
||||
save(chatModelDO);
|
||||
chatModel.setId(chatModelDO.getId());
|
||||
return chatModel;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
||||
@@ -8,10 +9,7 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.*;
|
||||
|
||||
/** Perform SQL corrections on the "Select" section in S2SQL. */
|
||||
@Slf4j
|
||||
@@ -46,10 +44,28 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
return correctS2SQL;
|
||||
}
|
||||
needAddFields.removeAll(selectFields);
|
||||
String addFieldsToSelectSql =
|
||||
SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(addFieldsToSelectSql);
|
||||
return addFieldsToSelectSql;
|
||||
|
||||
if (!SqlSelectHelper.hasSubSelect(correctS2SQL)) { //优化内容 , 如果sql 条件包含了这个字段,而且是全等,则不再查询该字段
|
||||
List<FieldExpression> tmp4 = SqlSelectHelper.getWhereExpressions(correctS2SQL);
|
||||
Iterator<String> it = needAddFields.iterator();
|
||||
while (it.hasNext()) {
|
||||
String field = it.next();
|
||||
long size = tmp4.stream()
|
||||
.filter(e -> e.getFieldName().equals(field) && "=".equals(e.getOperator()))
|
||||
.count();
|
||||
if (size == 1) {
|
||||
it.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (needAddFields.size() > 0) {
|
||||
String addFieldsToSelectSql =
|
||||
SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(addFieldsToSelectSql);
|
||||
return addFieldsToSelectSql;
|
||||
} else {
|
||||
return correctS2SQL;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -75,6 +75,7 @@ public class KeywordMapper extends BaseMapper {
|
||||
continue;
|
||||
}
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
if (elementID == null)continue; // 判空优化
|
||||
SchemaElement element = getSchemaElement(dataSetId, elementType, elementID,
|
||||
chatQueryContext.getSemanticSchema());
|
||||
if (Objects.isNull(element)) {
|
||||
|
||||
@@ -41,6 +41,28 @@ class AggCorrectorTest {
|
||||
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSchemaCorrector() {
|
||||
SchemaCorrector corrector = new SchemaCorrector();
|
||||
Long dataSetId = 1L;
|
||||
ChatQueryContext chatQueryContext = buildQueryContext(dataSetId);
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SchemaElement dataSet = new SchemaElement();
|
||||
dataSet.setDataSetId(dataSetId);
|
||||
semanticParseInfo.setDataSet(dataSet);
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
String sql =
|
||||
"WITH 总停留时长 AS (SELECT 用户, SUM(停留时长) AS _总停留时长_ FROM 超音数数据集 WHERE 用户 IN ('alice', 'lucy') AND 数据日期 >= '2025-03-01' AND 数据日期 <= '2025-03-12' GROUP BY 用户) SELECT 用户, _总停留时长_ FROM 总停留时长";
|
||||
sqlInfo.setParsedS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
corrector.correct(chatQueryContext, semanticParseInfo);
|
||||
Assert.assertEquals(
|
||||
"WITH 总停留时长 AS (SELECT 用户名, SUM(停留时长) AS _总停留时长_ FROM 超音数数据集 WHERE 用户名 IN ('alice', 'lucy') AND 数据日期 "
|
||||
+ ">= '2025-03-01' AND 数据日期 <= '2025-03-12' GROUP BY 用户名) SELECT 用户名, _总停留时长_ FROM 总停留时长",
|
||||
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
}
|
||||
|
||||
private ChatQueryContext buildQueryContext(Long dataSetId) {
|
||||
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
||||
@@ -51,18 +73,18 @@ class AggCorrectorTest {
|
||||
schemaElement.setDataSetId(dataSetId);
|
||||
dataSetSchema.setDataSet(schemaElement);
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
SchemaElement element1 = new SchemaElement();
|
||||
element1.setDataSetId(1L);
|
||||
element1.setName("部门");
|
||||
dimensions.add(element1);
|
||||
|
||||
dimensions.add(SchemaElement.builder().dataSetId(1L).name("部门").build());
|
||||
|
||||
dimensions.add(SchemaElement.builder().dataSetId(1L).name("用户名").build());
|
||||
|
||||
dataSetSchema.setDimensions(dimensions);
|
||||
|
||||
Set<SchemaElement> metrics = new HashSet<>();
|
||||
SchemaElement metric1 = new SchemaElement();
|
||||
metric1.setDataSetId(1L);
|
||||
metric1.setName("访问次数");
|
||||
metrics.add(metric1);
|
||||
|
||||
metrics.add(SchemaElement.builder().dataSetId(1L).name("访问次数").build());
|
||||
|
||||
metrics.add(SchemaElement.builder().dataSetId(1L).name("停留时长").build());
|
||||
|
||||
dataSetSchema.setMetrics(metrics);
|
||||
dataSetSchemaList.add(dataSetSchema);
|
||||
|
||||
@@ -333,7 +333,6 @@ public class S2VisitsDemo extends S2BaseDemo {
|
||||
metricReq.setSensitiveLevel(SensitiveLevelEnum.HIGH.getCode());
|
||||
metricReq.setDescription("每个用户平均访问的次数");
|
||||
metricReq.setClassifications(Collections.singletonList("核心指标"));
|
||||
metricReq.setAlias("平均访问次数");
|
||||
MetricDefineByMetricParams metricTypeParams = new MetricDefineByMetricParams();
|
||||
metricTypeParams.setExpr("pv/uv");
|
||||
List<MetricParam> metrics = new ArrayList<>();
|
||||
|
||||
@@ -155,16 +155,6 @@ public class Text2SQLEval extends BaseTest {
|
||||
assert result.getTextResult().contains("3");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_detail_query() throws Exception {
|
||||
long start = System.currentTimeMillis();
|
||||
QueryResult result = submitNewChat("特斯拉旗下有哪些品牌", agent.getId());
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() >= 1;
|
||||
assert result.getTextResult().contains("Model Y");
|
||||
assert result.getTextResult().contains("Model 3");
|
||||
}
|
||||
|
||||
public Agent getLLMAgent() {
|
||||
Agent agent = new Agent();
|
||||
agent.setName("Agent for Test");
|
||||
|
||||
2
pom.xml
2
pom.xml
@@ -32,7 +32,7 @@
|
||||
<maven.compiler.source>21</maven.compiler.source>
|
||||
<maven.compiler.target>21</maven.compiler.target>
|
||||
<file.encoding>UTF-8</file.encoding>
|
||||
<jsqlparser.version>4.7</jsqlparser.version>
|
||||
<jsqlparser.version>4.9</jsqlparser.version>
|
||||
<pagehelper.version>6.1.0</pagehelper.version>
|
||||
<pagehelper.spring.version>2.1.0</pagehelper.spring.version>
|
||||
<mybatis.version>3.5.3</mybatis.version>
|
||||
|
||||
Reference in New Issue
Block a user