7 Commits

Author SHA1 Message Date
zyclove
0eb8897d08 Merge b1dadb4a1a into 21e213fb19 2025-03-12 22:35:57 +08:00
guilinlewis
21e213fb19 (improvement)(headless | chat ) 向量数据被重置后,记忆不会再次添加到向量数据库 (#2164)
Some checks are pending
supersonic CentOS CI / build (21) (push) Waiting to run
supersonic mac CI / build (21) (push) Waiting to run
supersonic ubuntu CI / build (21) (push) Waiting to run
supersonic windows CI / build (21) (push) Waiting to run
2025-03-12 22:19:51 +08:00
jerryjzhang
f67bf3eeac (fix)(chat)Fix bug in creating chat model. 2025-03-12 16:47:40 +08:00
jerryjzhang
9d13038599 (fix)(headless)Fix schema corrector in that aliases should not be replaced. 2025-03-12 16:31:43 +08:00
beat4ocean
0c8c2d4804 [fix][headless] Fix issue filterSql is not working. (#2157)
Some checks are pending
supersonic CentOS CI / build (21) (push) Waiting to run
supersonic mac CI / build (21) (push) Waiting to run
supersonic ubuntu CI / build (21) (push) Waiting to run
supersonic windows CI / build (21) (push) Waiting to run
2025-03-12 13:53:08 +08:00
zhaoyingchao
b1dadb4a1a Merge remote-tracking branch 'origin/master' into hanlp-upgrade 2025-03-05 10:01:32 +08:00
zhaoyingchao
158a0a802a feat:upgrade 1.8.4 2025-03-04 18:55:31 +08:00
17 changed files with 113 additions and 34 deletions

View File

@@ -161,9 +161,11 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
JsonUtil.toMap(agentDO.getChatModelConfig(), String.class, ChatApp.class));
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
agent.getChatAppConfig().values().forEach(c -> {
ChatModel chatModel = chatModelService.getChatModel(c.getChatModelId());
if (Objects.nonNull(chatModel)) {
c.setChatModelConfig(chatModelService.getChatModel(c.getChatModelId()).getConfig());
if (c.isEnable()) {// 优化,减少访问数据库的次数
ChatModel chatModel = chatModelService.getChatModel(c.getChatModelId());
if (Objects.nonNull(chatModel)) {
c.setChatModelConfig(chatModel.getConfig());
}
}
});
agent.setAdmins(JsonUtil.toList(agentDO.getAdmin(), String.class));

View File

@@ -18,19 +18,23 @@ import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.service.ExemplarService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.CommandLineRunner;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@Service
public class MemoryServiceImpl implements MemoryService {
@Slf4j
public class MemoryServiceImpl implements MemoryService , CommandLineRunner {
@Autowired
private ChatMemoryRepository chatMemoryRepository;
@@ -187,4 +191,23 @@ public class MemoryServiceImpl implements MemoryService {
return memory;
}
@Override
public void run(String... args) { // 优化,启动时检查,向量数据,将记忆放到向量数据库
loadSysExemplars();
}
public void loadSysExemplars() {
try {
List<ChatMemory> memories =
this.getMemories(ChatMemoryFilter.builder().status(MemoryStatus.ENABLED).build());
for(ChatMemory memory:memories){
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
Text2SQLExemplar.builder().question(memory.getQuestion())
.sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema())
.sql(memory.getS2sql()).build());
}
} catch (Exception e) {
log.error("Failed to load system exemplars", e);
}
}
}

View File

@@ -146,6 +146,10 @@ public class SqlReplaceHelper {
public static String replaceFields(String sql, Map<String, String> fieldNameMap,
boolean exactReplace) {
Select selectStatement = SqlSelectHelper.getSelect(sql);
// alias field should not be replaced
Set<String> aliases = SqlSelectHelper.getAliasFields(sql);
aliases.forEach(alias -> fieldNameMap.put(alias, alias));
Set<Select> plainSelectList = SqlSelectHelper.getAllSelect(selectStatement);
for (Select plainSelect : plainSelectList) {
if (plainSelect instanceof PlainSelect) {

View File

@@ -57,6 +57,7 @@ public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModel
chatModelDO.setViewer(JsonUtil.toString(chatModel.getViewers()));
}
save(chatModelDO);
chatModel.setId(chatModelDO.getId());
return chatModel;
}

View File

@@ -24,6 +24,8 @@ public class ModelDetail {
private String tableQuery;
private String filterSql;
private List<Identify> identifiers = Lists.newArrayList();
private List<Dimension> dimensions = Lists.newArrayList();

View File

@@ -19,6 +19,8 @@ public class ModelBuildReq {
private String sql;
private String filterSql;
private String catalog;
private String db;

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
@@ -8,10 +9,7 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.*;
/** Perform SQL corrections on the "Select" section in S2SQL. */
@Slf4j
@@ -46,10 +44,28 @@ public class SelectCorrector extends BaseSemanticCorrector {
return correctS2SQL;
}
needAddFields.removeAll(selectFields);
String addFieldsToSelectSql =
SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(addFieldsToSelectSql);
return addFieldsToSelectSql;
if (!SqlSelectHelper.hasSubSelect(correctS2SQL)) { //优化内容 如果sql 条件包含了这个字段,而且是全等,则不再查询该字段
List<FieldExpression> tmp4 = SqlSelectHelper.getWhereExpressions(correctS2SQL);
Iterator<String> it = needAddFields.iterator();
while (it.hasNext()) {
String field = it.next();
long size = tmp4.stream()
.filter(e -> e.getFieldName().equals(field) && "=".equals(e.getOperator()))
.count();
if (size == 1) {
it.remove();
}
}
}
if (needAddFields.size() > 0) {
String addFieldsToSelectSql =
SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(addFieldsToSelectSql);
return addFieldsToSelectSql;
} else {
return correctS2SQL;
}
}
}

View File

@@ -75,6 +75,7 @@ public class KeywordMapper extends BaseMapper {
continue;
}
Long elementID = NatureHelper.getElementID(nature);
if (elementID == null)continue; // 判空优化
SchemaElement element = getSchemaElement(dataSetId, elementType, elementID,
chatQueryContext.getSemanticSchema());
if (Objects.isNull(element)) {

View File

@@ -41,6 +41,28 @@ class AggCorrectorTest {
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
}
@Test
void testSchemaCorrector() {
SchemaCorrector corrector = new SchemaCorrector();
Long dataSetId = 1L;
ChatQueryContext chatQueryContext = buildQueryContext(dataSetId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SchemaElement dataSet = new SchemaElement();
dataSet.setDataSetId(dataSetId);
semanticParseInfo.setDataSet(dataSet);
SqlInfo sqlInfo = new SqlInfo();
String sql =
"WITH 总停留时长 AS (SELECT 用户, SUM(停留时长) AS _总停留时长_ FROM 超音数数据集 WHERE 用户 IN ('alice', 'lucy') AND 数据日期 >= '2025-03-01' AND 数据日期 <= '2025-03-12' GROUP BY 用户) SELECT 用户, _总停留时长_ FROM 总停留时长";
sqlInfo.setParsedS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
corrector.correct(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"WITH 总停留时长 AS (SELECT 用户名, SUM(停留时长) AS _总停留时长_ FROM 超音数数据集 WHERE 用户名 IN ('alice', 'lucy') AND 数据日期 "
+ ">= '2025-03-01' AND 数据日期 <= '2025-03-12' GROUP BY 用户名) SELECT 用户名, _总停留时长_ FROM 总停留时长",
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
}
private ChatQueryContext buildQueryContext(Long dataSetId) {
ChatQueryContext chatQueryContext = new ChatQueryContext();
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
@@ -51,18 +73,18 @@ class AggCorrectorTest {
schemaElement.setDataSetId(dataSetId);
dataSetSchema.setDataSet(schemaElement);
Set<SchemaElement> dimensions = new HashSet<>();
SchemaElement element1 = new SchemaElement();
element1.setDataSetId(1L);
element1.setName("部门");
dimensions.add(element1);
dimensions.add(SchemaElement.builder().dataSetId(1L).name("部门").build());
dimensions.add(SchemaElement.builder().dataSetId(1L).name("用户名").build());
dataSetSchema.setDimensions(dimensions);
Set<SchemaElement> metrics = new HashSet<>();
SchemaElement metric1 = new SchemaElement();
metric1.setDataSetId(1L);
metric1.setName("访问次数");
metrics.add(metric1);
metrics.add(SchemaElement.builder().dataSetId(1L).name("访问次数").build());
metrics.add(SchemaElement.builder().dataSetId(1L).name("停留时长").build());
dataSetSchema.setMetrics(metrics);
dataSetSchemaList.add(dataSetSchema);

View File

@@ -40,11 +40,23 @@ public class DataModelNode extends SemanticNode {
.equalsIgnoreCase(EngineType.POSTGRESQL.getName())) {
String fullTableName = String.join(".public.",
dataModel.getModelDetail().getTableQuery().split("\\."));
sqlTable = "select * from " + fullTableName;
sqlTable = "SELECT * FROM " + fullTableName;
} else {
sqlTable = "select * from " + dataModel.getModelDetail().getTableQuery();
sqlTable = "SELECT * FROM " + dataModel.getModelDetail().getTableQuery();
}
}
// String filterSql = dataModel.getFilterSql();
String filterSql = dataModel.getModelDetail().getFilterSql();
if (filterSql != null && !filterSql.isEmpty()) {
boolean sqlContainWhere = sqlTable.toUpperCase().matches("(?s).*\\bWHERE\\b.*");
if (sqlContainWhere) {
sqlTable = String.format("%s AND %s", sqlTable, filterSql);
} else {
sqlTable = String.format("%s WHERE %s", sqlTable, filterSql);
}
}
if (sqlTable.isEmpty()) {
throw new Exception("DataModelNode build error [tableSqlNode not found]");
}

View File

@@ -36,6 +36,7 @@ public class ModelYamlManager {
} else {
dataModelYamlTpl.setTableQuery(modelDetail.getTableQuery());
}
dataModelYamlTpl.setFilterSql(modelDetail.getFilterSql());
dataModelYamlTpl.setFields(modelResp.getModelDetail().getFields());
dataModelYamlTpl.setId(modelResp.getId());
return dataModelYamlTpl;

View File

@@ -97,6 +97,7 @@ public class SemanticSchemaManager {
modelDetail.setDbType(d.getType());
modelDetail.setSqlQuery(d.getSqlQuery());
modelDetail.setTableQuery(d.getTableQuery());
modelDetail.setFilterSql(d.getFilterSql());
modelDetail.getIdentifiers().addAll(getIdentify(d.getIdentifiers()));
modelDetail.getMeasures().addAll(getMeasureParams(d.getMeasures()));
modelDetail.getDimensions().addAll(getDimensions(d.getDimensions()));

View File

@@ -21,6 +21,8 @@ public class DataModelYamlTpl {
private String tableQuery;
private String filterSql;
private List<IdentifyYamlTpl> identifiers;
private List<DimensionYamlTpl> dimensions;

View File

@@ -179,6 +179,7 @@ public class ModelConverter {
}
}
modelDetail.setFields(fields);
modelDetail.setFilterSql(modelBuildReq.getFilterSql());
modelReq.setModelDetail(modelDetail);
return modelReq;
}

View File

@@ -333,7 +333,6 @@ public class S2VisitsDemo extends S2BaseDemo {
metricReq.setSensitiveLevel(SensitiveLevelEnum.HIGH.getCode());
metricReq.setDescription("每个用户平均访问的次数");
metricReq.setClassifications(Collections.singletonList("核心指标"));
metricReq.setAlias("平均访问次数");
MetricDefineByMetricParams metricTypeParams = new MetricDefineByMetricParams();
metricTypeParams.setExpr("pv/uv");
List<MetricParam> metrics = new ArrayList<>();

View File

@@ -155,16 +155,6 @@ public class Text2SQLEval extends BaseTest {
assert result.getTextResult().contains("3");
}
@Test
public void test_detail_query() throws Exception {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("特斯拉旗下有哪些品牌", agent.getId());
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() >= 1;
assert result.getTextResult().contains("Model Y");
assert result.getTextResult().contains("Model 3");
}
public Agent getLLMAgent() {
Agent agent = new Agent();
agent.setName("Agent for Test");

View File

@@ -37,7 +37,7 @@
<pagehelper.spring.version>2.1.0</pagehelper.spring.version>
<mybatis.version>3.5.3</mybatis.version>
<guava.version>32.0.0-jre</guava.version>
<hanlp.version>portable-1.8.3</hanlp.version>
<hanlp.version>portable-1.8.4</hanlp.version>
<hadoop.version>2.7.2</hadoop.version>
<commons.lang.version>2.6</commons.lang.version>
<commons.lang3.version>3.7</commons.lang3.version>