Merge branch 'feature/showcase' into feature/lxw

# Conflicts:
#	chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java
#	chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml
This commit is contained in:
jolunoluo
2023-09-25 16:59:31 +08:00
106 changed files with 2495 additions and 1162 deletions

View File

@@ -7,6 +7,7 @@ import java.util.Map;
import java.util.stream.Collectors;
public class SemanticSchema implements Serializable {
private List<ModelSchema> modelSchemaList;
public SemanticSchema(List<ModelSchema> modelSchemaList) {
@@ -34,12 +35,28 @@ public class SemanticSchema implements Serializable {
return dimensions;
}
public List<SchemaElement> getDimensions(Long modelId) {
List<SchemaElement> dimensions = getDimensions();
return getElementsByModelId(modelId, dimensions);
}
public List<SchemaElement> getMetrics() {
List<SchemaElement> metrics = new ArrayList<>();
modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
return metrics;
}
public List<SchemaElement> getMetrics(Long modelId) {
List<SchemaElement> metrics = getMetrics();
return getElementsByModelId(modelId, metrics);
}
private List<SchemaElement> getElementsByModelId(Long modelId, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
.collect(Collectors.toList());
}
public List<SchemaElement> getModels() {
List<SchemaElement> models = new ArrayList<>();
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));

View File

@@ -1,25 +1,20 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import java.util.HashSet;
import java.util.Set;
import lombok.Data;
@Data
public class QueryDataReq {
String queryMode;
SchemaElement model;
Set<SchemaElement> metrics = new HashSet<>();
Set<SchemaElement> dimensions = new HashSet<>();
Set<QueryFilter> dimensionFilters = new HashSet<>();
Set<QueryFilter> metricFilters = new HashSet<>();
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
private Set<Order> orders = new HashSet<>();
private User user;
private Set<SchemaElement> metrics = new HashSet<>();
private Set<SchemaElement> dimensions = new HashSet<>();
private Set<QueryFilter> dimensionFilters = new HashSet<>();
private DateConf dateInfo;
private Long limit;
private Boolean nativeQuery = false;
private Long queryId = 7L;
private Integer parseId = 2;
}

View File

@@ -1,27 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class DateFieldCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String sql = semanticCorrectInfo.getSql();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DATE_FIELD)) {
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
sql = SqlParserUpdateHelper.addWhere(sql, DATE_FIELD, currentDate);
}
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql(sql);
}
}

View File

@@ -1,18 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FieldCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceFields(preSql,
getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId()));
semanticCorrectInfo.setSql(sql);
}
}

View File

@@ -1,16 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FunctionAliasCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql(replaceAlias);
}
}

View File

@@ -1,17 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FunctionCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceFunction(preSql);
semanticCorrectInfo.setSql(sql);
}
}

View File

@@ -16,11 +16,39 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class FieldNameCorrector extends BaseSemanticCorrector {
public class GlobalCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
replaceAlias(semanticCorrectInfo);
updateFieldNameByLinkingValue(semanticCorrectInfo);
updateFieldNameByBizName(semanticCorrectInfo);
addAggregateToMetric(semanticCorrectInfo);
}
private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) {
}
private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) {
String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql(replaceAlias);
}
private void updateFieldNameByBizName(SemanticCorrectInfo semanticCorrectInfo) {
Map<String, String> fieldToBizName = getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId());
String sql = SqlParserUpdateHelper.replaceFields(semanticCorrectInfo.getSql(), fieldToBizName);
semanticCorrectInfo.setSql(sql);
}
private void updateFieldNameByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) {
Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
if (Objects.isNull(context)) {
return;
@@ -45,5 +73,4 @@ public class FieldNameCorrector extends BaseSemanticCorrector {
String sql = SqlParserUpdateHelper.replaceFieldNameByValue(preSql, fieldValueToFieldNames);
semanticCorrectInfo.setSql(sql);
}
}
}

View File

@@ -0,0 +1,15 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class GroupByCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
}
}

View File

@@ -0,0 +1,14 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class HavingCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
}
}

View File

@@ -1,48 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@Slf4j
public class QueryFilterAppend extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException {
String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters());
String preSql = semanticCorrectInfo.getSql();
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to preSql :{}", queryFilter);
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
String sql = SqlParserUpdateHelper.addWhere(preSql, expression);
semanticCorrectInfo.setPreSql(preSql);
semanticCorrectInfo.setSql(sql);
}
}
private String getQueryFilter(QueryFilters queryFilters) {
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return null;
}
return queryFilters.getFilters().stream()
.map(filter -> {
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
return bizNameWrap + operatorWrap + valueWrap;
})
.collect(Collectors.joining(Constants.AND_UPPER));
}
}

View File

@@ -13,11 +13,12 @@ import net.sf.jsqlparser.expression.Expression;
import org.springframework.util.CollectionUtils;
@Slf4j
public class SelectFieldAppendCorrector extends BaseSemanticCorrector {
public class SelectCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
if (SqlParserSelectHelper.hasAggregateFunction(preSql)) {
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(preSql);
if (Objects.nonNull(havingExpression)) {

View File

@@ -5,7 +5,7 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class TableNameCorrector extends BaseSemanticCorrector {
public class TableCorrector extends BaseSemanticCorrector {
public static final String TABLE_PREFIX = "t_";

View File

@@ -1,26 +1,92 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import org.springframework.util.CollectionUtils;
@Slf4j
public class FieldValueCorrector extends BaseSemanticCorrector {
public class WhereCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
public void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException {
addDateIfNotExist(semanticCorrectInfo);
parserDateDiffFunction(semanticCorrectInfo);
addQueryFilter(semanticCorrectInfo);
updateFieldValueByTechName(semanticCorrectInfo);
}
private void addQueryFilter(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException {
String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters());
String preSql = semanticCorrectInfo.getSql();
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to preSql :{}", queryFilter);
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
String sql = SqlParserUpdateHelper.addWhere(preSql, expression);
semanticCorrectInfo.setPreSql(preSql);
semanticCorrectInfo.setSql(sql);
}
}
private void parserDateDiffFunction(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceFunction(preSql);
semanticCorrectInfo.setSql(sql);
}
private void addDateIfNotExist(SemanticCorrectInfo semanticCorrectInfo) {
String sql = semanticCorrectInfo.getSql();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getName())) {
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
sql = SqlParserUpdateHelper.addWhere(sql, TimeDimensionEnum.DAY.getName(), currentDate);
}
semanticCorrectInfo.setSql(sql);
}
private String getQueryFilter(QueryFilters queryFilters) {
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return null;
}
return queryFilters.getFilters().stream()
.map(filter -> {
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
return bizNameWrap + operatorWrap + valueWrap;
})
.collect(Collectors.joining(Constants.AND_UPPER));
}
private void updateFieldValueByTechName(SemanticCorrectInfo semanticCorrectInfo) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getId();
List<SchemaElement> dimensions = semanticSchema.getDimensions().stream()
@@ -39,7 +105,6 @@ public class FieldValueCorrector extends BaseSemanticCorrector {
return;
}
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
if (CollectionUtils.isEmpty(dimensions)) {
return new HashMap<>();

View File

@@ -408,27 +408,20 @@ public class LLMDslParser implements SemanticParser {
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema,
LLMParserConfig llmParserConfig) {
Set<String> results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig);
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelId, semanticSchema);
results.addAll(fieldNameList);
return new ArrayList<>(results);
}
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
Set<String> results = semanticSchema.getDimensions().stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
Set<String> metrics = semanticSchema.getMetrics().stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getMetricTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(metrics);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>(results);
return new HashSet<>();
}
Set<String> fieldNameList = matchedElements.stream()
.filter(schemaElementMatch -> {
@@ -447,13 +440,29 @@ public class LLMDslParser implements SemanticParser {
})
.filter(name -> StringUtils.isNotEmpty(name) && !name.contains("%"))
.collect(Collectors.toSet());
results.addAll(fieldNameList);
return new ArrayList<>(results);
return fieldNameList;
}
private Set<String> getTopNFieldNames(Long modelId, SemanticSchema semanticSchema,
LLMParserConfig llmParserConfig) {
Set<String> results = semanticSchema.getDimensions(modelId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
Set<String> metrics = semanticSchema.getMetrics(modelId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getMetricTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(metrics);
return results;
}
protected Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {
return semanticSchema.getDimensions().stream()
.filter(entry -> modelId.equals(entry.getModel()))
return semanticSchema.getDimensions(modelId).stream()
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
}

View File

@@ -72,6 +72,7 @@ public class ChatQueryController {
public Object queryData(@RequestBody QueryDataReq queryData,
HttpServletRequest request, HttpServletResponse response)
throws Exception {
queryData.setUser(UserHolder.findUser(request, response));
return queryService.executeDirectQuery(queryData, UserHolder.findUser(request, response));
}

View File

@@ -3,11 +3,14 @@ package com.tencent.supersonic.chat.service.impl;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
@@ -15,13 +18,15 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.query.QuerySelector;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.chat.query.llm.dsl.LLMResp;
import com.tencent.supersonic.chat.queryresponder.QueryResponder;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.QueryService;
@@ -29,25 +34,29 @@ import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.service.StatisticsService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import java.util.Map;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import java.util.List;
import java.util.ArrayList;
import java.util.Set;
import java.util.HashSet;
import java.util.HashMap;
import java.util.Comparator;
import java.util.Objects;
import java.util.stream.Collectors;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import com.tencent.supersonic.semantic.api.query.pojo.Filter;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.springframework.beans.BeanUtils;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Primary;
@@ -175,34 +184,26 @@ public class QueryServiceImpl implements QueryService {
ChatContext chatCtx = chatService.getOrCreateContext(queryReq.getChatId());
chatCtx.setAgentId(queryReq.getAgentId());
Long startTime = System.currentTimeMillis();
QueryResult queryResult = null;
try {
queryResult = semanticQuery.execute(queryReq.getUser());
} catch (Exception e) {
log.error("query execute failed, queryText:{}", queryReq.getQueryText(), e);
queryResult = new QueryResult();
queryResult.setQueryState(QueryState.INVALID);
QueryResult queryResult = semanticQuery.execute(queryReq.getUser());
if (queryResult != null) {
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(semanticQuery.getClass().getSimpleName()).type(CostType.QUERY.getType()).build());
saveInfo(timeCostDOList, queryReq.getQueryText(), queryReq.getQueryId(),
queryReq.getUser().getName(), queryReq.getChatId().longValue());
queryResult.setChatContext(parseInfo);
// update chat context after a successful semantic query
if (queryReq.isSaveAnswer() && QueryState.SUCCESS.equals(queryResult.getQueryState())) {
chatCtx.setParseInfo(parseInfo);
chatService.updateContext(chatCtx);
}
chatCtx.setQueryText(queryReq.getQueryText());
chatCtx.setUser(queryReq.getUser().getName());
chatService.updateQuery(queryReq.getQueryId(), queryResult, chatCtx);
} else {
chatService.deleteChatQuery(queryReq.getQueryId());
}
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(semanticQuery.getClass().getSimpleName()).type(CostType.QUERY.getType()).build());
saveInfo(timeCostDOList, queryReq.getQueryText(), queryReq.getQueryId(),
queryReq.getUser().getName(), queryReq.getChatId().longValue());
queryResult.setChatContext(parseInfo);
// update chat context after a successful semantic query
if (queryReq.isSaveAnswer() && QueryState.SUCCESS.equals(queryResult.getQueryState())) {
chatCtx.setParseInfo(parseInfo);
chatService.updateContext(chatCtx);
queryResponder.saveSolvedQuery(queryReq.getQueryText(), queryReq.getQueryId(), queryReq.getParseId());
}
chatCtx.setQueryText(queryReq.getQueryText());
chatCtx.setUser(queryReq.getUser().getName());
chatService.updateQuery(queryReq.getQueryId(), queryResult, chatCtx);
if (!QueryState.SUCCESS.equals(queryResult.getQueryState())) {
List<SolvedQueryRecallResp> solvedQueryRecallResps =
queryResponder.recallSolvedQuery(queryReq.getQueryText());
queryResult.setSimilarSolvedQuery(solvedQueryRecallResps);
}
return queryResult;
}
@@ -273,8 +274,52 @@ public class QueryServiceImpl implements QueryService {
@Override
public QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws SqlParseException {
SemanticQuery semanticQuery = QueryManager.createRuleQuery(queryData.getQueryMode());
BeanUtils.copyProperties(queryData, semanticQuery.getParseInfo());
ChatParseDO chatParseDO = chatService.getParseInfo(queryData.getQueryId(),
queryData.getUser().getName(), queryData.getParseId());
SemanticParseInfo parseInfo = JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
if (!parseInfo.getQueryMode().equals(DslQuery.QUERY_MODE)) {
if (CollectionUtils.isNotEmpty(queryData.getDimensions())) {
parseInfo.setDimensions(queryData.getDimensions());
}
if (CollectionUtils.isNotEmpty(queryData.getMetrics())) {
parseInfo.setMetrics(queryData.getMetrics());
}
if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) {
parseInfo.setDimensionFilters(queryData.getDimensionFilters());
}
}
if (Objects.nonNull(queryData.getDateInfo())) {
parseInfo.setDateInfo(queryData.getDateInfo());
}
if (parseInfo.getQueryMode().equals(DslQuery.QUERY_MODE)
&& CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class);
LLMResp llmResp = dslParseResult.getLlmResp();
String correctorSql = llmResp.getCorrectorSql();
log.info("correctorSql before replacing:{}", correctorSql);
for (QueryFilter dslQueryFilter : queryData.getDimensionFilters()) {
for (QueryFilter queryFilter : parseInfo.getDimensionFilters()) {
if (dslQueryFilter.getBizName().equals(queryFilter.getBizName())) {
Map<String, String> map = new HashMap<>();
map.put(queryFilter.getValue().toString(), dslQueryFilter.getValue().toString());
filedNameToValueMap.put(dslQueryFilter.getBizName(), map);
break;
}
}
}
log.info("filedNameToValueMap:{}", filedNameToValueMap);
correctorSql = SqlParserUpdateHelper.replaceValue(correctorSql, filedNameToValueMap);
log.info("correctorSql after replacing:{}", correctorSql);
llmResp.setCorrectorSql(correctorSql);
dslParseResult.setLlmResp(llmResp);
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, dslParseResult);
parseInfo.setProperties(properties);
}
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
semanticQuery.setParseInfo(parseInfo);
QueryResult queryResult = semanticQuery.execute(user);
queryResult.setChatContext(semanticQuery.getParseInfo());
return queryResult;
@@ -282,8 +327,6 @@ public class QueryServiceImpl implements QueryService {
@Override
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
com.tencent.supersonic.semantic.query.service.QueryService queryService =
ContextUtils.getBean(com.tencent.supersonic.semantic.query.service.QueryService.class);
QueryStructReq queryStructReq = new QueryStructReq();
DateConf dateConf = new DateConf();
@@ -307,7 +350,8 @@ public class QueryServiceImpl implements QueryService {
dimensionFilters.add(dimensionFilter);
queryStructReq.setDimensionFilters(dimensionFilters);
}
QueryResultWithSchemaResp queryResultWithSchemaResp = queryService.queryByStructWithAuth(queryStructReq, user);
SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
QueryResultWithSchemaResp queryResultWithSchemaResp = semanticLayer.queryByStruct(queryStructReq, user);
Set<String> dimensionValues = new HashSet<>();
queryResultWithSchemaResp.getResultList().removeIf(o -> {
if (dimensionValues.contains(o.get(dimensionValueReq.getBizName()))) {

View File

@@ -1,348 +1,371 @@
examplars= [
{ "current_date":"2020-12-01",
"table_name":"内容库产品",
"fields_list":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",
"question":"比较jackjchen和robinlee在内容库的访问次数",
"prior_schema_links":"""['jackjchen'->用户名, 'robinlee'->用户名]""",
examplars = [
{
"current_date": "2020-12-01",
"table_name": "内容库产品",
"fields_list": """["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",
"question": "比较jackjchen和robinlee在内容库的访问次数",
"prior_schema_links": """['jackjchen'->用户名, 'robinlee'->用户名]""",
"analysis": """让我们一步一步地思考。在问题“比较jackjchen和robinlee在内容库的访问次数“中我们被问
“比较jackjchen和robinlee”所以我们需要column=[用户名]
”内容库的访问次数“所以我们需要column=[访问次数]
基于table和columns可能的cell values 是 = ['jackjchen', 'robinlee']。""",
"schema_links":"""["用户名", "访问次数", "'jackjchen'", "'robinlee'"]""",
"sql":"""select 用户名, 访问次数 from 内容库产品 where 用户名 in ('jackjchen', 'robinlee') and 数据日期 = '2020-12-01' """
},
{ "current_date":"2022-11-06",
"table_name":"内容库产品",
"fields_list":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",
"question":"内容库近12个月访问人数 按部门",
"prior_schema_links":"""[]""",
"schema_links": """["用户名", "访问次数", "'jackjchen'", "'robinlee'"]""",
"sql": """select 用户名, 访问次数 from 内容库产品 where 用户名 in ('jackjchen', 'robinlee') and 数据日期 = '2020-12-01' """,
},
{
"current_date": "2022-11-06",
"table_name": "内容库产品",
"fields_list": """["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",
"question": "内容库近12个月访问人数 按部门",
"prior_schema_links": """[]""",
"analysis": """让我们一步一步地思考。在问题“内容库近12个月访问人数 按部门“中,我们被问:
”内容库近12个月“所以我们需要column=[数据日期]
“访问人数”所以我们需要column=[访问人数]
”按部门“所以我们需要column=[部门]
基于table和columns可能的cell values 是 = [12]。""",
"schema_links":"""["访问人数", "部门", "数据日期", 12]""",
"sql":"""select 部门, 数据日期, 访问人数 from 内容库产品 where datediff('month', 数据日期, '2022-11-06') <= 12 """
},
{ "current_date":"2023-04-21",
"table_name":"内容库产品",
"fields_list":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",
"question":"内容库美术部、技术研发部的访问时长",
"prior_schema_links":"""['美术部'->部门, '技术研发部'->部门]""",
"schema_links": """["访问人数", "部门", "数据日期", 12]""",
"sql": """select 部门, 数据日期, 访问人数 from 内容库产品 where datediff('month', 数据日期, '2022-11-06') <= 12 """,
},
{
"current_date": "2023-04-21",
"table_name": "内容库产品",
"fields_list": """["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",
"question": "内容库美术部、技术研发部的访问时长",
"prior_schema_links": """['美术部'->部门, '技术研发部'->部门]""",
"analysis": """让我们一步一步地思考。在问题“内容库美术部、技术研发部的访问时长“中,我们被问:
“访问时长”所以我们需要column=[访问时长]
”内容库美术部、技术研发部“所以我们需要column=[部门]
基于table和columns可能的cell values 是 = ['美术部', '技术研发部']。""",
"schema_links":"""["访问时长", "部门", "'美术部'", "'技术研发部'"]""",
"sql":"""select 部门, 访问时长 from 内容库产品 where 部门 in ('美术部', '技术研发部') and 数据日期 = '2023-04-21' """
},
{ "current_date":"2023-08-21",
"table_name":"严选",
"fields_list":"""["严选版权归属系", "付费模式", "结算播放份额", "付费用户结算播放份额", "数据日期"]""",
"question":"近3天海田飞系MPPM结算播放份额",
"prior_schema_links":"""['海田飞系'->严选版权归属系]""",
"schema_links": """["访问时长", "部门", "'美术部'", "'技术研发部'"]""",
"sql": """select 部门, 访问时长 from 内容库产品 where 部门 in ('美术部', '技术研发部') and 数据日期 = '2023-04-21' """,
},
{
"current_date": "2023-08-21",
"table_name": "严选",
"fields_list": """["严选版权归属系", "付费模式", "结算播放份额", "付费用户结算播放份额", "数据日期"]""",
"question": "近3天海田飞系MPPM结算播放份额",
"prior_schema_links": """['海田飞系'->严选版权归属系]""",
"analysis": """让我们一步一步地思考。在问题“近3天海田飞系MPPM结算播放份额“中我们被问
“MPPM结算播放份额”所以我们需要column=[结算播放份额]
”海田飞系“所以我们需要column=[严选版权归属系]
”近3天“所以我们需要column=[数据日期]
基于table和columns可能的cell values 是 = ['海田飞系', 3]。""",
"schema_links":"""["结算播放份额", "严选版权归属系", "数据日期", "'海田飞系'", 3]""",
"sql":"""select 严选版权归属系, 结算播放份额 from 严选 where 严选版权归属系 = '海田飞系' and datediff('day', 数据日期, '2023-08-21') <= 3 """
},
{ "current_date":"2023-05-22",
"table_name":"歌曲库",
"fields_list":"""["是否潮流人歌曲", "C音歌曲ID", "C音歌曲MID", "歌曲名", "歌曲版本", "语种", "歌曲类型", "翻唱类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "结算播放量", "运营播放量", "付费用户结算播放量", "历史累计结算播放量", "运营搜播量", "结算搜播量", "运营完播量", "运营推播量", "近7日复播率", "日均搜播量", "数据日期"]""",
"question":"对比近7天翻唱版和纯音乐的歌曲播放量",
"prior_schema_links":"""['纯音乐'->语种, '翻唱版'->歌曲版本]""",
"schema_links": """["结算播放份额", "严选版权归属系", "数据日期", "'海田飞系'", 3]""",
"sql": """select 严选版权归属系, 结算播放份额 from 严选 where 严选版权归属系 = '海田飞系' and datediff('day', 数据日期, '2023-08-21') <= 3 """,
},
{
"current_date": "2023-05-22",
"table_name": "歌曲库",
"fields_list": """["是否潮流人歌曲", "C音歌曲ID", "C音歌曲MID", "歌曲名", "歌曲版本", "语种", "歌曲类型", "翻唱类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "结算播放量", "运营播放量", "付费用户结算播放量", "历史累计结算播放量", "运营搜播量", "结算搜播量", "运营完播量", "运营推播量", "近7日复播率", "日均搜播量", "数据日期"]""",
"question": "对比近7天翻唱版和纯音乐的歌曲播放量",
"prior_schema_links": """['纯音乐'->语种, '翻唱版'->歌曲版本]""",
"analysis": """让我们一步一步地思考。在问题“对比近3天翻唱版和纯音乐的歌曲播放量“中我们被问
“歌曲播放量”所以我们需要column=[结算播放量]
”翻唱版“所以我们需要column=[歌曲版本]
”和纯音乐的歌曲“所以我们需要column=[语种]
”近7天“所以我们需要column=[数据日期]
基于table和columns可能的cell values 是 = ['翻唱版', '纯音乐', 7]。""",
"schema_links":"""["结算播放量", "歌曲版本", "语种", "数据日期", "'翻唱版'", "'纯音乐'", 7]""",
"sql":"""select 歌曲版本, 语种, 结算播放量 from 歌曲库 where 歌曲版本 = '翻唱版' and 语种 = '纯音乐' and datediff('day', 数据日期, '2023-05-22') <= 7 """
},
{ "current_date":"2023-05-31",
"table_name":"艺人库",
"fields_list":"""["上下架状态", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "活跃区域", "年龄", "歌手才能", "歌手风格", "粉丝数", "潮音粉丝数", "超声波粉丝数", "推博粉丝数", "超声波歌曲数", "在架歌曲数", "超声波分享数", "独占歌曲数", "超声波在架歌曲评论数", "有播放量歌曲数", "数据日期"]""",
"question":"对比一下陈拙悬、孟梅琦、赖媚韵的粉丝数",
"prior_schema_links":"""['1527896'->MPPM歌手ID, '1565463'->MPPM歌手ID, '2141459'->MPPM歌手ID]""",
"schema_links": """["结算播放量", "歌曲版本", "语种", "数据日期", "'翻唱版'", "'纯音乐'", 7]""",
"sql": """select 歌曲版本, 语种, 结算播放量 from 歌曲库 where 歌曲版本 = '翻唱版' and 语种 = '纯音乐' and datediff('day', 数据日期, '2023-05-22') <= 7 """,
},
{
"current_date": "2023-05-31",
"table_name": "艺人库",
"fields_list": """["上下架状态", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "活跃区域", "年龄", "歌手才能", "歌手风格", "粉丝数", "潮音粉丝数", "超声波粉丝数", "推博粉丝数", "超声波歌曲数", "在架歌曲数", "超声波分享数", "独占歌曲数", "超声波在架歌曲评论数", "有播放量歌曲数", "数据日期"]""",
"question": "对比一下陈拙悬、孟梅琦、赖媚韵的粉丝数",
"prior_schema_links": """['1527896'->MPPM歌手ID, '1565463'->MPPM歌手ID, '2141459'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“对比一下陈拙悬、孟梅琦、赖媚韵的粉丝数“中,我们被问:
“粉丝数”所以我们需要column=[粉丝数]
”陈拙悬、孟梅琦、赖媚韵“所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = ['陈拙悬', '孟梅琦', '赖媚韵']。""",
"schema_links":"""["粉丝数", "歌手名", "'陈拙悬'", "'孟梅琦'", "'赖媚韵'"]""",
"sql":"""select 歌手名, 粉丝数 from 艺人库 where 歌手名 in ('陈拙悬', '孟梅琦', '赖媚韵') and 数据日期 = '2023-05-31' """
},
{ "current_date":"2023-07-31",
"table_name":"歌曲库",
"fields_list":"""["歌曲名", "歌曲版本", "歌曲类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question":"播放量大于1万的歌曲有多少",
"prior_schema_links":"""[]""",
"schema_links": """["粉丝数", "歌手名", "'陈拙悬'", "'孟梅琦'", "'赖媚韵'"]""",
"sql": """select 歌手名, 粉丝数 from 艺人库 where 歌手名 in ('陈拙悬', '孟梅琦', '赖媚韵') and 数据日期 = '2023-05-31' """,
},
{
"current_date": "2023-07-31",
"table_name": "歌曲库",
"fields_list": """["歌曲名", "歌曲版本", "歌曲类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question": "播放量大于1万的歌曲有多少",
"prior_schema_links": """[]""",
"analysis": """让我们一步一步地思考。在问题“播放量大于1万的歌曲有多少“中我们被问
“歌曲有多少”所以我们需要column=[歌曲名]
”播放量大于1万的“所以我们需要column=[结算播放量]
基于table和columns可能的cell values 是 = [10000]。""",
"schema_links":"""["歌曲名", "结算播放量", 10000]""",
"sql":"""select 歌曲名 from 歌曲库 where 结算播放量 > 10000 and 数据日期 = '2023-07-31' """
},
{ "current_date":"2023-07-31",
"table_name":"内容库产品",
"fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"question":"内容库访问时长小于1小时且来自美术部的用户是哪些",
"prior_schema_links":"""['美术部'->部门]""",
"schema_links": """["歌曲名", "结算播放量", 10000]""",
"sql": """select 歌曲名 from 歌曲库 where 结算播放量 > 10000 and 数据日期 = '2023-07-31' """,
},
{
"current_date": "2023-07-31",
"table_name": "内容库产品",
"fields_list": """["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"question": "内容库访问时长小于1小时且来自美术部的用户是哪些",
"prior_schema_links": """['美术部'->部门]""",
"analysis": """让我们一步一步地思考。在问题“内容库访问时长小于1小时且来自美术部的用户是哪些“中我们被问
“用户是哪些”所以我们需要column=[用户名]
”美术部的“所以我们需要column=[部门]
”访问时长小于1小时“所以我们需要column=[访问时长]
基于table和columns可能的cell values 是 = ['美术部', 1]。""",
"schema_links":"""["用户名", "部门", "访问时长", "'美术部'", 1]""",
"sql":"""select 用户名 from 内容库产品 where 部门 = '美术部' and 访问时长 < 1 and 数据日期 = '2023-07-31' """
},
{ "current_date":"2023-08-31",
"table_name":"内容库产品",
"fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"question":"内容库pv最高的用户有哪些",
"prior_schema_links":"""[]""",
"schema_links": """["用户名", "部门", "访问时长", "'美术部'", 1]""",
"sql": """select 用户名 from 内容库产品 where 部门 = '美术部' and 访问时长 < 1 and 数据日期 = '2023-07-31' """,
},
{
"current_date": "2023-08-31",
"table_name": "内容库产品",
"fields_list": """["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"question": "内容库pv最高的用户有哪些",
"prior_schema_links": """[]""",
"analysis": """让我们一步一步地思考。在问题“内容库pv最高的用户有哪些“中我们被问
“用户有哪些”所以我们需要column=[用户名]
”pv最高的“所以我们需要column=[访问次数]
基于table和columns可能的cell values 是 = []。""",
"schema_links":"""["用户名", "访问次数"]""",
"sql":"""select 用户名 from 内容库产品 where 数据日期 = '2023-08-31' order by 访问次数 desc limit 10 """
},
{ "current_date":"2023-08-31",
"table_name":"艺人库",
"fields_list":"""["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""",
"question":"近90天袁亚伟播放量平均值是多少",
"prior_schema_links":"""['152789226'->MPPM歌手ID]""",
"schema_links": """["用户名", "访问次数"]""",
"sql": """select 用户名 from 内容库产品 where 数据日期 = '2023-08-31' order by 访问次数 desc limit 10 """,
},
{
"current_date": "2023-08-31",
"table_name": "艺人库",
"fields_list": """["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""",
"question": "近90天袁亚伟播放量平均值是多少",
"prior_schema_links": """['152789226'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“近90天袁亚伟播放量平均值是多少“中我们被问
“播放量平均值是多少”所以我们需要column=[结算播放量]
”袁亚伟“所以我们需要column=[歌手名]
”近90天“所以我们需要column=[数据日期]
基于table和columns可能的cell values 是 = ['袁亚伟', 90]。""",
"schema_links":"""["结算播放量", "歌手名", "数据日期", "'袁亚伟'", 90]""",
"sql":"""select avg(结算播放量) from 艺人库 where 歌手名 = '袁亚伟' and datediff('day', 数据日期, '2023-08-31') <= 90 """
},
{ "current_date":"2023-08-31",
"table_name":"艺人库",
"fields_list":"""["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""",
"question":"周倩倩近7天结算播放量总和是多少",
"prior_schema_links":"""['199509'->MPPM歌手ID]""",
"schema_links": """["结算播放量", "歌手名", "数据日期", "'袁亚伟'", 90]""",
"sql": """select avg(结算播放量) from 艺人库 where 歌手名 = '袁亚伟' and datediff('day', 数据日期, '2023-08-31') <= 90 """,
},
{
"current_date": "2023-08-31",
"table_name": "艺人库",
"fields_list": """["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""",
"question": "周倩倩近7天结算播放量总和是多少",
"prior_schema_links": """['199509'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“周倩倩近7天结算播放量总和是多少“中我们被问
“结算播放量总和是多少”所以我们需要column=[结算播放量]
”周倩倩“所以我们需要column=[歌手名]
”近7天“所以我们需要column=[数据日期]
基于table和columns可能的cell values 是 = ['周倩倩', 7]。""",
"schema_links":"""["结算播放量", "歌手名", "数据日期", "'周倩倩'", 7]""",
"sql":"""select sum(结算播放量) from 艺人库 where 歌手名 = '周倩倩' and datediff('day', 数据日期, '2023-08-31') <= 7 """
},
{ "current_date":"2023-09-14",
"table_name":"内容库产品",
"fields_list":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",
"question":"内容库访问次数大于1k的部门是哪些",
"prior_schema_links":"""[]""",
"schema_links": """["结算播放量", "歌手名", "数据日期", "'周倩倩'", 7]""",
"sql": """select sum(结算播放量) from 艺人库 where 歌手名 = '周倩倩' and datediff('day', 数据日期, '2023-08-31') <= 7 """,
},
{
"current_date": "2023-09-14",
"table_name": "内容库产品",
"fields_list": """["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",
"question": "内容库访问次数大于1k的部门是哪些",
"prior_schema_links": """[]""",
"analysis": """让我们一步一步地思考。在问题“内容库访问次数大于1k的部门是哪些“中我们被问
“部门是哪些”所以我们需要column=[部门]
”访问次数大于1k的“所以我们需要column=[访问次数]
基于table和columns可能的cell values 是 = [1000]。""",
"schema_links":"""["部门", "访问次数", 1000]""",
"sql":"""select 部门 from 内容库产品 where 访问次数 > 1000 and 数据日期 = '2023-09-14' """
},
{ "current_date":"2023-09-18",
"table_name":"歌曲库",
"fields_list":"""["歌曲名", "MPPM歌手ID", "歌曲版本", "歌曲类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question":"陈亿训唱的所有的播放量大于20k的孤勇者有哪些",
"prior_schema_links":"""['199509'->MPPM歌手ID, '1527123'->MPPM歌曲ID]""",
"schema_links": """["部门", "访问次数", 1000]""",
"sql": """select 部门 from 内容库产品 where 访问次数 > 1000 and 数据日期 = '2023-09-14' """,
},
{
"current_date": "2023-09-18",
"table_name": "歌曲库",
"fields_list": """["歌曲名", "MPPM歌手ID", "歌曲版本", "歌曲类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question": "陈亿训唱的所有的播放量大于20k的孤勇者有哪些",
"prior_schema_links": """['199509'->MPPM歌手ID, '1527123'->MPPM歌曲ID]""",
"analysis": """让我们一步一步地思考。在问题“陈亿训唱的所有的播放量大于20k的孤勇者有哪些“中我们被问
“孤勇者有哪些”所以我们需要column=[歌曲名]
”播放量大于20k的“所以我们需要column=[结算播放量]
”陈亿训唱的“所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = [20000, '陈亿训', '孤勇者']。""",
"schema_links":"""["歌曲名", "结算播放量", "歌手名", 20000, "'陈亿训'", "'孤勇者'"]""",
"sql":"""select 歌曲名 from 歌曲库 where 结算播放量 > 20000 and 歌手名 = '陈亿训' and 歌曲名 = '孤勇者' and 数据日期 = '2023-09-18' """
},
{ "current_date":"2023-09-18",
"table_name":"歌曲库",
"fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question":"周洁轮去年发布的歌曲有哪些",
"prior_schema_links":"""['23109'->MPPM歌手ID]""",
"schema_links": """["歌曲名", "结算播放量", "歌手名", 20000, "'陈亿训'", "'孤勇者'"]""",
"sql": """select 歌曲名 from 歌曲库 where 结算播放量 > 20000 and 歌手名 = '陈亿训' and 歌曲名 = '孤勇者' and 数据日期 = '2023-09-18' """,
},
{
"current_date": "2023-09-18",
"table_name": "歌曲库",
"fields_list": """["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question": "周洁轮去年发布的歌曲有哪些",
"prior_schema_links": """['23109'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“周洁轮去年发布的歌曲有哪些“中,我们被问:
“歌曲有哪些”所以我们需要column=[歌曲名]
”去年发布的“所以我们需要column=[发布时间]
”周洁轮“所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = ['周洁轮', 1]。""",
"schema_links":"""["歌曲名", "发布时间", "歌手名", 1, "'周洁轮'"]""",
"sql":"""select 歌曲名 from 歌曲库 where datediff('year', 发布时间, '2023-09-18') <= 1 and 歌手名 = '周洁轮' and 数据日期 = '2023-09-18' """
},
{ "current_date":"2023-09-11",
"table_name":"艺人库",
"fields_list":"""["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "签约日期", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""",
"question":"我想要近半年签约的播放量前十的歌手有哪些",
"prior_schema_links":"""[]""",
"schema_links": """["歌曲名", "发布时间", "歌手名", 1, "'周洁轮'"]""",
"sql": """select 歌曲名 from 歌曲库 where datediff('year', 发布时间, '2023-09-18') <= 1 and 歌手名 = '周洁轮' and 数据日期 = '2023-09-18' """,
},
{
"current_date": "2023-09-11",
"table_name": "艺人库",
"fields_list": """["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "签约日期", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""",
"question": "我想要近半年签约的播放量前十的歌手有哪些",
"prior_schema_links": """[]""",
"analysis": """让我们一步一步地思考。在问题“我想要近半年签约的播放量前十的歌手“中,我们被问:
“歌手有哪些”所以我们需要column=[歌手名]
”播放量前十的“所以我们需要column=[结算播放量]
”近半年签约的“所以我们需要column=[签约日期]
基于table和columns可能的cell values 是 = [0.5, 10]。""",
"schema_links":"""["歌手名", "结算播放量", "签约日期", 0.5, 10]""",
"sql":"""select 歌手名 from 艺人库 where datediff('year', 签约日期, '2023-09-11') <= 0.5 and 数据日期 = '2023-09-11' order by 结算播放量 desc limit 10"""
},
{ "current_date":"2023-08-12",
"table_name":"歌曲库",
"schema_links": """["歌手名", "结算播放量", "签约日期", 0.5, 10]""",
"sql": """select 歌手名 from 艺人库 where datediff('year', 签约日期, '2023-09-11') <= 0.5 and 数据日期 = '2023-09-11' order by 结算播放量 desc limit 10""",
},
{
"current_date": "2023-08-12",
"table_name": "歌曲库",
"fields_list": """["发行日期", "歌曲语言", "歌曲来源", "歌曲流派", "歌曲名", "歌曲版本", "歌曲类型", "发行时间", "数据日期"]""",
"question":"最近一年发行的歌曲中有哪些在近7天播放超过一千万的",
"prior_schema_links":"""[]""",
"question": "最近一年发行的歌曲中有哪些在近7天播放超过一千万的",
"prior_schema_links": """[]""",
"analysis": """让我们一步一步地思考。在问题“最近一年发行的歌曲中有哪些在近7天播放超过一千万的“中我们被问
“发行的歌曲中有哪些”所以我们需要column=[歌曲名]
”最近一年发行的“所以我们需要column=[发行日期]
”在近7天播放超过一千万的“所以我们需要column=[数据日期, 结算播放量]
基于table和columns可能的cell values 是 = [1, 10000000]""",
"schema_links":"""["歌曲名", "发行日期", "数据日期", "结算播放量", 1, 10000000]""",
"sql":"""select 歌曲名 from 歌曲库 where datediff('year', 发行日期, '2023-08-12') <= 1 and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000"""
},
{ "current_date":"2023-08-12",
"table_name":"歌曲库",
"schema_links": """["歌曲名", "发行日期", "数据日期", "结算播放量", 1, 10000000]""",
"sql": """select 歌曲名 from 歌曲库 where datediff('year', 发行日期, '2023-08-12') <= 1 and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000""",
},
{
"current_date": "2023-08-12",
"table_name": "歌曲库",
"fields_list": """["发行日期", "歌曲语言", "歌曲来源", "歌曲流派", "歌曲名", "歌曲版本", "歌曲类型", "发行时间", "数据日期"]""",
"question":"今年以来发行的歌曲中有哪些在近7天播放超过一千万的",
"prior_schema_links":"""[]""",
"question": "今年以来发行的歌曲中有哪些在近7天播放超过一千万的",
"prior_schema_links": """[]""",
"analysis": """让我们一步一步地思考。在问题“今年以来发行的歌曲中有哪些在近7天播放超过一千万的“中我们被问
“发行的歌曲中有哪些”所以我们需要column=[歌曲名]
”今年以来发行的“所以我们需要column=[发行日期]
”在近7天播放超过一千万的“所以我们需要column=[数据日期, 结算播放量]
基于table和columns可能的cell values 是 = [0, 7, 10000000]""",
"schema_links":"""["歌曲名", "发行日期", "数据日期", "结算播放量", 0, 7, 10000000]""",
"sql":"""select 歌曲名 from 歌曲库 where datediff('year', 发行日期, '2023-08-12') <= 0 and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000"""
},
{ "current_date":"2023-08-12",
"table_name":"歌曲库",
"schema_links": """["歌曲名", "发行日期", "数据日期", "结算播放量", 0, 7, 10000000]""",
"sql": """select 歌曲名 from 歌曲库 where datediff('year', 发行日期, '2023-08-12') <= 0 and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000""",
},
{
"current_date": "2023-08-12",
"table_name": "歌曲库",
"fields_list": """["发行日期", "歌曲语言", "歌曲来源", "歌曲流派", "歌曲名", "歌曲版本", "歌曲类型", "发行时间", "数据日期"]""",
"question":"2023年以来发行的歌曲中有哪些在近7天播放超过一千万的",
"prior_schema_links":"""['514129144'->MPPM歌曲ID]""",
"question": "2023年以来发行的歌曲中有哪些在近7天播放超过一千万的",
"prior_schema_links": """['514129144'->MPPM歌曲ID]""",
"analysis": """让我们一步一步地思考。在问题“2023年以来发行的歌曲中有哪些在近7天播放超过一千万的“中我们被问
“发行的歌曲中有哪些”所以我们需要column=[歌曲名]
”2023年以来发行的“所以我们需要column=[发行日期]
”在近7天播放超过一千万的“所以我们需要column=[数据日期, 结算播放量]
基于table和columns可能的cell values 是 = [2023, 7, 10000000]""",
"schema_links":"""["歌曲名", "发行日期", "数据日期", "结算播放量", 2023, 7, 10000000]""",
"sql":"""select 歌曲名 from 歌曲库 where YEAR(发行日期) >= 2023 and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000"""
},
{ "current_date":"2023-08-01",
"table_name":"歌曲库",
"fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question":"周洁轮2023年6月之后发布的歌曲有哪些",
"prior_schema_links":"""['23109'->MPPM歌手ID]""",
"schema_links": """["歌曲名", "发行日期", "数据日期", "结算播放量", 2023, 7, 10000000]""",
"sql": """select 歌曲名 from 歌曲库 where YEAR(发行日期) >= 2023 and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000""",
},
{
"current_date": "2023-08-01",
"table_name": "歌曲库",
"fields_list": """["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question": "周洁轮2023年6月之后发布的歌曲有哪些",
"prior_schema_links": """['23109'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“周洁轮2023年6月之后发布的歌曲有哪些“中我们被问
“歌曲有哪些”所以我们需要column=[歌曲名]
”2023年6月之后发布的“所以我们需要column=[发布时间]
”周洁轮“所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = ['周洁轮', 2023, 6]。""",
"schema_links":"""["歌曲名", "发布时间", "歌手名", "周洁轮", 2023, 6]""",
"sql":"""select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 2023 and MONTH(发布时间) >= 6 and 歌手名 = '周洁轮' and 数据日期 = '2023-08-01' """
},
{ "current_date":"2023-08-01",
"table_name":"歌曲库",
"fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question":"邓梓琦在2023年1月5日之后发布的歌曲中有哪些播放量大于500W的",
"prior_schema_links":"""['2312311'->MPPM歌手ID]""",
"schema_links": """["歌曲名", "发布时间", "歌手名", "周洁轮", 2023, 6]""",
"sql": """select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 2023 and MONTH(发布时间) >= 6 and 歌手名 = '周洁轮' and 数据日期 = '2023-08-01' """,
},
{
"current_date": "2023-08-01",
"table_name": "歌曲库",
"fields_list": """["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question": "邓梓琦在2023年1月5日之后发布的歌曲中有哪些播放量大于500W的",
"prior_schema_links": """['2312311'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“邓梓琦在2023年1月5日之后发布的歌曲中有哪些播放量大于500W的“中我们被问
“播放量大于500W的”所以我们需要column=[结算播放量]
”邓梓琦在2023年1月5日之后发布的“所以我们需要column=[发布时间]
”邓梓琦“所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = ['邓梓琦', 2023, 1, 5, 5000000]。""",
"schema_links":"""["结算播放量", "发布时间", "歌手名", "邓梓琦", 2023, 1, 5, 5000000]""",
"sql":"""select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 2023 and MONTH(发布时间) >= 1 and DAY(发布时间) >= 5 and 歌手名 = '邓梓琦' and 结算播放量 > 5000000 and 数据日期 = '2023-08-01'"""
},
{ "current_date":"2023-09-17",
"table_name":"歌曲库",
"fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question":"2023年6月以后张亮英播放量大于200万的歌曲有哪些",
"prior_schema_links":"""['45453'->MPPM歌手ID]""",
"schema_links": """["结算播放量", "发布时间", "歌手名", "邓梓琦", 2023, 1, 5, 5000000]""",
"sql": """select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 2023 and MONTH(发布时间) >= 1 and DAY(发布时间) >= 5 and 歌手名 = '邓梓琦' and 结算播放量 > 5000000 and 数据日期 = '2023-08-01'""",
},
{
"current_date": "2023-09-17",
"table_name": "歌曲库",
"fields_list": """["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question": "2023年6月以后张亮英播放量大于200万的歌曲有哪些",
"prior_schema_links": """['45453'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“2023年6月以后张亮英播放量大于200万的歌曲有哪些“中我们被问
“播放量大于200万的”所以我们需要column=[结算播放量]
”2023年6月以后张亮英“所以我们需要column=[数据日期, 歌手名]
”歌曲有哪些“所以我们需要column=[歌曲名]
基于table和columns可能的cell values 是 = ['张亮英', 2023, 6, 2000000]。""",
"schema_links":"""["结算播放量", "数据日期", "歌手名", "张亮英", 2023, 6, 2000000]""",
"sql":"""select 歌曲名 from 歌曲库 where YEAR(数据日期) >= 2023 and MONTH(数据日期) >= 6 and 歌手名 = '张亮英' and 结算播放量 > 2000000 """
},
{ "current_date":"2023-08-16",
"table_name":"歌曲库",
"fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question":"2021年6月以后发布的李雨纯的播放量大于20万的歌曲有哪些",
"prior_schema_links":"""['23109'->MPPM歌手ID]""",
"schema_links": """["结算播放量", "数据日期", "歌手名", "张亮英", 2023, 6, 2000000]""",
"sql": """select 歌曲名 from 歌曲库 where YEAR(数据日期) >= 2023 and MONTH(数据日期) >= 6 and 歌手名 = '张亮英' and 结算播放量 > 2000000 """,
},
{
"current_date": "2023-08-16",
"table_name": "歌曲库",
"fields_list": """["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question": "2021年6月以后发布的李雨纯的播放量大于20万的歌曲有哪些",
"prior_schema_links": """['23109'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“2021年6月以后发布的李雨纯的播放量大于20万的歌曲有哪些“中我们被问
“播放量大于20万的”所以我们需要column=[结算播放量]
”2021年6月以后发布的“所以我们需要column=[发布时间]
”李雨纯“所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = ['李雨纯', 2021, 6, 200000]。""",
"schema_links":"""["结算播放量", "发布时间", "歌手名", "李雨纯", 2021, 6, 200000]""",
"sql":"""select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 2021 and MONTH(发布时间) >= 6 and 歌手名 = '李雨纯' and 结算播放量 > 200000 and 数据日期 = '2023-08-16'"""
},
{ "current_date":"2023-08-16",
"table_name":"歌曲库",
"fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question":"刘锝桦在1992年4月2日到2020年5月2日之间发布的播放量大于20万的歌曲有哪些",
"prior_schema_links":"""['4234234'->MPPM歌手ID]""",
"schema_links": """["结算播放量", "发布时间", "歌手名", "李雨纯", 2021, 6, 200000]""",
"sql": """select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 2021 and MONTH(发布时间) >= 6 and 歌手名 = '李雨纯' and 结算播放量 > 200000 and 数据日期 = '2023-08-16'""",
},
{
"current_date": "2023-08-16",
"table_name": "歌曲库",
"fields_list": """["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question": "刘锝桦在1992年4月2日到2020年5月2日之间发布的播放量大于20万的歌曲有哪些",
"prior_schema_links": """['4234234'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“刘锝桦在1992年4月2日到2020年5月2日之间发布的播放量大于20万的歌曲有哪些“中我们被问
“播放量大于20万的”所以我们需要column=[结算播放量]
”1992年4月2日到2020年5月2日之间发布的“所以我们需要column=[发布时间]
”刘锝桦“所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = ['刘锝桦', 1992, 4, 2, 2020, 5, 2, 200000]。""",
"schema_links":"""["结算播放量", "发布时间", "歌手名", "刘锝桦", 1992, 4, 2, 2020, 5, 2, 200000]""",
"sql":"""select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 1992 and MONTH(发布时间) >= 4 and DAY(发布时间) >= 2 and YEAR(发布时间) <= 2020 and MONTH(发布时间) <= 5 and DAY(发布时间) <= 2 and 歌手名 = '刘锝桦' and 结算播放量 > 200000 and 数据日期 = '2023-08-16'"""
},
"schema_links": """["结算播放量", "发布时间", "歌手名", "刘锝桦", 1992, 4, 2, 2020, 5, 2, 200000]""",
"sql": """select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 1992 and MONTH(发布时间) >= 4 and DAY(发布时间) >= 2 and YEAR(发布时间) <= 2020 and MONTH(发布时间) <= 5 and DAY(发布时间) <= 2 and 歌手名 = '刘锝桦' and 结算播放量 > 200000 and 数据日期 = '2023-08-16'""",
},
{
"current_date":"2023-09-04",
"table_name":"内容库产品",
"fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"question":"内容库近30天访问次数的平均数",
"prior_schema_links":"""[]""",
"current_date": "2023-09-04",
"table_name": "内容库产品",
"fields_list": """["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"question": "内容库近30天访问次数的平均数",
"prior_schema_links": """[]""",
"analysis": """让我们一步一步地思考。在问题“内容库近30天访问次数的平均数“中我们被问
“访问次数的平均数”所以我们需要column=[访问次数]
”内容库近30天“所以我们需要column=[数据日期]
基于table和columns可能的cell values 是 = [30]。""",
"schema_links":"""["访问次数", "数据日期", 30]""",
"sql":"""select avg(访问次数) from 内容库产品 where datediff('day', 数据日期, '2023-09-04') <= 30 """
},
"schema_links": """["访问次数", "数据日期", 30]""",
"sql": """select avg(访问次数) from 内容库产品 where datediff('day', 数据日期, '2023-09-04') <= 30 """,
},
{
"current_date":"2023-09-04",
"table_name":"内容库产品",
"fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"question":"内容库近半年哪个月的访问次数汇总最高",
"prior_schema_links":"""[]""",
"current_date": "2023-09-04",
"table_name": "内容库产品",
"fields_list": """["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"question": "内容库近半年哪个月的访问次数汇总最高",
"prior_schema_links": """[]""",
"analysis": """让我们一步一步地思考。在问题“内容库近半年哪个月的访问次数汇总最高“中,我们被问:
“访问次数汇总最高”所以我们需要column=[访问次数]
”内容库近半年“所以我们需要column=[数据日期]
基于table和columns可能的cell values 是 = [0.5]。""",
"schema_links":"""["访问次数", "数据日期", 0.5]""",
"sql":"""select MONTH(数据日期), sum(访问次数) from 内容库产品 where datediff('year', 数据日期, '2023-09-04') <= 0.5 group by MONTH(数据日期) order by sum(访问次数) desc limit 1 """
},
"schema_links": """["访问次数", "数据日期", 0.5]""",
"sql": """select MONTH(数据日期), sum(访问次数) from 内容库产品 where datediff('year', 数据日期, '2023-09-04') <= 0.5 group by MONTH(数据日期) order by sum(访问次数) desc limit 1 """,
},
{
"current_date":"2023-09-04",
"table_name":"内容库产品",
"fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"question":"内容库近半年每个月的平均访问次数",
"prior_schema_links":"""[]""",
"current_date": "2023-09-04",
"table_name": "内容库产品",
"fields_list": """["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"question": "内容库近半年每个月的平均访问次数",
"prior_schema_links": """[]""",
"analysis": """让我们一步一步地思考。在问题“内容库近半年每个月的平均访问次数“中,我们被问:
“每个月的平均访问次数”所以我们需要column=[访问次数]
”内容库近半年“所以我们需要column=[数据日期]
基于table和columns可能的cell values 是 = [0.5]。""",
"schema_links":"""["访问次数", "数据日期", 0.5]""",
"sql":"""select MONTH(数据日期), avg(访问次数) from 内容库产品 where datediff('year', 数据日期, '2023-09-04') <= 0.5 group by MONTH(数据日期) """
},
"schema_links": """["访问次数", "数据日期", 0.5]""",
"sql": """select MONTH(数据日期), avg(访问次数) from 内容库产品 where datediff('year', 数据日期, '2023-09-04') <= 0.5 group by MONTH(数据日期) """,
},
{
"current_date":"2023-09-10",
"table_name":"内容库产品",
"fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"question":"内容库 按部门统计访问次数 top10 的部门",
"prior_schema_links":"""[]""",
"current_date": "2023-09-10",
"table_name": "内容库产品",
"fields_list": """["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"question": "内容库 按部门统计访问次数 top10 的部门",
"prior_schema_links": """[]""",
"analysis": """让我们一步一步地思考。在问题“内容库 按部门统计访问次数 top10 的部门“中,我们被问:
“访问次数 top10 的部门”所以我们需要column=[访问次数]
”内容库 按部门统计“所以我们需要column=[部门]
基于table和columns可能的cell values 是 = [10]。""",
"schema_links":"""["访问次数", "部门", 10]""",
"sql":"""select 部门, sum(访问次数) from 内容库产品 group by 部门 order by sum(访问次数) desc limit 10 """
}
]
"schema_links": """["访问次数", "部门", 10]""",
"sql": """select 部门, sum(访问次数) from 内容库产品 group by 部门 order by sum(访问次数) desc limit 10 """,
},
]

View File

@@ -14,7 +14,7 @@ def construct_plugin_prompt(tool_config):
tool_name = tool_config["name"]
tool_description = tool_config["description"]
tool_examples = tool_config["examples"]
prompt = "【工具名称】\n" + tool_name + "\n"
prompt += "【工具描述】\n" + tool_description + "\n"
@@ -23,6 +23,7 @@ def construct_plugin_prompt(tool_config):
prompt += example + "\n"
return prompt
def construct_plugin_pool_prompt(tool_config_list):
tool_explain_list = []
for tool_config in tool_config_list:
@@ -35,15 +36,20 @@ def construct_plugin_pool_prompt(tool_config_list):
def construct_task_prompt(query_text, tool_explain_list_str):
instruction = """问题为:{query_text}\n请根据问题和工具的描述选择对应的工具完成任务。请注意只能选择1个工具。请一步一步地分析选择工具的原因(每个工具的【工具适用问题示例】是选择的重要参考依据)并给出最终选择输出格式为json,key为分析过程, ’选择工具‘""".format(query_text=query_text)
instruction = """问题为:{query_text}\n请根据问题和工具的描述选择对应的工具完成任务。请注意只能选择1个工具。请一步一步地分析选择工具的原因(每个工具的【工具适用问题示例】是选择的重要参考依据)并给出最终选择输出格式为json,key为分析过程, ’选择工具‘""".format(
query_text=query_text
)
prompt = "工具选择如下:\n\n{tool_explain_list_str}\n\n【任务说明】\n{instruction}".format(
instruction=instruction, tool_explain_list_str=tool_explain_list_str
)
prompt = "工具选择如下:\n\n{tool_explain_list_str}\n\n【任务说明】\n{instruction}".format(instruction=instruction, tool_explain_list_str=tool_explain_list_str)
return prompt
def plugin_selection_output_parse(llm_output: str)-> Union[Mapping[str, str], None]:
def plugin_selection_output_parse(llm_output: str) -> Union[Mapping[str, str], None]:
try:
pattern = r'\{[^{}]+\}'
pattern = r"\{[^{}]+\}"
find_result = re.findall(pattern, llm_output)
result = find_result[0].strip()
@@ -52,20 +58,24 @@ def plugin_selection_output_parse(llm_output: str)-> Union[Mapping[str, str], No
result_dict = json.loads(result)
print("result_dict: ", result_dict)
key_mapping = {
"分析过程":"analysis",
"选择工具":"toolSelection"
}
key_mapping = {"分析过程": "analysis", "选择工具": "toolSelection"}
converted_result_dict = {key_mapping[key]: value for key, value in result_dict.items() if key in key_mapping}
converted_result_dict = {
key_mapping[key]: value
for key, value in result_dict.items()
if key in key_mapping
}
except Exception as e:
print(e)
converted_result_dict = None
return converted_result_dict
def plugins_config_format_convert(plugin_config_list: List[Mapping[str, Any]]) -> List[Mapping[str, Any]]:
def plugins_config_format_convert(
plugin_config_list: List[Mapping[str, Any]]
) -> List[Mapping[str, Any]]:
plugin_config_list_new = []
for plugin_config in plugin_config_list:
plugin_config_new = dict()
@@ -75,7 +85,9 @@ def plugins_config_format_convert(plugin_config_list: List[Mapping[str, Any]]) -
parameters = plugin_config["parameters"]
examples_str = "\n".join(examples)
description_new = """{plugin_desc}\n\n例如能够处理如下问题:\n{examples_str}""".format(plugin_desc=description, examples_str=examples_str)
description_new = """{plugin_desc}\n\n例如能够处理如下问题:\n{examples_str}""".format(
plugin_desc=description, examples_str=examples_str
)
plugin_config_new["name"] = name
plugin_config_new["description"] = description_new
@@ -84,4 +96,3 @@ def plugins_config_format_convert(plugin_config_list: List[Mapping[str, Any]]) -
plugin_config_list_new.append(plugin_config_new)
return plugin_config_list_new

View File

@@ -10,12 +10,19 @@ import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from plugin_call.prompt_construct import construct_plugin_pool_prompt, construct_task_prompt, plugin_selection_output_parse, plugins_config_format_convert
from plugin_call.prompt_construct import (
construct_plugin_pool_prompt,
construct_task_prompt,
plugin_selection_output_parse,
plugins_config_format_convert,
)
from util.llm_instance import llm
def plugin_selection_run(query_text: str, plugin_configs: List[Mapping[str, Any]])-> Union[Mapping[str, str], None]:
def plugin_selection_run(
query_text: str, plugin_configs: List[Mapping[str, Any]]
) -> Union[Mapping[str, str], None]:
tools_prompt = construct_plugin_pool_prompt(plugin_configs)
task_prompt = construct_task_prompt(query_text, tools_prompt)
@@ -23,4 +30,3 @@ def plugin_selection_run(query_text: str, plugin_configs: List[Mapping[str, Any]
parsed_output = plugin_selection_output_parse(llm_output)
return parsed_output

View File

@@ -11,7 +11,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
def get_ids(documents:List[str]) -> List[str]:
def get_ids(documents: List[str]) -> List[str]:
ids = []
for doc in documents:
ids.append(str(uuid.uuid5(uuid.NAMESPACE_URL, doc)))
@@ -19,25 +19,23 @@ def get_ids(documents:List[str]) -> List[str]:
return ids
def add2preset_query_collection(collection:Collection,
preset_queries:List[str],
preset_query_ids:List[str]
) -> None:
def add2preset_query_collection(
collection: Collection, preset_queries: List[str], preset_query_ids: List[str]
) -> None:
collection.add(documents=preset_queries,
ids=preset_query_ids)
collection.add(documents=preset_queries, ids=preset_query_ids)
def update_preset_query_collection(collection:Collection,
preset_queries:List[str],
preset_query_ids:List[str]
) -> None:
collection.update(documents=preset_queries,
ids=preset_query_ids)
def update_preset_query_collection(
collection: Collection, preset_queries: List[str], preset_query_ids: List[str]
) -> None:
def query2preset_query_collection(collection:Collection, query_texts:List[str], n_results:int=10):
collection.update(documents=preset_queries, ids=preset_query_ids)
def query2preset_query_collection(
collection: Collection, query_texts: List[str], n_results: int = 10
):
collection_cnt = collection.count()
min_n_results = 10
min_n_results = min(collection_cnt, min_n_results)
@@ -56,12 +54,13 @@ def query2preset_query_collection(collection:Collection, query_texts:List[str],
return res
def parse_retrieval_preset_query(res:List[Mapping[str, Any]]):
parsed_res = [[] for _ in range(0, len(res['ids']))]
retrieval_ids = res['ids']
retrieval_distances = res['distances']
retrieval_sentences = res['documents']
def parse_retrieval_preset_query(res: List[Mapping[str, Any]]):
parsed_res = [[] for _ in range(0, len(res["ids"]))]
retrieval_ids = res["ids"]
retrieval_distances = res["distances"]
retrieval_sentences = res["documents"]
for query_idx in range(0, len(retrieval_ids)):
id_ls = retrieval_ids[query_idx]
@@ -73,43 +72,41 @@ def parse_retrieval_preset_query(res:List[Mapping[str, Any]]):
distance = distance_ls[idx]
sentence = sentence_ls[idx]
parsed_res[query_idx].append({
'id': id,
'distance': distance,
'presetQuery': sentence
})
parsed_res[query_idx].append(
{"id": id, "distance": distance, "presetQuery": sentence}
)
return parsed_res
def preset_query_retrieval_format(query_list:List[str], retrieval_list:List[Mapping[str, Any]]):
def preset_query_retrieval_format(
query_list: List[str], retrieval_list: List[Mapping[str, Any]]
):
res = []
for query_idx in range(0, len(query_list)):
query = query_list[query_idx]
retrieval = retrieval_list[query_idx]
res.append({
'query': query,
'retrieval': retrieval
})
res.append({"query": query, "retrieval": retrieval})
return res
def empty_preset_query_collection(collection:Collection) -> None:
def empty_preset_query_collection(collection: Collection) -> None:
collection.delete()
def delete_preset_query_by_ids(collection:Collection, preset_query_ids:List[str]) -> None:
def delete_preset_query_by_ids(
collection: Collection, preset_query_ids: List[str]
) -> None:
collection.delete(ids=preset_query_ids)
def get_preset_query_by_ids(collection:Collection, preset_query_ids:List[str]):
def get_preset_query_by_ids(collection: Collection, preset_query_ids: List[str]):
res = collection.get(ids=preset_query_ids)
return res
def preset_query_collection_size(collection:Collection) -> int:
def preset_query_collection_size(collection: Collection) -> int:
return collection.count()

View File

@@ -13,34 +13,45 @@ from chromadb.api import Collection, Documents, Embeddings
from langchain.llms import OpenAI
from preset_query_db import (get_ids, add2preset_query_collection,
query2preset_query_collection, parse_retrieval_preset_query,
preset_query_retrieval_format, empty_preset_query_collection, preset_query_collection_size)
from preset_query_db import (
get_ids,
add2preset_query_collection,
query2preset_query_collection,
parse_retrieval_preset_query,
preset_query_retrieval_format,
empty_preset_query_collection,
preset_query_collection_size,
)
from util.text2vec import Text2VecEmbeddingFunction
from run_config import CHROMA_DB_PERSIST_PATH, PRESET_QUERY_COLLECTION_NAME
from util.chromadb_instance import client
from util.chromadb_instance import client
emb_func = Text2VecEmbeddingFunction()
collection = client.get_or_create_collection(name=PRESET_QUERY_COLLECTION_NAME,
embedding_function=emb_func,
metadata={"hnsw:space": "cosine"}
) # Get a collection object from an existing collection, by name. If it doesn't exist, create it.
collection = client.get_or_create_collection(
name=PRESET_QUERY_COLLECTION_NAME,
embedding_function=emb_func,
metadata={"hnsw:space": "cosine"},
) # Get a collection object from an existing collection, by name. If it doesn't exist, create it.
print("init_preset_query_collection_size: ", preset_query_collection_size(collection))
def preset_query_retrieval_run(collection:Collection, query_texts_list:List[str], n_results:int=5):
retrieval_res = query2preset_query_collection(collection=collection,
query_texts=query_texts_list,
n_results=n_results)
def preset_query_retrieval_run(
collection: Collection, query_texts_list: List[str], n_results: int = 5
):
retrieval_res = query2preset_query_collection(
collection=collection, query_texts=query_texts_list, n_results=n_results
)
parsed_retrieval_res = parse_retrieval_preset_query(retrieval_res)
parsed_retrieval_res_format = preset_query_retrieval_format(query_texts_list, parsed_retrieval_res)
parsed_retrieval_res_format = preset_query_retrieval_format(
query_texts_list, parsed_retrieval_res
)
print('parsed_retrieval_res_format: ', parsed_retrieval_res_format)
print("parsed_retrieval_res_format: ", parsed_retrieval_res_format)
return parsed_retrieval_res_format

View File

@@ -11,7 +11,7 @@ OPENAI_API_KEY = "YOUR_API_KEY"
TEMPERATURE = 0.0
CHROMA_DB_PERSIST_DIR = 'chm_db'
CHROMA_DB_PERSIST_DIR = "chm_db"
PRESET_QUERY_COLLECTION_NAME = "preset_query_collection"
TEXT2DSL_COLLECTION_NAME = "text2dsl_collection"
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15
@@ -21,9 +21,9 @@ CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)
HF_TEXT2VEC_MODEL_NAME = "GanymedeNil/text2vec-large-chinese"
if __name__ == '__main__':
print('PROJECT_DIR_PATH: ', PROJECT_DIR_PATH)
print('EMB_MODEL_PATH: ', HF_TEXT2VEC_MODEL_NAME)
print('CHROMA_DB_PERSIST_PATH: ', CHROMA_DB_PERSIST_PATH)
print('LLMPARSER_HOST: ', LLMPARSER_HOST)
print('LLMPARSER_PORT: ', LLMPARSER_PORT)
if __name__ == "__main__":
print("PROJECT_DIR_PATH: ", PROJECT_DIR_PATH)
print("EMB_MODEL_PATH: ", HF_TEXT2VEC_MODEL_NAME)
print("CHROMA_DB_PERSIST_PATH: ", CHROMA_DB_PERSIST_PATH)
print("LLMPARSER_HOST: ", LLMPARSER_HOST)
print("LLMPARSER_PORT: ", LLMPARSER_PORT)

View File

@@ -22,20 +22,34 @@ from util.text2vec import Text2VecEmbeddingFunction, hg_embedding
from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2
from run_config import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
def reload_sql_example_collection(vectorstore:Chroma,
sql_examplars:List[Mapping[str, str]],
sql_example_selector:SemanticSimilarityExampleSelector,
example_nums:int
):
def reload_sql_example_collection(
vectorstore: Chroma,
sql_examplars: List[Mapping[str, str]],
sql_example_selector: SemanticSimilarityExampleSelector,
example_nums: int,
):
print("original sql_examples_collection size:", vectorstore._collection.count())
new_collection = empty_chroma_collection_2(collection=vectorstore._collection)
vectorstore._collection = new_collection
print("emptied sql_examples_collection size:", vectorstore._collection.count())
sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
input_keys=["question"],
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links", "current_date", "sql"])
sql_example_selector = SemanticSimilarityExampleSelector(
vectorstore=sql_examples_vectorstore,
k=example_nums,
input_keys=["question"],
example_keys=[
"table_name",
"fields_list",
"prior_schema_links",
"question",
"analysis",
"schema_links",
"current_date",
"sql",
],
)
for example in sql_examplars:
sql_example_selector.add_example(example)
@@ -45,20 +59,36 @@ def reload_sql_example_collection(vectorstore:Chroma,
return vectorstore, sql_example_selector
sql_examples_vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME,
embedding_function=hg_embedding,
client=chromadb_client)
sql_examples_vectorstore = Chroma(
collection_name=TEXT2DSL_COLLECTION_NAME,
embedding_function=hg_embedding,
client=chromadb_client,
)
example_nums = TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
input_keys=["question"],
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links", "current_date", "sql"])
sql_example_selector = SemanticSimilarityExampleSelector(
vectorstore=sql_examples_vectorstore,
k=example_nums,
input_keys=["question"],
example_keys=[
"table_name",
"fields_list",
"prior_schema_links",
"question",
"analysis",
"schema_links",
"current_date",
"sql",
],
)
if sql_examples_vectorstore._collection.count() > 0:
print("examples already in sql_vectorstore")
print("init sql_vectorstore size:", sql_examples_vectorstore._collection.count())
print("init sql_vectorstore size:", sql_examples_vectorstore._collection.count())
print("sql_examplars size:", len(sql_examplars))
sql_examples_vectorstore, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, sql_example_selector, example_nums)
sql_examples_vectorstore, sql_example_selector = reload_sql_example_collection(
sql_examples_vectorstore, sql_examplars, sql_example_selector, example_nums
)
print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count())

View File

@@ -13,17 +13,31 @@ from few_shot_example.sql_exampler import examplars as sql_examplars
from run_config import LLMPARSER_HOST, LLMPARSER_PORT
def text2dsl_setting_update(llm_parser_host:str, llm_parser_port:str,
sql_examplars:List[Mapping[str, str]], example_nums:int, is_shortcut:bool):
def text2dsl_setting_update(
llm_parser_host: str,
llm_parser_port: str,
sql_examplars: List[Mapping[str, str]],
example_nums: int,
is_shortcut: bool,
):
url = f"http://{llm_parser_host}:{llm_parser_port}/query2sql_setting_update/"
print("url: ", url)
payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums, "isShortcut":is_shortcut}
headers = {'content-type': 'application/json'}
payload = {
"sqlExamplars": sql_examplars,
"exampleNums": example_nums,
"isShortcut": is_shortcut,
}
headers = {"content-type": "application/json"}
response = requests.post(url, data=json.dumps(payload), headers=headers)
print(response.text)
if __name__ == "__main__":
text2dsl_setting_update(LLMPARSER_HOST, LLMPARSER_PORT,
sql_examplars, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT)
text2dsl_setting_update(
LLMPARSER_HOST,
LLMPARSER_PORT,
sql_examplars,
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM,
TEXT2DSL_IS_SHORTCUT,
)

View File

@@ -1,21 +1,25 @@
# -*- coding:utf-8 -*-
import re
def schema_link_parse(schema_link_output):
try:
schema_link_output = schema_link_output.strip()
pattern = r'Schema_links:(.*)'
schema_link_output = re.findall(pattern, schema_link_output, re.DOTALL)[0].strip()
pattern = r"Schema_links:(.*)"
schema_link_output = re.findall(pattern, schema_link_output, re.DOTALL)[
0
].strip()
except Exception as e:
print(e)
schema_link_output = None
return schema_link_output
def combo_schema_link_parse(schema_linking_sql_combo_output: str):
try:
schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip()
pattern = r'Schema_links:(\[.*?\])'
pattern = r"Schema_links:(\[.*?\])"
schema_links_match = re.search(pattern, schema_linking_sql_combo_output)
if schema_links_match:
@@ -28,10 +32,11 @@ def combo_schema_link_parse(schema_linking_sql_combo_output: str):
return schema_links
def combo_sql_parse(schema_linking_sql_combo_output: str):
try:
schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip()
pattern = r'SQL:(.*)'
pattern = r"SQL:(.*)"
sql_match = re.search(pattern, schema_linking_sql_combo_output)
if sql_match:

View File

@@ -11,17 +11,31 @@ from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
def schema_linking_exampler(user_query: str,
domain_name: str,
fields_list: List[str],
prior_schema_links: Mapping[str,str],
example_selector: SemanticSimilarityExampleSelector,
) -> str:
def schema_linking_exampler(
user_query: str,
domain_name: str,
fields_list: List[str],
prior_schema_links: Mapping[str, str],
example_selector: SemanticSimilarityExampleSelector,
) -> str:
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
prior_schema_links_str = (
"["
+ ",".join(["""'{}'->{}""".format(k, v) for k, v in prior_schema_links.items()])
+ "]"
)
example_prompt_template = PromptTemplate(input_variables=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"],
template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}")
example_prompt_template = PromptTemplate(
input_variables=[
"table_name",
"fields_list",
"prior_schema_links",
"question",
"analysis",
"schema_links",
],
template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}",
)
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links"
@@ -30,81 +44,121 @@ def schema_linking_exampler(user_query: str,
schema_linking_example_prompt_template = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=example_prompt_template,
example_separator="\n\n",
example_separator="\n\n",
prefix=instruction,
input_variables=["table_name", "fields_list", "prior_schema_links", "question"],
suffix=schema_linking_prompt
)
suffix=schema_linking_prompt,
)
schema_linking_example_prompt = schema_linking_example_prompt_template.format(table_name=domain_name,
fields_list=fields_list,
prior_schema_links=prior_schema_links_str,
question=user_query)
schema_linking_example_prompt = schema_linking_example_prompt_template.format(
table_name=domain_name,
fields_list=fields_list,
prior_schema_links=prior_schema_links_str,
question=user_query,
)
return schema_linking_example_prompt
def sql_exampler(user_query: str,
domain_name: str,
schema_link_str: str,
data_date: str,
example_selector: SemanticSimilarityExampleSelector,
) -> str:
def sql_exampler(
user_query: str,
domain_name: str,
schema_link_str: str,
data_date: str,
example_selector: SemanticSimilarityExampleSelector,
) -> str:
instruction = "# 根据schema_links为每个问题生成SQL查询语句"
sql_example_prompt_template = PromptTemplate(input_variables=["question", "current_date", "table_name", "schema_links", "sql"],
template="问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:{sql}")
sql_example_prompt_template = PromptTemplate(
input_variables=[
"question",
"current_date",
"table_name",
"schema_links",
"sql",
],
template="问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:{sql}",
)
sql_prompt = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:"
sql_example_prompt_template = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=sql_example_prompt_template,
example_separator="\n\n",
example_separator="\n\n",
prefix=instruction,
input_variables=["question", "current_date", "table_name", "schema_links"],
suffix=sql_prompt
)
suffix=sql_prompt,
)
sql_example_prompt = sql_example_prompt_template.format(question=user_query,
current_date=data_date,
table_name=domain_name,
schema_links=schema_link_str)
sql_example_prompt = sql_example_prompt_template.format(
question=user_query,
current_date=data_date,
table_name=domain_name,
schema_links=schema_link_str,
)
return sql_example_prompt
def schema_linking_sql_combo_examplar(user_query: str,
domain_name: str,
data_date : str,
fields_list: List[str],
prior_schema_links: Mapping[str,str],
example_selector: SemanticSimilarityExampleSelector) -> str:
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
def schema_linking_sql_combo_examplar(
user_query: str,
domain_name: str,
data_date: str,
fields_list: List[str],
prior_schema_links: Mapping[str, str],
example_selector: SemanticSimilarityExampleSelector,
) -> str:
example_prompt_template = PromptTemplate(input_variables=["table_name", "fields_list", "prior_schema_links", "current_date", "question", "analysis", "schema_links", "sql"],
template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}\nSQL:{sql}")
prior_schema_links_str = (
"["
+ ",".join(["""'{}'->{}""".format(k, v) for k, v in prior_schema_links.items()])
+ "]"
)
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句"
example_prompt_template = PromptTemplate(
input_variables=[
"table_name",
"fields_list",
"prior_schema_links",
"current_date",
"question",
"analysis",
"schema_links",
"sql",
],
template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}\nSQL:{sql}",
)
instruction = (
"# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句"
)
schema_linking_sql_combo_prompt = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析: 让我们一步一步地思考。"
schema_linking_sql_combo_example_prompt_template = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=example_prompt_template,
example_separator="\n\n",
example_separator="\n\n",
prefix=instruction,
input_variables=["table_name", "fields_list", "prior_schema_links", "current_date", "question"],
suffix=schema_linking_sql_combo_prompt
input_variables=[
"table_name",
"fields_list",
"prior_schema_links",
"current_date",
"question",
],
suffix=schema_linking_sql_combo_prompt,
)
schema_linking_sql_combo_example_prompt = (
schema_linking_sql_combo_example_prompt_template.format(
table_name=domain_name,
fields_list=fields_list,
prior_schema_links=prior_schema_links_str,
current_date=data_date,
question=user_query,
)
schema_linking_sql_combo_example_prompt = schema_linking_sql_combo_example_prompt_template.format(table_name=domain_name,
fields_list=fields_list,
prior_schema_links=prior_schema_links_str,
current_date=data_date,
question=user_query)
)
return schema_linking_sql_combo_example_prompt

View File

@@ -7,133 +7,182 @@ import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from sql.prompt_maker import schema_linking_exampler, sql_exampler, schema_linking_sql_combo_examplar
from sql.constructor import sql_examples_vectorstore, sql_example_selector, reload_sql_example_collection
from sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse
from sql.prompt_maker import (
schema_linking_exampler,
sql_exampler,
schema_linking_sql_combo_examplar,
)
from sql.constructor import (
sql_examples_vectorstore,
sql_example_selector,
reload_sql_example_collection,
)
from sql.output_parser import (
schema_link_parse,
combo_schema_link_parse,
combo_sql_parse,
)
from util.llm_instance import llm
from run_config import TEXT2DSL_IS_SHORTCUT
class Text2DSLAgent(object):
def __init__(self):
self.schema_linking_exampler = schema_linking_exampler
self.sql_exampler = sql_exampler
def __init__(self):
self.schema_linking_exampler = schema_linking_exampler
self.sql_exampler = sql_exampler
self.schema_linking_sql_combo_exampler = schema_linking_sql_combo_examplar
self.schema_linking_sql_combo_exampler = schema_linking_sql_combo_examplar
self.sql_examples_vectorstore = sql_examples_vectorstore
self.sql_example_selector = sql_example_selector
self.sql_examples_vectorstore = sql_examples_vectorstore
self.sql_example_selector = sql_example_selector
self.schema_link_parse = schema_link_parse
self.combo_schema_link_parse = combo_schema_link_parse
self.combo_sql_parse = combo_sql_parse
self.schema_link_parse = schema_link_parse
self.combo_schema_link_parse = combo_schema_link_parse
self.combo_sql_parse = combo_sql_parse
self.llm = llm
self.llm = llm
self.is_shortcut = TEXT2DSL_IS_SHORTCUT
self.is_shortcut = TEXT2DSL_IS_SHORTCUT
def update_examples(self, sql_examples, example_nums, is_shortcut):
self.sql_examples_vectorstore, self.sql_example_selector = reload_sql_example_collection(self.sql_examples_vectorstore,
sql_examples,
self.sql_example_selector,
example_nums)
self.is_shortcut = is_shortcut
def update_examples(self, sql_examples, example_nums, is_shortcut):
(
self.sql_examples_vectorstore,
self.sql_example_selector,
) = reload_sql_example_collection(
self.sql_examples_vectorstore,
sql_examples,
self.sql_example_selector,
example_nums,
)
self.is_shortcut = is_shortcut
def query2sql(self, query_text: str,
schema : Union[dict, None] = None,
current_date: str = None,
linking: Union[List[Mapping[str, str]], None] = None
):
def query2sql(
self,
query_text: str,
schema: Union[dict, None] = None,
current_date: str = None,
linking: Union[List[Mapping[str, str]], None] = None,
):
print("query_text: ", query_text)
print("schema: ", schema)
print("current_date: ", current_date)
print("prior_schema_links: ", linking)
print("query_text: ", query_text)
print("schema: ", schema)
print("current_date: ", current_date)
print("prior_schema_links: ", linking)
if linking is not None:
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
else:
prior_schema_links = {}
if linking is not None:
prior_schema_links = {
item["fieldValue"]: item["fieldName"] for item in linking
}
else:
prior_schema_links = {}
model_name = schema['modelName']
fields_list = schema['fieldNameList']
model_name = schema["modelName"]
fields_list = schema["fieldNameList"]
schema_linking_prompt = self.schema_linking_exampler(query_text, model_name, fields_list, prior_schema_links, self.sql_example_selector)
print("schema_linking_prompt->", schema_linking_prompt)
schema_link_output = self.llm(schema_linking_prompt)
schema_link_str = self.schema_link_parse(schema_link_output)
sql_prompt = self.sql_exampler(query_text, model_name, schema_link_str, current_date, self.sql_example_selector)
print("sql_prompt->", sql_prompt)
sql_output = self.llm(sql_prompt)
schema_linking_prompt = self.schema_linking_exampler(
query_text,
model_name,
fields_list,
prior_schema_links,
self.sql_example_selector,
)
print("schema_linking_prompt->", schema_linking_prompt)
schema_link_output = self.llm(schema_linking_prompt)
schema_link_str = self.schema_link_parse(schema_link_output)
resp = dict()
resp['query'] = query_text
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = linking
resp['dataDate'] = current_date
sql_prompt = self.sql_exampler(
query_text,
model_name,
schema_link_str,
current_date,
self.sql_example_selector,
)
print("sql_prompt->", sql_prompt)
sql_output = self.llm(sql_prompt)
resp['analysisOutput'] = schema_link_output
resp['schemaLinkStr'] = schema_link_str
resp['sqlOutput'] = sql_output
resp = dict()
resp["query"] = query_text
resp["model"] = model_name
resp["fields"] = fields_list
resp["priorSchemaLinking"] = linking
resp["dataDate"] = current_date
print("resp: ", resp)
resp["analysisOutput"] = schema_link_output
resp["schemaLinkStr"] = schema_link_str
return resp
resp["sqlOutput"] = sql_output
def query2sqlcombo(self, query_text: str,
schema : Union[dict, None] = None,
current_date: str = None,
linking: Union[List[Mapping[str, str]], None] = None
):
print("resp: ", resp)
print("query_text: ", query_text)
print("schema: ", schema)
print("current_date: ", current_date)
print("prior_schema_links: ", linking)
return resp
if linking is not None:
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
else:
prior_schema_links = {}
def query2sqlcombo(
self,
query_text: str,
schema: Union[dict, None] = None,
current_date: str = None,
linking: Union[List[Mapping[str, str]], None] = None,
):
model_name = schema['modelName']
fields_list = schema['fieldNameList']
print("query_text: ", query_text)
print("schema: ", schema)
print("current_date: ", current_date)
print("prior_schema_links: ", linking)
schema_linking_sql_combo_prompt = self.schema_linking_sql_combo_exampler(query_text, model_name, current_date, fields_list,
prior_schema_links, self.sql_example_selector)
print("schema_linking_sql_combo_prompt->", schema_linking_sql_combo_prompt)
schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt)
if linking is not None:
prior_schema_links = {
item["fieldValue"]: item["fieldName"] for item in linking
}
else:
prior_schema_links = {}
schema_linking_str = self.combo_schema_link_parse(schema_linking_sql_combo_output)
sql_str = self.combo_sql_parse(schema_linking_sql_combo_output)
model_name = schema["modelName"]
fields_list = schema["fieldNameList"]
resp = dict()
resp['query'] = query_text
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['dataDate'] = current_date
schema_linking_sql_combo_prompt = self.schema_linking_sql_combo_exampler(
query_text,
model_name,
current_date,
fields_list,
prior_schema_links,
self.sql_example_selector,
)
print("schema_linking_sql_combo_prompt->", schema_linking_sql_combo_prompt)
schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt)
resp['analysisOutput'] = schema_linking_sql_combo_output
resp['schemaLinkStr'] = schema_linking_str
resp['sqlOutput'] = sql_str
schema_linking_str = self.combo_schema_link_parse(
schema_linking_sql_combo_output
)
sql_str = self.combo_sql_parse(schema_linking_sql_combo_output)
print("resp: ", resp)
resp = dict()
resp["query"] = query_text
resp["model"] = model_name
resp["fields"] = fields_list
resp["priorSchemaLinking"] = prior_schema_links
resp["dataDate"] = current_date
return resp
resp["analysisOutput"] = schema_linking_sql_combo_output
resp["schemaLinkStr"] = schema_linking_str
resp["sqlOutput"] = sql_str
def query2sql_run(self, query_text: str,
schema : Union[dict, None] = None,
current_date: str = None,
linking: Union[List[Mapping[str, str]], None] = None):
print("resp: ", resp)
return resp
def query2sql_run(
self,
query_text: str,
schema: Union[dict, None] = None,
current_date: str = None,
linking: Union[List[Mapping[str, str]], None] = None,
):
if self.is_shortcut:
return self.query2sqlcombo(query_text, schema, current_date, linking)
else:
return self.query2sql(query_text, schema, current_date, linking)
if self.is_shortcut:
return self.query2sqlcombo(query_text, schema, current_date, linking)
else:
return self.query2sql(query_text, schema, current_date, linking)
text2sql_agent = Text2DSLAgent()

View File

@@ -13,11 +13,19 @@ from fastapi import FastAPI, HTTPException
from sql.run import text2sql_agent
from preset_retrieval.run import preset_query_retrieval_run, collection as preset_query_collection
from preset_retrieval.preset_query_db import (add2preset_query_collection, update_preset_query_collection,
empty_preset_query_collection, delete_preset_query_by_ids,
update_preset_query_collection, get_preset_query_by_ids,
preset_query_collection_size)
from preset_retrieval.run import (
preset_query_retrieval_run,
collection as preset_query_collection,
)
from preset_retrieval.preset_query_db import (
add2preset_query_collection,
update_preset_query_collection,
empty_preset_query_collection,
delete_preset_query_by_ids,
update_preset_query_collection,
get_preset_query_by_ids,
preset_query_collection_size,
)
from plugin_call.run import plugin_selection_run
@@ -27,62 +35,64 @@ from run_config import LLMPARSER_PORT
app = FastAPI()
@app.post("/query2sql/")
async def din_query2sql(query_body: Mapping[str, Any]):
if 'queryText' not in query_body:
raise HTTPException(status_code=400,
detail="query_text is not in query_body")
if "queryText" not in query_body:
raise HTTPException(status_code=400, detail="query_text is not in query_body")
else:
query_text = query_body['queryText']
query_text = query_body["queryText"]
if 'schema' not in query_body:
if "schema" not in query_body:
raise HTTPException(status_code=400, detail="schema is not in query_body")
else:
schema = query_body['schema']
schema = query_body["schema"]
if 'currentDate' not in query_body:
if "currentDate" not in query_body:
raise HTTPException(status_code=400, detail="currentDate is not in query_body")
else:
current_date = query_body['currentDate']
current_date = query_body["currentDate"]
if 'linking' not in query_body:
if "linking" not in query_body:
linking = None
else:
linking = query_body['linking']
linking = query_body["linking"]
resp = text2sql_agent.query2sql_run(query_text=query_text,
schema=schema, current_date=current_date, linking=linking)
resp = text2sql_agent.query2sql_run(
query_text=query_text, schema=schema, current_date=current_date, linking=linking
)
return resp
@app.post("/query2sql_setting_update/")
async def query2sql_setting_update(query_body: Mapping[str, Any]):
if 'sqlExamplars' not in query_body:
raise HTTPException(status_code=400,
detail="sqlExamplars is not in query_body")
if "sqlExamplars" not in query_body:
raise HTTPException(status_code=400, detail="sqlExamplars is not in query_body")
else:
sql_examplars = query_body['sqlExamplars']
sql_examplars = query_body["sqlExamplars"]
if 'exampleNums' not in query_body:
if "exampleNums" not in query_body:
raise HTTPException(status_code=400, detail="exampleNums is not in query_body")
else:
example_nums = query_body['exampleNums']
example_nums = query_body["exampleNums"]
if 'isShortcut' not in query_body:
if "isShortcut" not in query_body:
raise HTTPException(status_code=400, detail="isShortcut is not in query_body")
else:
is_shortcut = query_body['isShortcut']
is_shortcut = query_body["isShortcut"]
text2sql_agent.update_examples(sql_examples=sql_examplars, example_nums=example_nums, is_shortcut=is_shortcut)
text2sql_agent.update_examples(
sql_examples=sql_examplars, example_nums=example_nums, is_shortcut=is_shortcut
)
return "success"
@app.post("/preset_query_retrival/")
async def preset_query_retrival(query_text_list: List[str], n_results: int = 5):
parsed_retrieval_res_format = preset_query_retrieval_run(preset_query_collection, query_text_list, n_results)
parsed_retrieval_res_format = preset_query_retrieval_run(
preset_query_collection, query_text_list, n_results
)
return parsed_retrieval_res_format
@@ -93,27 +103,32 @@ async def preset_query_add(preset_info_list: List[Mapping[str, str]]):
preset_query_ids = []
for preset_info in preset_info_list:
preset_queries.append(preset_info['preset_query'])
preset_query_ids.append(preset_info['preset_query_id'])
preset_queries.append(preset_info["preset_query"])
preset_query_ids.append(preset_info["preset_query_id"])
add2preset_query_collection(collection=preset_query_collection,
preset_queries=preset_queries,
preset_query_ids=preset_query_ids)
add2preset_query_collection(
collection=preset_query_collection,
preset_queries=preset_queries,
preset_query_ids=preset_query_ids,
)
return "success"
@app.post("/preset_query_update/")
async def preset_query_update(preset_info_list: List[Mapping[str, str]]):
preset_queries = []
preset_query_ids = []
for preset_info in preset_info_list:
preset_queries.append(preset_info['preset_query'])
preset_query_ids.append(preset_info['preset_query_id'])
preset_queries.append(preset_info["preset_query"])
preset_query_ids.append(preset_info["preset_query_id"])
update_preset_query_collection(collection=preset_query_collection,
preset_queries=preset_queries,
preset_query_ids=preset_query_ids)
update_preset_query_collection(
collection=preset_query_collection,
preset_queries=preset_queries,
preset_query_ids=preset_query_ids,
)
return "success"
@@ -124,39 +139,50 @@ async def preset_query_empty():
return "success"
@app.post("/preset_delete_by_ids/")
async def preset_delete_by_ids(preset_query_ids: List[str]):
delete_preset_query_by_ids(collection=preset_query_collection, preset_query_ids=preset_query_ids)
delete_preset_query_by_ids(
collection=preset_query_collection, preset_query_ids=preset_query_ids
)
return "success"
@app.post("/preset_get_by_ids/")
async def preset_get_by_ids(preset_query_ids: List[str]):
preset_queries = get_preset_query_by_ids(collection=preset_query_collection, preset_query_ids=preset_query_ids)
preset_queries = get_preset_query_by_ids(
collection=preset_query_collection, preset_query_ids=preset_query_ids
)
return preset_queries
@app.get("/preset_query_size/")
async def preset_query_size():
size = preset_query_collection_size(collection=preset_query_collection)
return size
@app.post("/plugin_selection/")
async def tool_selection(query_body: Mapping[str, Any]):
if 'queryText' not in query_body:
if "queryText" not in query_body:
raise HTTPException(status_code=400, detail="query_text is not in query_body")
else:
query_text = query_body['queryText']
query_text = query_body["queryText"]
if 'pluginConfigs' not in query_body:
raise HTTPException(status_code=400, detail="pluginConfigs is not in query_body")
if "pluginConfigs" not in query_body:
raise HTTPException(
status_code=400, detail="pluginConfigs is not in query_body"
)
else:
plugin_configs = query_body['pluginConfigs']
plugin_configs = query_body["pluginConfigs"]
resp = plugin_selection_run(query_text=query_text, plugin_configs=plugin_configs)
return resp
if __name__ == "__main__":
uvicorn.run(app, host=LLMPARSER_HOST, port=LLMPARSER_PORT)

View File

@@ -7,13 +7,15 @@ from chromadb.config import Settings
from run_config import CHROMA_DB_PERSIST_PATH
client = chromadb.Client(Settings(
chroma_db_impl="duckdb+parquet",
persist_directory=CHROMA_DB_PERSIST_PATH # Optional, defaults to .chromadb/ in the current directory
))
client = chromadb.Client(
Settings(
chroma_db_impl="duckdb+parquet",
persist_directory=CHROMA_DB_PERSIST_PATH, # Optional, defaults to .chromadb/ in the current directory
)
)
def empty_chroma_collection_2(collection:Collection):
def empty_chroma_collection_2(collection: Collection):
collection_name = collection.name
client = collection._client
metadata = collection.metadata
@@ -21,17 +23,18 @@ def empty_chroma_collection_2(collection:Collection):
client.delete_collection(collection_name)
new_collection = client.get_or_create_collection(name=collection_name,
metadata=metadata,
embedding_function=embedding_function)
new_collection = client.get_or_create_collection(
name=collection_name, metadata=metadata, embedding_function=embedding_function
)
size_of_new_collection = new_collection.count()
print(f'Collection {collection_name} emptied. Size of new collection: {size_of_new_collection}')
print(
f"Collection {collection_name} emptied. Size of new collection: {size_of_new_collection}"
)
return new_collection
def empty_chroma_collection(collection:Collection):
def empty_chroma_collection(collection: Collection):
collection.delete()

View File

@@ -4,5 +4,6 @@ from langchain.llms import OpenAI
from run_config import MODEL_NAME, OPENAI_API_KEY, TEMPERATURE
llm = OpenAI(openai_api_key=OPENAI_API_KEY, model_name=MODEL_NAME,
temperature=TEMPERATURE)
llm = OpenAI(
openai_api_key=OPENAI_API_KEY, model_name=MODEL_NAME, temperature=TEMPERATURE
)

View File

@@ -9,6 +9,7 @@ from run_config import HF_TEXT2VEC_MODEL_NAME
hg_embedding = HuggingFaceEmbeddings(model_name=HF_TEXT2VEC_MODEL_NAME)
class Text2VecEmbeddingFunction(EmbeddingFunction):
def __call__(self, texts: Documents) -> Embeddings:
@@ -16,13 +17,8 @@ class Text2VecEmbeddingFunction(EmbeddingFunction):
return embeddings
def get_embeddings(documents:List[str]) -> List[List[float]]:
def get_embeddings(documents: List[str]) -> List[List[float]]:
embeddings = hg_embedding.embed_documents(documents)
return embeddings

View File

@@ -3,7 +3,7 @@
<mapper namespace="com.tencent.supersonic.chat.persistence.mapper.ChatQueryDOMapper">
<resultMap id="BaseResultMap" type="com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO">
<id column="question_id" jdbcType="BIGINT" property="questionId" />
<result column="agent_id" jdbcType="BIGINT" property="agentId" />
<result column="agent_id" jdbcType="INTEGER" property="agentId" />
<result column="create_time" jdbcType="TIMESTAMP" property="createTime" />
<result column="user_name" jdbcType="VARCHAR" property="userName" />
<result column="query_state" jdbcType="INTEGER" property="queryState" />
@@ -77,7 +77,7 @@
query_state, chat_id, score,
feedback, query_text, query_result
)
values (#{questionId,jdbcType=BIGINT}, #{agentId,jdbcType=BIGINT}, #{createTime,jdbcType=TIMESTAMP}, #{userName,jdbcType=VARCHAR},
values (#{questionId,jdbcType=BIGINT}, #{agentId,jdbcType=INTEGER}, #{createTime,jdbcType=TIMESTAMP}, #{userName,jdbcType=VARCHAR},
#{queryState,jdbcType=INTEGER}, #{chatId,jdbcType=BIGINT}, #{score,jdbcType=INTEGER},
#{feedback,jdbcType=VARCHAR}, #{queryText,jdbcType=LONGVARCHAR}, #{queryResult,jdbcType=LONGVARCHAR}
)
@@ -98,9 +98,6 @@
<if test="chatId != null">
chat_id = #{chatId,jdbcType=BIGINT},
</if>
<if test="agentId != null">
agent_id = #{agentId,jdbcType=INTEGER},
</if>
<if test="score != null">
score = #{score,jdbcType=INTEGER},
</if>
@@ -116,5 +113,4 @@
</set>
where question_id = #{questionId,jdbcType=BIGINT}
</update>
</mapper>

View File

@@ -59,7 +59,7 @@
join (
select distinct chat_id
from s2_chat_query
where query_state = 0 and agent_id = ${agentId}
where query_state = 1 and agent_id = ${agentId}
order by chat_id desc
limit #{start}, #{limit}
) q2

View File

@@ -1,45 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import static org.mockito.ArgumentMatchers.any;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
class DateFieldCorrectorTest {
@Test
void corrector() {
MockedStatic<DSLDateHelper> dslDateHelper = Mockito.mockStatic(DSLDateHelper.class);
dslDateHelper.when(() -> DSLDateHelper.getReferenceDate(any())).thenReturn("2023-08-14");
DateFieldCorrector dateFieldCorrector = new DateFieldCorrector();
SemanticParseInfo parseInfo = new SemanticParseInfo();
SchemaElement model = new SchemaElement();
model.setId(2L);
parseInfo.setModel(model);
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select count(歌曲名) from 歌曲库 ")
.parseInfo(parseInfo)
.build();
dateFieldCorrector.correct(semanticCorrectInfo);
Assert.assertEquals("SELECT count(歌曲名) FROM 歌曲库 WHERE 数据日期 = '2023-08-14'", semanticCorrectInfo.getSql());
semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'")
.parseInfo(parseInfo)
.build();
dateFieldCorrector.correct(semanticCorrectInfo);
Assert.assertEquals("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'", semanticCorrectInfo.getSql());
}
}

View File

@@ -1,65 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
class FieldNameCorrectorTest {
@Test
void corrector() {
FieldNameCorrector corrector = new FieldNameCorrector();
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select 歌曲名 from 歌曲库 where 专辑照片 = '七里香' and 专辑名 = '流行' and 数据日期 = '2023-08-19'")
.build();
SemanticParseInfo parseInfo = new SemanticParseInfo();
DSLParseResult dslParseResult = new DSLParseResult();
LLMReq llmReq = new LLMReq();
List<ElementValue> linking = new ArrayList<>();
ElementValue elementValue = new ElementValue();
elementValue.setFieldValue("流行");
elementValue.setFieldName("歌曲风格");
linking.add(elementValue);
ElementValue elementValue2 = new ElementValue();
elementValue2.setFieldValue("七里香");
elementValue2.setFieldName("歌曲名");
linking.add(elementValue2);
ElementValue elementValue3 = new ElementValue();
elementValue3.setFieldValue("周杰伦");
elementValue3.setFieldName("歌手名");
linking.add(elementValue3);
ElementValue elementValue4 = new ElementValue();
elementValue4.setFieldValue("流行");
elementValue4.setFieldName("歌曲流派");
linking.add(elementValue4);
llmReq.setLinking(linking);
dslParseResult.setLlmReq(llmReq);
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, dslParseResult);
parseInfo.setProperties(properties);
semanticCorrectInfo.setParseInfo(parseInfo);
corrector.correct(semanticCorrectInfo);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE 歌曲名 = '七里香' AND 歌曲流派 = '流行' AND 数据日期 = '2023-08-19'",
semanticCorrectInfo.getSql());
}
}

View File

@@ -1,71 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import static org.mockito.Mockito.when;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
class FieldValueCorrectorTest {
@Test
void corrector() {
MockedStatic<ContextUtils> mockContextUtils = Mockito.mockStatic(ContextUtils.class);
SchemaService mockSchemaService = Mockito.mock(SchemaService.class);
SemanticSchema mockSemanticSchema = Mockito.mock(SemanticSchema.class);
List<SchemaElement> dimensions = new ArrayList<>();
List<SchemaValueMap> schemaValueMaps = new ArrayList<>();
SchemaValueMap value1 = new SchemaValueMap();
value1.setBizName("杰伦");
value1.setTechName("周杰伦");
value1.setAlias(Arrays.asList("周杰倫", "Jay Chou", "周董", "周先生"));
schemaValueMaps.add(value1);
SchemaElement schemaElement = SchemaElement.builder()
.bizName("singer_name")
.name("歌手名")
.model(2L)
.schemaValueMaps(schemaValueMaps)
.build();
dimensions.add(schemaElement);
when(mockSemanticSchema.getDimensions()).thenReturn(dimensions);
when(mockSchemaService.getSemanticSchema()).thenReturn(mockSemanticSchema);
mockContextUtils.when(() -> ContextUtils.getBean(SchemaService.class)).thenReturn(mockSchemaService);
SemanticParseInfo parseInfo = new SemanticParseInfo();
SchemaElement model = new SchemaElement();
model.setId(2L);
parseInfo.setModel(model);
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生'")
.parseInfo(parseInfo)
.build();
FieldValueCorrector corrector = new FieldValueCorrector();
corrector.correct(semanticCorrectInfo);
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql("select count(song_name) from 歌曲库 where singer_name = '杰伦'");
corrector.correct(semanticCorrectInfo);
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", semanticCorrectInfo.getSql());
}
}

View File

@@ -1,46 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
class SelectFieldAppendCorrectorTest {
@Test
void corrector() {
SelectFieldAppendCorrector corrector = new SelectFieldAppendCorrector();
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' "
+ "and sys_imp_date = '2023-08-09' and 歌曲发布时 = '2023-08-01' order by 播放量 desc limit 11")
.build();
corrector.correct(semanticCorrectInfo);
Assert.assertEquals(
"SELECT 歌曲名, 歌手名, 播放量, 歌曲发布时, 发布日期 FROM 歌曲库 WHERE "
+ "datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '邓紫棋' "
+ "AND sys_imp_date = '2023-08-09' AND 歌曲发布时 = '2023-08-01'"
+ " ORDER BY 播放量 DESC LIMIT 11", semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql("select 用户名 from 内容库产品 where datediff('day', 数据日期, '2023-09-14') <= 30"
+ " group by 用户名 having sum(访问次数) > 2000");
corrector.correct(semanticCorrectInfo);
Assert.assertEquals(
"SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE "
+ "datediff('day', 数据日期, '2023-09-14') <= 30 "
+ "GROUP BY 用户名 HAVING sum(访问次数) > 2000", semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql("SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE "
+ "datediff('day', 数据日期, '2023-09-14') <= 30 "
+ "GROUP BY 用户名 HAVING sum(访问次数) > 2000");
corrector.correct(semanticCorrectInfo);
Assert.assertEquals(
"SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE "
+ "datediff('day', 数据日期, '2023-09-14') <= 30 "
+ "GROUP BY 用户名 HAVING sum(访问次数) > 2000", semanticCorrectInfo.getSql());
}
}

View File

@@ -7,8 +7,8 @@ import com.tencent.supersonic.knowledge.service.WordService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.context.event.ApplicationStartedEvent;
import org.springframework.context.ApplicationListener;
import org.springframework.boot.CommandLineRunner;
import org.springframework.core.annotation.Order;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
@@ -17,7 +17,8 @@ import java.util.concurrent.CompletableFuture;
@Slf4j
@Component
public class ApplicationStartedListener implements ApplicationListener<ApplicationStartedEvent> {
@Order(5)
public class ApplicationStartedListener implements CommandLineRunner {
@Autowired
private KnowledgeService knowledgeService;
@@ -27,7 +28,7 @@ public class ApplicationStartedListener implements ApplicationListener<Applicati
private SchemaService schemaService;
@Override
public void onApplicationEvent(ApplicationStartedEvent event) {
public void run(String... args) {
updateKnowledgeDimValue();
}

View File

@@ -4,18 +4,13 @@ import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.common.pojo.ResultData;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.util.CollectionUtils;
@Slf4j
@@ -24,10 +19,6 @@ public abstract class BaseSemanticLayer implements SemanticLayer {
protected final Cache<String, List<ModelSchemaResp>> modelSchemaCache =
CacheBuilder.newBuilder().expireAfterWrite(10, TimeUnit.SECONDS).build();
protected ParameterizedTypeReference<ResultData<QueryResultWithSchemaResp>> structTypeRef =
new ParameterizedTypeReference<ResultData<QueryResultWithSchemaResp>>() {
};
@SneakyThrows
public List<ModelSchemaResp> fetchModelSchema(List<Long> ids, Boolean cacheEnable) {
if (cacheEnable) {

View File

@@ -57,17 +57,13 @@ public class LocalSemanticLayer extends BaseSemanticLayer {
}
@Override
@SneakyThrows
public QueryResultWithSchemaResp queryByDsl(QueryDslReq queryDslReq, User user) {
try {
queryService = ContextUtils.getBean(QueryService.class);
Object object = queryService.queryBySql(queryDslReq, user);
QueryResultWithSchemaResp queryResultWithSchemaResp = JsonUtil.toObject(JsonUtil.toString(object),
queryService = ContextUtils.getBean(QueryService.class);
Object object = queryService.queryBySql(queryDslReq, user);
QueryResultWithSchemaResp queryResultWithSchemaResp = JsonUtil.toObject(JsonUtil.toString(object),
QueryResultWithSchemaResp.class);
return queryResultWithSchemaResp;
} catch (Exception e) {
log.info("queryByDsl has an exception:{}", e);
}
return null;
return queryResultWithSchemaResp;
}
@Override