[improvement][chat]Introduce AllFieldMapper to increase parsing robustness when normal pipeline fails.

This commit is contained in:
jerryjzhang
2024-12-26 21:33:40 +08:00
parent 8e03531424
commit d834e98a66
20 changed files with 163 additions and 42 deletions

View File

@@ -125,12 +125,9 @@ public class MemoryReviewTask {
if (MemoryReviewResult.POSITIVE.equals(m.getLlmReviewRet())) {
m.setStatus(MemoryStatus.ENABLED);
}
ChatMemoryUpdateReq memoryUpdateReq = ChatMemoryUpdateReq.builder()
.id(m.getId())
.status(m.getStatus())
.llmReviewRet(m.getLlmReviewRet())
.llmReviewCmt(m.getLlmReviewCmt())
.build();
ChatMemoryUpdateReq memoryUpdateReq = ChatMemoryUpdateReq.builder().id(m.getId())
.status(m.getStatus()).llmReviewRet(m.getLlmReviewRet())
.llmReviewCmt(m.getLlmReviewCmt()).build();
memoryService.updateMemory(memoryUpdateReq, User.getDefaultUser());
}
}

View File

@@ -100,10 +100,12 @@ public class NL2SQLParser implements ChatQueryParser {
queryNLReq.setMapModeEnum(mode);
doParse(queryNLReq, parseResp);
}
if (parseResp.getSelectedParses().isEmpty()) {
queryNLReq.setMapModeEnum(MapModeEnum.LOOSE);
doParse(queryNLReq, parseResp);
}
if (parseResp.getSelectedParses().isEmpty()) {
errMsg.append(parseResp.getErrorMsg());
continue;
@@ -137,11 +139,18 @@ public class NL2SQLParser implements ChatQueryParser {
SemanticParseInfo userSelectParse = parseContext.getRequest().getSelectedParse();
queryNLReq.setSelectedParseInfo(Objects.nonNull(userSelectParse) ? userSelectParse
: parseContext.getResponse().getSelectedParses().get(0));
parseContext.setResponse(new ChatParseResp(parseContext.getResponse().getQueryId()));
rewriteMultiTurn(parseContext, queryNLReq);
addDynamicExemplars(parseContext, queryNLReq);
doParse(queryNLReq, parseContext.getResponse());
// try again with all semantic fields passed to LLM
if (parseContext.getResponse().getState().equals(ParseResp.ParseState.FAILED)) {
queryNLReq.setSelectedParseInfo(null);
queryNLReq.setMapModeEnum(MapModeEnum.ALL);
doParse(queryNLReq, parseContext.getResponse());
}
}
}

View File

@@ -73,16 +73,19 @@ public class MemoryServiceImpl implements MemoryService {
updateWrapper.set(ChatMemoryDO::getStatus, chatMemoryUpdateReq.getStatus());
}
if (Objects.nonNull(chatMemoryUpdateReq.getLlmReviewRet())) {
updateWrapper.set(ChatMemoryDO::getLlmReviewRet, chatMemoryUpdateReq.getLlmReviewRet().toString());
updateWrapper.set(ChatMemoryDO::getLlmReviewRet,
chatMemoryUpdateReq.getLlmReviewRet().toString());
}
if (Objects.nonNull(chatMemoryUpdateReq.getLlmReviewCmt())) {
updateWrapper.set(ChatMemoryDO::getLlmReviewCmt, chatMemoryUpdateReq.getLlmReviewCmt());
}
if (Objects.nonNull(chatMemoryUpdateReq.getHumanReviewRet())) {
updateWrapper.set(ChatMemoryDO::getHumanReviewRet, chatMemoryUpdateReq.getHumanReviewRet().toString());
updateWrapper.set(ChatMemoryDO::getHumanReviewRet,
chatMemoryUpdateReq.getHumanReviewRet().toString());
}
if (Objects.nonNull(chatMemoryUpdateReq.getHumanReviewCmt())) {
updateWrapper.set(ChatMemoryDO::getHumanReviewCmt, chatMemoryUpdateReq.getHumanReviewCmt());
updateWrapper.set(ChatMemoryDO::getHumanReviewCmt,
chatMemoryUpdateReq.getHumanReviewCmt());
}
updateWrapper.set(ChatMemoryDO::getUpdatedAt, new Date());
updateWrapper.set(ChatMemoryDO::getUpdatedBy, user.getName());

View File

@@ -0,0 +1,24 @@
package com.tencent.supersonic.common.jsqlparser;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.statement.select.SelectItem;
import java.util.Set;
public class AliasAcquireVisitor extends ExpressionVisitorAdapter {
private Set<String> aliases;
public AliasAcquireVisitor(Set<String> aliases) {
this.aliases = aliases;
}
@Override
public void visit(SelectItem selectItem) {
Alias alias = selectItem.getAlias();
if (alias != null) {
aliases.add(alias.getName());
}
}
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.common.jsqlparser;
import com.google.common.collect.Sets;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
@@ -11,6 +12,7 @@ import java.util.Set;
public class FieldAcquireVisitor extends ExpressionVisitorAdapter {
private Set<String> fields;
private Set<String> aliases = Sets.newHashSet();
public FieldAcquireVisitor(Set<String> fields) {
this.fields = fields;
@@ -26,8 +28,9 @@ public class FieldAcquireVisitor extends ExpressionVisitorAdapter {
public void visit(SelectItem selectItem) {
Alias alias = selectItem.getAlias();
if (alias != null) {
fields.add(alias.getName());
aliases.add(alias.getName());
}
Expression expression = selectItem.getExpression();
if (expression != null) {
expression.accept(this);

View File

@@ -133,6 +133,15 @@ public class SqlSelectHelper {
return result;
}
public static Set<String> getAliasFields(PlainSelect plainSelect) {
Set<String> result = new HashSet<>();
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
for (SelectItem selectItem : selectItems) {
selectItem.accept(new AliasAcquireVisitor(result));
}
return result;
}
public static List<PlainSelect> getPlainSelect(Select selectStatement) {
if (selectStatement == null) {
return null;
@@ -264,10 +273,14 @@ public class SqlSelectHelper {
public static List<String> getAllSelectFields(String sql) {
List<PlainSelect> plainSelects = getPlainSelects(getPlainSelect(sql));
Set<String> results = new HashSet<>();
Set<String> aliases = new HashSet<>();
for (PlainSelect plainSelect : plainSelects) {
List<String> fields = getFieldsByPlainSelect(plainSelect);
results.addAll(fields);
aliases.addAll(getAliasFields(plainSelect));
}
// do not account in aliases
results.removeAll(aliases);
return new ArrayList<>(results);
}

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.headless.api.pojo.enums;
public enum MapModeEnum {
STRICT(0), MODERATE(2), LOOSE(4);
STRICT(0), MODERATE(2), LOOSE(4), ALL(6);
public int threshold;

View File

@@ -0,0 +1,40 @@
package com.tencent.supersonic.headless.chat.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
import java.util.Map;
@Slf4j
public class AllFieldMapper extends BaseMapper {
@Override
public boolean accept(ChatQueryContext chatQueryContext) {
return MapModeEnum.ALL.equals(chatQueryContext.getRequest().getMapModeEnum());
}
@Override
public void doMap(ChatQueryContext chatQueryContext) {
Map<Long, DataSetSchema> schemaMap =
chatQueryContext.getSemanticSchema().getDataSetSchemaMap();
for (Map.Entry<Long, DataSetSchema> entry : schemaMap.entrySet()) {
List<SchemaElement> schemaElements = Lists.newArrayList();
schemaElements.addAll(entry.getValue().getDimensions());
schemaElements.addAll(entry.getValue().getMetrics());
for (SchemaElement schemaElement : schemaElements) {
chatQueryContext.getMapInfo().getMatchedElements(entry.getKey())
.add(SchemaElementMatch.builder().word(schemaElement.getName())
.element(schemaElement).detectWord(schemaElement.getName())
.similarity(1.0).build());
}
}
}
}

View File

@@ -27,6 +27,10 @@ public abstract class BaseMapper implements SchemaMapper {
@Override
public void map(ChatQueryContext chatQueryContext) {
if (!accept(chatQueryContext)) {
return;
}
String simpleName = this.getClass().getSimpleName();
long startTime = System.currentTimeMillis();
log.debug("before {},mapInfo:{}", simpleName,
@@ -46,6 +50,10 @@ public abstract class BaseMapper implements SchemaMapper {
public abstract void doMap(ChatQueryContext chatQueryContext);
protected boolean accept(ChatQueryContext chatQueryContext) {
return true;
}
public void addToSchemaMap(SchemaMapInfo schemaMap, Long dataSetId,
SchemaElementMatch newElementMatch) {
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =

View File

@@ -20,12 +20,13 @@ import java.util.Objects;
*/
@Slf4j
public class EmbeddingMapper extends BaseMapper {
public void doMap(ChatQueryContext chatQueryContext) {
// Check if the map mode is LOOSE
if (!MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum())) {
return;
@Override
public boolean accept(ChatQueryContext chatQueryContext) {
return MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum());
}
public void doMap(ChatQueryContext chatQueryContext) {
// 1. Query from embedding by queryText
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy);
@@ -62,4 +63,5 @@ public class EmbeddingMapper extends BaseMapper {
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
}
}
}

View File

@@ -35,11 +35,11 @@ public class MapperConfig extends ParameterConfig {
"维度值相似度阈值在动态调整中的最低值", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_TEXT_SIZE =
new Parameter("s2.mapper.embedding.word.size", "4", "用于向量召回文本长度",
new Parameter("s2.mapper.embedding.word.size", "3", "用于向量召回文本长度",
"为提高向量召回效率, 按指定长度进行向量语义召回", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_TEXT_STEP =
new Parameter("s2.mapper.embedding.word.step", "3", "向量召回文本每步长度",
new Parameter("s2.mapper.embedding.word.step", "2", "向量召回文本每步长度",
"为提高向量召回效率, 按指定每步长度进行召回", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_BATCH =
@@ -51,7 +51,7 @@ public class MapperConfig extends ParameterConfig {
"每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_THRESHOLD =
new Parameter("s2.mapper.embedding.threshold", "0.98", "向量召回相似度阈值", "相似度小于该阈值的则舍弃",
new Parameter("s2.mapper.embedding.threshold", "0.8", "向量召回相似度阈值", "相似度小于该阈值的则舍弃",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =

View File

@@ -9,22 +9,23 @@ import lombok.extern.slf4j.Slf4j;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Slf4j
public class TimeFieldMapper extends BaseMapper {
public class PartitionTimeMapper extends BaseMapper {
@Override
public boolean accept(ChatQueryContext chatQueryContext) {
return !(chatQueryContext.getRequest().getText2SQLType().equals(Text2SQLType.ONLY_RULE)
|| chatQueryContext.getMapInfo().isEmpty());
}
@Override
public void doMap(ChatQueryContext chatQueryContext) {
if (chatQueryContext.getRequest().getText2SQLType().equals(Text2SQLType.ONLY_RULE)) {
return;
}
Map<Long, DataSetSchema> schemaMap =
chatQueryContext.getSemanticSchema().getDataSetSchemaMap();
for (Map.Entry<Long, DataSetSchema> entry : schemaMap.entrySet()) {
List<SchemaElement> timeDims = entry.getValue().getDimensions().stream()
.filter(dim -> dim.getTimeFormat() != null).collect(Collectors.toList());
.filter(SchemaElement::isPartitionTime).toList();
for (SchemaElement schemaElement : timeDims) {
chatQueryContext.getMapInfo().getMatchedElements(entry.getKey())
.add(SchemaElementMatch.builder().word(schemaElement.getName())

View File

@@ -21,14 +21,16 @@ import java.util.stream.Collectors;
@Slf4j
public class QueryFilterMapper extends BaseMapper {
private double similarity = 1.0;
private final double similarity = 1.0;
@Override
public boolean accept(ChatQueryContext chatQueryContext) {
return !chatQueryContext.getRequest().getDataSetIds().isEmpty();
}
@Override
public void doMap(ChatQueryContext chatQueryContext) {
Set<Long> dataSetIds = chatQueryContext.getRequest().getDataSetIds();
if (CollectionUtils.isEmpty(dataSetIds)) {
return;
}
SchemaMapInfo schemaMapInfo = chatQueryContext.getMapInfo();
clearOtherSchemaElementMatch(dataSetIds, schemaMapInfo);
for (Long dataSetId : dataSetIds) {

View File

@@ -17,13 +17,14 @@ import java.util.List;
public class TermDescMapper extends BaseMapper {
@Override
public void doMap(ChatQueryContext chatQueryContext) {
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
List<SchemaElement> termElements = mapInfo.getTermDescriptionToMap();
if (CollectionUtils.isEmpty(termElements)
|| chatQueryContext.getRequest().isDescriptionMapped()) {
return;
public boolean accept(ChatQueryContext chatQueryContext) {
return !(CollectionUtils.isEmpty(chatQueryContext.getMapInfo().getTermDescriptionToMap())
|| chatQueryContext.getRequest().isDescriptionMapped());
}
@Override
public void doMap(ChatQueryContext chatQueryContext) {
List<SchemaElement> termElements = chatQueryContext.getMapInfo().getTermDescriptionToMap();
for (SchemaElement schemaElement : termElements) {
ChatQueryContext queryCtx =
buildQueryContext(chatQueryContext, schemaElement.getDescription());

View File

@@ -30,6 +30,9 @@ public class DefaultSemanticTranslator implements SemanticTranslator {
if (parser.accept(queryStatement)) {
log.debug("QueryConverter accept [{}]", parser.getClass().getName());
parser.parse(queryStatement);
if (queryStatement.getStatus() != 0) {
break;
}
}
}
mergeOntologyQuery(queryStatement);

View File

@@ -45,6 +45,14 @@ public class SqlQueryParser implements QueryParser {
List<String> queryFields = SqlSelectHelper.getAllSelectFields(sqlQuery.getSql());
Ontology ontology = queryStatement.getOntology();
OntologyQuery ontologyQuery = buildOntologyQuery(ontology, queryFields);
// check if there are fields not matched with any metric or dimension
if (queryFields.size() > ontologyQuery.getMetrics().size()
+ ontologyQuery.getDimensions().size()) {
queryStatement
.setErrMsg("There are fields in the SQL not matched with any semantic column.");
queryStatement.setStatus(1);
return;
}
queryStatement.setOntologyQuery(ontologyQuery);
AggOption sqlQueryAggOption = getAggOption(sqlQuery.getSql(), ontologyQuery.getMetrics());

View File

@@ -76,7 +76,6 @@ public class ChatWorkflowEngine {
long start = System.currentTimeMillis();
performTranslating(queryCtx, parseResult);
parseResult.getParseTimeCost().setSqlTime(System.currentTimeMillis() - start);
parseResult.setState(ParseResp.ParseState.COMPLETED);
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
break;
default:
@@ -137,7 +136,12 @@ public class ChatWorkflowEngine {
ContextUtils.getBean(SemanticLayerService.class);
SemanticTranslateResp explain =
queryService.translate(semanticQueryReq, queryCtx.getRequest().getUser());
if (explain.isOk()) {
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
parseResult.setState(ParseResp.ParseState.COMPLETED);
} else {
parseResult.setState(ParseResp.ParseState.FAILED);
}
if (StringUtils.isNotBlank(explain.getErrMsg())) {
errorMsg.add(explain.getErrMsg());
}

View File

@@ -4,7 +4,7 @@ com.tencent.supersonic.headless.chat.mapper.SchemaMapper=\
com.tencent.supersonic.headless.chat.mapper.EmbeddingMapper, \
com.tencent.supersonic.headless.chat.mapper.KeywordMapper, \
com.tencent.supersonic.headless.chat.mapper.QueryFilterMapper, \
com.tencent.supersonic.headless.chat.mapper.TimeFieldMapper,\
com.tencent.supersonic.headless.chat.mapper.PartitionTimeMapper,\
com.tencent.supersonic.headless.chat.mapper.TermDescMapper
com.tencent.supersonic.headless.chat.parser.SemanticParser=\

View File

@@ -4,8 +4,9 @@ com.tencent.supersonic.headless.chat.mapper.SchemaMapper=\
com.tencent.supersonic.headless.chat.mapper.EmbeddingMapper, \
com.tencent.supersonic.headless.chat.mapper.KeywordMapper, \
com.tencent.supersonic.headless.chat.mapper.QueryFilterMapper, \
com.tencent.supersonic.headless.chat.mapper.TimeFieldMapper,\
com.tencent.supersonic.headless.chat.mapper.TermDescMapper
com.tencent.supersonic.headless.chat.mapper.PartitionTimeMapper,\
com.tencent.supersonic.headless.chat.mapper.TermDescMapper,\
com.tencent.supersonic.headless.chat.mapper.AllFieldMapper
com.tencent.supersonic.headless.chat.parser.SemanticParser=\
com.tencent.supersonic.headless.chat.parser.llm.LLMSqlParser,\

View File

@@ -24,6 +24,7 @@ import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junitpioneer.jupiter.SetSystemProperty;
import org.springframework.test.context.TestPropertySource;
import java.util.List;
@@ -95,6 +96,7 @@ public class Text2SQLEval extends BaseTest {
}
@Test
@SetSystemProperty(key = "s2.test", value = "true")
public void test_drilldown_and_topN() throws Exception {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("过去30天访问次数最高的部门top3", agent.getId());