mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
(improvement)(headless&chat)Move ChatContext from Headless module to Chat module.
This commit is contained in:
@@ -1,14 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ChatContext {
|
||||
|
||||
private Integer chatId;
|
||||
private String queryText;
|
||||
private SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
private String user;
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
@@ -35,7 +36,6 @@ import java.util.stream.Collectors;
|
||||
public class ChatQueryContext {
|
||||
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Set<Long> dataSetIds;
|
||||
private Map<Long, List<Long>> modelIdToDataSetIds;
|
||||
private User user;
|
||||
@@ -54,6 +54,7 @@ public class ChatQueryContext {
|
||||
private ChatModelConfig modelConfig;
|
||||
private PromptConfig promptConfig;
|
||||
private List<SqlExemplar> dynamicExemplars;
|
||||
private SemanticParseInfo contextParseInfo;
|
||||
|
||||
public List<SemanticQuery> getCandidateQueries() {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
|
||||
@@ -11,7 +11,6 @@ import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
@@ -29,7 +28,7 @@ import java.util.stream.Collectors;
|
||||
public class QueryTypeParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
|
||||
List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries();
|
||||
User user = chatQueryContext.getUser();
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
|
||||
/**
|
||||
@@ -10,5 +9,5 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
*/
|
||||
public interface SemanticParser {
|
||||
|
||||
void parse(ChatQueryContext chatQueryContext, ChatContext chatContext);
|
||||
void parse(ChatQueryContext chatQueryContext);
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
@@ -23,7 +22,7 @@ import org.apache.commons.collections.MapUtils;
|
||||
public class LLMSqlParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext queryCtx, ChatContext chatCtx) {
|
||||
public void parse(ChatQueryContext queryCtx) {
|
||||
try {
|
||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||
//1.determine whether to skip this parser.
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.rule;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
@@ -41,7 +40,7 @@ public class AggregateTypeParser implements SemanticParser {
|
||||
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k2));
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
String queryText = chatQueryContext.getQueryText();
|
||||
AggregateConf aggregateConf = resolveAggregateConf(queryText);
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricIdQuery;
|
||||
@@ -43,11 +42,11 @@ public class ContextInheritParser implements SemanticParser {
|
||||
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
if (!shouldInherit(chatQueryContext)) {
|
||||
return;
|
||||
}
|
||||
Long dataSetId = getMatchedDataSet(chatQueryContext, chatContext);
|
||||
Long dataSetId = getMatchedDataSet(chatQueryContext);
|
||||
if (dataSetId == null) {
|
||||
return;
|
||||
}
|
||||
@@ -55,10 +54,11 @@ public class ContextInheritParser implements SemanticParser {
|
||||
List<SchemaElementMatch> elementMatches = chatQueryContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
|
||||
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
|
||||
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
|
||||
for (SchemaElementMatch match : chatQueryContext.getContextParseInfo().getElementMatches()) {
|
||||
SchemaElementType matchType = match.getElement().getType();
|
||||
// mutual exclusive element types should not be inherited
|
||||
RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery(chatContext.getParseInfo().getQueryMode());
|
||||
RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery(
|
||||
chatQueryContext.getContextParseInfo().getQueryMode());
|
||||
if (!containsTypes(elementMatches, matchType, ruleQuery)) {
|
||||
match.setInherited(true);
|
||||
matchesToInherit.add(match);
|
||||
@@ -68,7 +68,7 @@ public class ContextInheritParser implements SemanticParser {
|
||||
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(chatQueryContext, chatContext);
|
||||
query.fillParseInfo(chatQueryContext);
|
||||
if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(), chatQueryContext)) {
|
||||
continue;
|
||||
}
|
||||
@@ -108,8 +108,8 @@ public class ContextInheritParser implements SemanticParser {
|
||||
return metricModelQueries.size() == chatQueryContext.getCandidateQueries().size();
|
||||
}
|
||||
|
||||
protected Long getMatchedDataSet(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
Long dataSetId = chatContext.getParseInfo().getDataSetId();
|
||||
protected Long getMatchedDataSet(ChatQueryContext chatQueryContext) {
|
||||
Long dataSetId = chatQueryContext.getContextParseInfo().getDataSetId();
|
||||
if (dataSetId == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
@@ -24,7 +23,7 @@ public class RuleSqlParser implements SemanticParser {
|
||||
);
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
if (!chatQueryContext.getText2SQLType().enableRule()) {
|
||||
return;
|
||||
}
|
||||
@@ -34,11 +33,11 @@ public class RuleSqlParser implements SemanticParser {
|
||||
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(chatQueryContext, chatContext);
|
||||
query.fillParseInfo(chatQueryContext);
|
||||
chatQueryContext.getCandidateQueries().add(query);
|
||||
}
|
||||
}
|
||||
|
||||
auxiliaryParsers.stream().forEach(p -> p.parse(chatQueryContext, chatContext));
|
||||
auxiliaryParsers.stream().forEach(p -> p.parse(chatQueryContext));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat.parser.rule;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
@@ -42,7 +41,7 @@ public class TimeRangeParser implements SemanticParser {
|
||||
private static final DateFormat DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd");
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext queryContext, ChatContext chatContext) {
|
||||
public void parse(ChatQueryContext queryContext) {
|
||||
String queryText = queryContext.getQueryText();
|
||||
DateConf dateConf = parseRecent(queryText);
|
||||
if (dateConf == null) {
|
||||
@@ -59,14 +58,14 @@ public class TimeRangeParser implements SemanticParser {
|
||||
query.getParseInfo().setScore(query.getParseInfo().getScore()
|
||||
+ dateConf.getDetectWord().length());
|
||||
}
|
||||
} else if (QueryManager.containsRuleQuery(chatContext.getParseInfo().getQueryMode())) {
|
||||
} else if (QueryManager.containsRuleQuery(queryContext.getContextParseInfo().getQueryMode())) {
|
||||
RuleSemanticQuery semanticQuery = QueryManager.createRuleQuery(
|
||||
chatContext.getParseInfo().getQueryMode());
|
||||
queryContext.getContextParseInfo().getQueryMode());
|
||||
// inherit parse info from context
|
||||
chatContext.getParseInfo().setDateInfo(dateConf);
|
||||
chatContext.getParseInfo().setScore(chatContext.getParseInfo().getScore()
|
||||
queryContext.getContextParseInfo().setDateInfo(dateConf);
|
||||
queryContext.getContextParseInfo().setScore(queryContext.getContextParseInfo().getScore()
|
||||
+ dateConf.getDetectWord().length());
|
||||
semanticQuery.setParseInfo(chatContext.getParseInfo());
|
||||
semanticQuery.setParseInfo(queryContext.getContextParseInfo());
|
||||
queryContext.getCandidateQueries().add(semanticQuery);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
@@ -49,13 +48,13 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
initS2SqlByStruct(semanticSchema);
|
||||
}
|
||||
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
parseInfo.setQueryMode(getQueryMode());
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
|
||||
fillSchemaElement(parseInfo, semanticSchema);
|
||||
fillScore(parseInfo);
|
||||
fillDateConf(parseInfo, chatContext.getParseInfo());
|
||||
fillDateConf(parseInfo, chatQueryContext.getContextParseInfo());
|
||||
}
|
||||
|
||||
private void fillDateConf(SemanticParseInfo queryParseInfo, SemanticParseInfo chatParseInfo) {
|
||||
|
||||
@@ -7,7 +7,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@@ -19,8 +18,8 @@ import java.util.stream.Collectors;
|
||||
public abstract class DetailListQuery extends DetailSemanticQuery {
|
||||
|
||||
@Override
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
super.fillParseInfo(chatQueryContext, chatContext);
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
super.fillParseInfo(chatQueryContext);
|
||||
this.addEntityDetailAndOrderByMetric(chatQueryContext, parseInfo);
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.time.LocalDate;
|
||||
@@ -35,8 +34,8 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
super.fillParseInfo(chatQueryContext, chatContext);
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
super.fillParseInfo(chatQueryContext);
|
||||
|
||||
parseInfo.setQueryType(QueryType.DETAIL);
|
||||
parseInfo.setLimit(DETAIL_MAX_RESULTS);
|
||||
|
||||
@@ -8,7 +8,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.time.LocalDate;
|
||||
@@ -36,8 +35,8 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
super.fillParseInfo(chatQueryContext, chatContext);
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
super.fillParseInfo(chatQueryContext);
|
||||
parseInfo.setLimit(METRIC_MAX_RESULTS);
|
||||
if (parseInfo.getDateInfo() == null) {
|
||||
DataSetSchema dataSetSchema =
|
||||
|
||||
@@ -6,7 +6,6 @@ import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -50,8 +49,8 @@ public class MetricTopNQuery extends MetricSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
super.fillParseInfo(chatQueryContext, chatContext);
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
super.fillParseInfo(chatQueryContext);
|
||||
|
||||
parseInfo.setLimit(ORDERBY_MAX_RESULTS);
|
||||
parseInfo.setScore(parseInfo.getScore() + 2.0);
|
||||
|
||||
Reference in New Issue
Block a user