[release](project)update version 0.7.4 backend (#66)

This commit is contained in:
daikon
2023-09-10 21:26:46 +08:00
committed by GitHub
parent 02068f58c7
commit a8add4c013
172 changed files with 2180 additions and 1082 deletions

View File

@@ -3,11 +3,8 @@ package com.tencent.supersonic.chat.api.component;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
/**
* This interface defines the contract for a schema mapper that identifies references to schema
* elements in natural language queries.
*
* The schema mapper matches queries against the knowledge base which is constructed using the
* schema of semantic models.
* A schema mapper identifies references to schema elements(metrics/dimensions/entities/values)
* in user queries. It matches the query text against the knowledge base.
*/
public interface SchemaMapper {

View File

@@ -1,9 +1,13 @@
package com.tencent.supersonic.chat.api.component;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import net.sf.jsqlparser.JSQLParserException;
/**
* A semantic corrector checks validity of extracted semantic information and
* performs correction and optimization if needed.
*/
public interface SemanticCorrector {
CorrectionInfo corrector(CorrectionInfo correctionInfo) throws JSQLParserException;
void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException;
}

View File

@@ -18,10 +18,9 @@ import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.List;
/**
* This interface defines the contract for a semantic layer that provides a simplified and
* consistent view of data from multiple sources.
* The semantic layer abstracts away the complexity of the underlying data sources and provides
* a unified view of the data that is easier to understand and use.
* A semantic layer provides a simplified and consistent view of data from multiple sources.
* It abstracts away the complexity of the underlying data sources and provides a unified view
* of the data that is easier to understand and use.
* <p>
* The interface defines methods for getting metadata as well as querying data in the semantic layer.
* Implementations of this interface should provide concrete implementations that interact with the

View File

@@ -5,11 +5,9 @@ import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
/**
* This interface defines the contract for a semantic parser that can analyze natural language query
* and extract meaning from it.
*
* The semantic parser uses either rule-based or model-based algorithms to identify query intent
* and related semantic items described in the query.
* A semantic parser understands user queries and extracts semantic information.
* It could leverage either rule-based or LLM-based approach to identify query intent
* and extract related semantic items from the query.
*/
public interface SemanticParser {

View File

@@ -6,8 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import org.apache.calcite.sql.parser.SqlParseException;
/**
* This class defines the contract for a semantic query that executes specific type of
* query based on the results of semantic parsing.
* A semantic query executes specific type of query based on the results of semantic parsing.
*/
public interface SemanticQuery {

View File

@@ -21,6 +21,10 @@ public class SchemaMapInfo {
return modelElementMatches;
}
public void setModelElementMatches(Map<Long, List<SchemaElementMatch>> modelElementMatches) {
this.modelElementMatches = modelElementMatches;
}
public void setMatchedElements(Long model, List<SchemaElementMatch> elementMatches) {
modelElementMatches.put(model, elementMatches);
}

View File

@@ -10,7 +10,7 @@ import lombok.NoArgsConstructor;
@AllArgsConstructor
@NoArgsConstructor
@Builder
public class CorrectionInfo {
public class SemanticCorrectInfo {
private QueryFilters queryFilters;

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.api.pojo.request;
import lombok.Data;
import lombok.ToString;
@Data
@ToString
public class ItemNameVisibility {
private ItemNameVisibilityInfo aggVisibilityInfo;
private ItemNameVisibilityInfo detailVisibilityInfo;
}

View File

@@ -0,0 +1,23 @@
package com.tencent.supersonic.chat.api.pojo.request;
import lombok.Data;
import lombok.ToString;
import java.util.ArrayList;
import java.util.List;
@Data
@ToString
public class ItemNameVisibilityInfo {
/**
* invisible dimensions
*/
private List<String> blackDimNameList = new ArrayList<>();
/**
* invisible metrics
*/
private List<String> blackMetricNameList = new ArrayList<>();
}

View File

@@ -32,6 +32,4 @@ public class KnowledgeInfoReq {
* advanced knowledge config for single item
*/
private KnowledgeAdvancedConfig knowledgeAdvancedConfig;
}

View File

@@ -12,5 +12,4 @@ public class AgentTool {
private String name;
private AgentToolType type;
}

View File

@@ -1,99 +0,0 @@
/*
//package com.tencent.supersonic.chat.aspect;
//
//import lombok.extern.slf4j.Slf4j;
//import org.aspectj.lang.JoinPoint;
//import org.aspectj.lang.ProceedingJoinPoint;
//import org.aspectj.lang.annotation.*;
//import org.springframework.stereotype.Component;
//
//import java.util.HashMap;
//import java.util.Map;
//
//@Aspect
//@Component
//@Slf4j
//public class TimeCostAspect {
//
// ThreadLocal<Long> startTime = new ThreadLocal<>();
//
// ThreadLocal<Map<String, Long>> map = new ThreadLocal<>();
//
// @Pointcut("execution(public * com.tencent.supersonic.chat.mapper.HanlpDictMapper.*(*))")
// //@Pointcut("execution(* public com.tencent.supersonic.chat.parser.*.*(..))")
// //@Pointcut("execution(* com.tencent.supersonic.chat.mapper.*Mapper.map(..)) ")
// //@Pointcut("execution(* com.tencent.supersonic.chat.mapper.HanlpDictMapper.map(..)) ")
// //@Pointcut("execution(* com.tencent.supersonic.chat.parser.rule.QueryModeParser.*(..)) ")
// public void point() {
// }
//
// @Around("point()")
// public void doAround(ProceedingJoinPoint joinPoint) throws Throwable {
// long start = System.currentTimeMillis();
// try {
// log.info("切面开始");
// Object result = joinPoint.proceed();
// log.info("切面开始");
// if (result == null) {
// //如果切到了 没有返回类型的void方法这里直接返回
// //return null;
// }
// long end = System.currentTimeMillis();
// log.info("===================");
// String targetClassName = joinPoint.getSignature().getDeclaringTypeName();
// String MethodName = joinPoint.getSignature().getName();
// String typeStr = joinPoint.getSignature().getDeclaringType().toString().split(" ")[0];
// log.info("类/接口:" + targetClassName + "(" + typeStr + ")");
// log.info("方法:" + MethodName);
// Long total = end - start;
// log.info("耗时: " + total + " ms!");
// map.get().put(targetClassName + "_" + MethodName, total);
// //return result;
// } catch (Throwable e) {
// long end = System.currentTimeMillis();
// log.info("====around " + joinPoint + "\tUse time : " + (end - start) + " ms with exception : "
// + e.getMessage());
// throw e;
// }
// }
//
//// //对Controller下面的方法执行前进行切入初始化开始时间
//// @Before(value = "execution(* com.appleyk.controller.*.*(..))")
//// public void beforMehhod(JoinPoint jp) {
//// startTime.set(System.currentTimeMillis());
//// }
////
//// //对Controller下面的方法执行后进行切入统计方法执行的次数和耗时情况
//// //注意这里的执行方法统计的数据不止包含Controller下面的方法也包括环绕切入的所有方法的统计信息
//// @AfterReturning(value = "execution(* com.appleyk.controller.*.*(..))")
//// public void afterMehhod(JoinPoint jp) {
//// long end = System.currentTimeMillis();
//// long total = end - startTime.get();
//// String methodName = jp.getSignature().getName();
//// log.info("连接点方法为:" + methodName + ",执行总耗时为:" +total+"ms");
////
//// //重新new一个map
//// Map<String, Long> map = new HashMap<>();
//////从map2中将最后的 连接点方法给移除了,替换成最终的,避免连接点方法多次进行叠加计算
//// //由于map2受ThreadLocal的保护这里不支持remove因此需要单开一个map进行数据交接
//// for(Map.Entry<String, Long> entry:map2.get().entrySet()){
//// if(entry.getKey().equals(methodName)){
//// map.put(methodName, total);
////
//// }else{
//// map.put(entry.getKey(), entry.getValue());
//// }
//// }
////
//// for (Map.Entry<String, Long> entry :map1.get().entrySet()) {
//// for(Map.Entry<String, Long> entry2 :map.entrySet()){
//// if(entry.getKey().equals(entry2.getKey())){
//// System.err.println(entry.getKey()+",被调用次数:"+entry.getValue()+",综合耗时:"+entry2.getValue()+"ms");
//// }
//// }
////
//// }
//// }
//
//}
*/

View File

@@ -9,4 +9,7 @@ import org.springframework.context.annotation.Configuration;
public class FunctionCallInfoConfig {
@Value("${functionCall.url:}")
private String url;
@Value("${funtionCall.plugin.select.path:/plugin_selection}")
private String pluginSelectPath;
}

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
@@ -12,17 +12,16 @@ import org.springframework.util.CollectionUtils;
public class DateFieldCorrector extends BaseSemanticCorrector {
@Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String sql = correctionInfo.getSql();
String sql = semanticCorrectInfo.getSql();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DATE_FIELD)) {
String currentDate = DSLDateHelper.getCurrentDate(correctionInfo.getParseInfo().getModelId());
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
sql = SqlParserUpdateHelper.addWhere(sql, DATE_FIELD, currentDate);
}
correctionInfo.setPreSql(correctionInfo.getSql());
correctionInfo.setSql(sql);
return correctionInfo;
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql(sql);
}
}

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@@ -8,12 +8,11 @@ import lombok.extern.slf4j.Slf4j;
public class FieldCorrector extends BaseSemanticCorrector {
@Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
String preSql = correctionInfo.getSql();
correctionInfo.setPreSql(preSql);
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceFields(preSql,
getFieldToBizName(correctionInfo.getParseInfo().getModelId()));
correctionInfo.setSql(sql);
return correctionInfo;
getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId()));
semanticCorrectInfo.setSql(sql);
}
}

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
@@ -19,32 +19,31 @@ import org.springframework.util.CollectionUtils;
public class FieldNameCorrector extends BaseSemanticCorrector {
@Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
Object context = correctionInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
if (Objects.isNull(context)) {
return correctionInfo;
return;
}
DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class);
if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) {
return correctionInfo;
return;
}
LLMReq llmReq = dslParseResult.getLlmReq();
List<ElementValue> linking = llmReq.getLinking();
if (CollectionUtils.isEmpty(linking)) {
return correctionInfo;
return;
}
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
Collectors.groupingBy(ElementValue::getFieldValue,
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
String preSql = correctionInfo.getSql();
correctionInfo.setPreSql(preSql);
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceFieldNameByValue(preSql, fieldValueToFieldNames);
correctionInfo.setSql(sql);
return correctionInfo;
semanticCorrectInfo.setSql(sql);
}
}

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
@@ -20,23 +20,23 @@ import org.springframework.util.CollectionUtils;
public class FieldValueCorrector extends BaseSemanticCorrector {
@Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Long modelId = correctionInfo.getParseInfo().getModel().getId();
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getId();
List<SchemaElement> dimensions = semanticSchema.getDimensions().stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(dimensions)) {
return correctionInfo;
return;
}
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
String preSql = correctionInfo.getSql();
correctionInfo.setPreSql(preSql);
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceValue(preSql, aliasAndBizNameToTechName);
correctionInfo.setSql(sql);
return correctionInfo;
semanticCorrectInfo.setSql(sql);
return;
}

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@@ -8,11 +8,10 @@ import lombok.extern.slf4j.Slf4j;
public class FunctionCorrector extends BaseSemanticCorrector {
@Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
String preSql = correctionInfo.getSql();
correctionInfo.setPreSql(preSql);
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceFunction(preSql);
correctionInfo.setSql(sql);
return correctionInfo;
semanticCorrectInfo.setSql(sql);
}
}

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.StringUtil;
@@ -18,18 +18,17 @@ import org.apache.commons.lang3.StringUtils;
public class QueryFilterAppend extends BaseSemanticCorrector {
@Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) throws JSQLParserException {
String queryFilter = getQueryFilter(correctionInfo.getQueryFilters());
String preSql = correctionInfo.getSql();
public void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException {
String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters());
String preSql = semanticCorrectInfo.getSql();
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to preSql :{}", queryFilter);
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
String sql = SqlParserUpdateHelper.addWhere(preSql, expression);
correctionInfo.setPreSql(preSql);
correctionInfo.setSql(sql);
semanticCorrectInfo.setPreSql(preSql);
semanticCorrectInfo.setSql(sql);
}
return correctionInfo;
}
private String getQueryFilter(QueryFilters queryFilters) {

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
@@ -14,16 +14,16 @@ import org.springframework.util.CollectionUtils;
public class SelectFieldAppendCorrector extends BaseSemanticCorrector {
@Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
String preSql = correctionInfo.getSql();
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
if (SqlParserSelectHelper.hasAggregateFunction(preSql)) {
return correctionInfo;
return;
}
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(preSql));
Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(preSql));
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
return correctionInfo;
return;
}
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(preSql));
@@ -32,8 +32,7 @@ public class SelectFieldAppendCorrector extends BaseSemanticCorrector {
whereFields.remove(TimeDimensionEnum.WEEK.getName());
whereFields.remove(TimeDimensionEnum.MONTH.getName());
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(preSql, new ArrayList<>(whereFields));
correctionInfo.setPreSql(preSql);
correctionInfo.setSql(replaceFields);
return correctionInfo;
semanticCorrectInfo.setPreSql(preSql);
semanticCorrectInfo.setSql(replaceFields);
}
}

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@@ -10,13 +10,12 @@ public class TableNameCorrector extends BaseSemanticCorrector {
public static final String TABLE_PREFIX = "t_";
@Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
Long modelId = correctionInfo.getParseInfo().getModelId();
String preSql = correctionInfo.getSql();
correctionInfo.setPreSql(preSql);
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
Long modelId = semanticCorrectInfo.getParseInfo().getModelId();
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceTable(preSql, TABLE_PREFIX + modelId);
correctionInfo.setSql(sql);
return correctionInfo;
semanticCorrectInfo.setSql(sql);
}
}

View File

@@ -10,7 +10,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
import com.tencent.supersonic.knowledge.dictionary.builder.WordBuilderFactory;

View File

@@ -14,6 +14,7 @@ import java.util.Set;
import java.util.stream.Collectors;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@@ -49,7 +50,6 @@ public class MapperHelper {
}
public double getThresholdMatch(List<String> natures) {
log.info("optimizationConfig:{}", optimizationConfig);
if (existDimensionValues(natures)) {
return optimizationConfig.getDimensionValueThresholdConfig();
}
@@ -90,9 +90,20 @@ public class MapperHelper {
AgentService agentService = ContextUtils.getBean(AgentService.class);
Set<Long> detectModelIds = agentService.getDslToolsModelIds(request.getAgentId(), null);
//contains all
if (isContainsAllModel(detectModelIds)) {
if (Objects.nonNull(modelId) && modelId > 0) {
Set<Long> result = new HashSet<>();
result.add(modelId);
return result;
}
return new HashSet<>();
}
if (Objects.nonNull(detectModelIds)) {
detectModelIds = detectModelIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
}
if (Objects.nonNull(modelId) && modelId > 0 && Objects.nonNull(detectModelIds)) {
if (detectModelIds.contains(modelId)) {
Set<Long> result = new HashSet<>();
@@ -103,4 +114,8 @@ public class MapperHelper {
return detectModelIds;
}
private boolean isContainsAllModel(Set<Long> detectModelIds) {
return CollectionUtils.isNotEmpty(detectModelIds) && detectModelIds.contains(-1L);
}
}

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.mapper;
import com.google.common.collect.Lists;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.service.SearchService;
import java.util.List;

View File

@@ -1,11 +1,55 @@
package com.tencent.supersonic.chat.parser.llm.dsl;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DatePeriodEnum;
import com.tencent.supersonic.common.util.DateUtils;
import java.util.List;
import java.util.Objects;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
public class DSLDateHelper {
public static String getCurrentDate(Long modelId) {
return DateUtils.getBeforeDate(4);
public static String getReferenceDate(Long modelId) {
String chatDetailDate = getChatDetailDate(modelId);
if (StringUtils.isNotBlank(chatDetailDate)) {
return chatDetailDate;
}
return DateUtils.getBeforeDate(0);
}
private static String getChatDetailDate(Long modelId) {
if (Objects.isNull(modelId)) {
return null;
}
ChatConfigFilter filter = new ChatConfigFilter();
filter.setModelId(modelId);
List<ChatConfigResp> configResps = ContextUtils.getBean(ConfigService.class).search(filter, null);
if (CollectionUtils.isEmpty(configResps)) {
return null;
}
ChatConfigResp chatConfigResp = configResps.get(0);
if (Objects.isNull(chatConfigResp.getChatDetailConfig()) || Objects.isNull(
chatConfigResp.getChatDetailConfig().getChatDefaultConfig())) {
return null;
}
ChatDefaultConfigReq chatDefaultConfig = chatConfigResp.getChatDetailConfig().getChatDefaultConfig();
Integer unit = chatDefaultConfig.getUnit();
String period = chatDefaultConfig.getPeriod();
if (Objects.nonNull(unit)) {
DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(period);
if (Objects.isNull(datePeriodEnum)) {
return DateUtils.getBeforeDate(unit);
} else {
return DateUtils.getBeforeDate(unit, datePeriodEnum);
}
}
return null;
}
}

View File

@@ -6,11 +6,11 @@ import com.tencent.supersonic.chat.agent.tool.DslTool;
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
@@ -67,7 +67,7 @@ public class LLMDslParser implements SemanticParser {
QueryReq request = queryCtx.getRequest();
LLMConfig llmConfig = ContextUtils.getBean(LLMConfig.class);
if (StringUtils.isEmpty(llmConfig.getUrl()) || SatisfactionChecker.check(queryCtx)) {
log.info("llmConfig:{}, skip function parser, queryText:{}", llmConfig, request.getQueryText());
log.info("llmConfig:{}, skip dsl parser, queryText:{}", llmConfig, request.getQueryText());
return;
}
try {
@@ -93,22 +93,56 @@ public class LLMDslParser implements SemanticParser {
SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, dslTool, dslParseResult);
CorrectionInfo correctionInfo = getCorrectorSql(queryCtx, parseInfo, llmResp.getSqlOutput());
SemanticCorrectInfo semanticCorrectInfo = getCorrectorSql(queryCtx, parseInfo, llmResp.getSqlOutput());
llmResp.setCorrectorSql(correctionInfo.getSql());
llmResp.setCorrectorSql(semanticCorrectInfo.getSql());
setFilter(correctionInfo, modelId, parseInfo);
setFilter(semanticCorrectInfo, modelId, parseInfo);
setDimensionsAndMetrics(modelId, parseInfo, semanticCorrectInfo.getSql());
} catch (Exception e) {
log.error("LLMDSLParser error", e);
}
}
public void setFilter(CorrectionInfo correctionInfo, Long modelId, SemanticParseInfo parseInfo) {
private void setDimensionsAndMetrics(Long modelId, SemanticParseInfo parseInfo, String sql) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
String correctorSql = correctionInfo.getPreSql();
if (Objects.isNull(semanticSchema)) {
return;
}
List<String> allFields = getFieldsExceptDate(sql);
Set<SchemaElement> metrics = getElements(modelId, allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics);
Set<SchemaElement> dimensions = getElements(modelId, allFields, semanticSchema.getDimensions());
parseInfo.setDimensions(dimensions);
}
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
&& allFields.contains(schemaElement.getBizName())
).collect(Collectors.toSet());
}
private List<String> getFieldsExceptDate(String sql) {
List<String> allFields = SqlParserSelectHelper.getAllFields(sql);
if (CollectionUtils.isEmpty(allFields)) {
return new ArrayList<>();
}
return allFields.stream()
.filter(entry -> !TimeDimensionEnum.getNameList().contains(entry))
.collect(Collectors.toList());
}
public void setFilter(SemanticCorrectInfo semanticCorrectInfo, Long modelId, SemanticParseInfo parseInfo) {
String correctorSql = semanticCorrectInfo.getPreSql();
if (StringUtils.isEmpty(correctorSql)) {
correctorSql = correctionInfo.getSql();
correctorSql = semanticCorrectInfo.getSql();
}
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
if (CollectionUtils.isEmpty(expressions)) {
@@ -204,9 +238,9 @@ public class LLMDslParser implements SemanticParser {
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
}
private CorrectionInfo getCorrectorSql(QueryContext queryCtx, SemanticParseInfo parseInfo, String sql) {
private SemanticCorrectInfo getCorrectorSql(QueryContext queryCtx, SemanticParseInfo parseInfo, String sql) {
CorrectionInfo correctionInfo = CorrectionInfo.builder()
SemanticCorrectInfo correctInfo = SemanticCorrectInfo.builder()
.queryFilters(queryCtx.getRequest().getQueryFilters()).sql(sql)
.parseInfo(parseInfo).build();
@@ -214,14 +248,13 @@ public class LLMDslParser implements SemanticParser {
dslCorrections.forEach(dslCorrection -> {
try {
dslCorrection.corrector(correctionInfo);
log.info("sqlCorrection:{} sql:{}", dslCorrection.getClass().getSimpleName(),
correctionInfo.getSql());
dslCorrection.correct(correctInfo);
log.info("sqlCorrection:{} sql:{}", dslCorrection.getClass().getSimpleName(), correctInfo.getSql());
} catch (Exception e) {
log.error("sqlCorrection:{} execute error,correctionInfo:{}", dslCorrection, correctionInfo, e);
log.error("sqlCorrection:{} correct error,correctInfo:{}", dslCorrection, correctInfo, e);
}
});
return correctionInfo;
return correctInfo;
}
private SemanticParseInfo getParseInfo(QueryContext queryCtx, Long modelId, DslTool dslTool,
@@ -305,12 +338,12 @@ public class LLMDslParser implements SemanticParser {
List<ElementValue> linking = new ArrayList<>();
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
llmReq.setLinking(linking);
String currentDate = DSLDateHelper.getCurrentDate(modelId);
String currentDate = DSLDateHelper.getReferenceDate(modelId);
llmReq.setCurrentDate(currentDate);
return llmReq;
}
private List<ElementValue> getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
protected List<ElementValue> getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
@@ -348,7 +381,7 @@ public class LLMDslParser implements SemanticParser {
}
private List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
@@ -375,7 +408,7 @@ public class LLMDslParser implements SemanticParser {
return new ArrayList<>(fieldNameList);
}
private Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {
protected Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {
return semanticSchema.getDimensions().stream()
.filter(entry -> modelId.equals(entry.getModel()))
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));

View File

@@ -17,7 +17,7 @@ import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.query.metricinterpret.MetricInterpretQuery;
import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.AgentService;

View File

@@ -1,5 +1,7 @@
package com.tencent.supersonic.chat.parser.plugin;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
@@ -8,7 +10,6 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
@@ -16,14 +17,13 @@ import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import org.springframework.util.CollectionUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
public abstract class PluginParser implements SemanticParser {
@@ -45,7 +45,11 @@ public abstract class PluginParser implements SemanticParser {
public void buildQuery(QueryContext queryContext, PluginRecallResult pluginRecallResult) {
Plugin plugin = pluginRecallResult.getPlugin();
for (Long modelId : pluginRecallResult.getModelIds()) {
Set<Long> modelIds = pluginRecallResult.getModelIds();
if (plugin.isContainsAllModel()) {
modelIds = Sets.newHashSet(-1L);
}
for (Long modelId : modelIds) {
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin, queryContext.getRequest(),
queryContext.getMapInfo().getMatchedElements(modelId), pluginRecallResult.getDistance());
@@ -53,9 +57,6 @@ public abstract class PluginParser implements SemanticParser {
semanticParseInfo.setScore(pluginRecallResult.getScore());
pluginQuery.setParseInfo(semanticParseInfo);
queryContext.getCandidateQueries().add(pluginQuery);
if (plugin.isContainsAllModel()) {
break;
}
}
}
@@ -68,6 +69,9 @@ public abstract class PluginParser implements SemanticParser {
if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) {
modelId = plugin.getModelList().get(0);
}
if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList();
}
SchemaElement model = new SchemaElement();
model.setModel(modelId);
model.setId(modelId);
@@ -85,17 +89,9 @@ public abstract class PluginParser implements SemanticParser {
semanticParseInfo.setProperties(properties);
semanticParseInfo.setScore(distance);
fillSemanticParseInfo(semanticParseInfo);
setEntity(modelId, semanticParseInfo);
return semanticParseInfo;
}
private void setEntity(Long modelId, SemanticParseInfo semanticParseInfo) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
ModelSchema modelSchema = semanticService.getModelSchema(modelId);
if (modelSchema != null && modelSchema.getEntity() != null) {
semanticParseInfo.setEntity(modelSchema.getEntity());
}
}
private void fillSemanticParseInfo(SemanticParseInfo semanticParseInfo) {
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();

View File

@@ -28,6 +28,10 @@ public class EmbeddingBasedParser extends PluginParser {
if (StringUtils.isBlank(embeddingConfig.getUrl())) {
return false;
}
List<Plugin> plugins = getPluginList(queryContext);
if (CollectionUtils.isEmpty(plugins)) {
return false;
}
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
for (SemanticQuery semanticQuery : semanticQueries) {
if (queryContext.getRequest().getQueryText().length() <= semanticQuery.getParseInfo().getScore()) {

View File

@@ -11,13 +11,13 @@ public class EmbeddingConfig {
@Value("${embedding.url:}")
private String url;
@Value("${embedding.recognize.path:preset_query_retrival}")
@Value("${embedding.recognize.path:/preset_query_retrival}")
private String recognizePath;
@Value("${embedding.delete.path:preset_delete_by_ids}")
@Value("${embedding.delete.path:/preset_delete_by_ids}")
private String deletePath;
@Value("${embedding.add.path:preset_query_add}")
@Value("${embedding.add.path:/preset_query_add}")
private String addPath;
@Value("${embedding.nResult:1}")

View File

@@ -44,7 +44,8 @@ public class FunctionBasedParser extends PluginParser {
queryContext.getRequest().getQueryText());
return false;
}
return true;
List<Plugin> plugins = getPluginList(queryContext);
return !CollectionUtils.isEmpty(plugins);
}
@Override
@@ -82,7 +83,6 @@ public class FunctionBasedParser extends PluginParser {
log.info("function call parser, plugin is empty, skip");
return null;
}
FunctionCallInfoConfig functionCallConfig = ContextUtils.getBean(FunctionCallInfoConfig.class);
FunctionResp functionResp = new FunctionResp();
if (pluginToFunctionCall.size() == 1) {
functionResp.setToolSelection(pluginToFunctionCall.iterator().next().getName());
@@ -90,7 +90,7 @@ public class FunctionBasedParser extends PluginParser {
FunctionReq functionReq = FunctionReq.builder()
.queryText(queryContext.getRequest().getQueryText())
.pluginConfigs(pluginToFunctionCall).build();
functionResp = requestFunction(functionCallConfig.getUrl(), functionReq);
functionResp = requestFunction(functionReq);
}
return functionResp;
}
@@ -133,7 +133,9 @@ public class FunctionBasedParser extends PluginParser {
return functionDOList;
}
public FunctionResp requestFunction(String url, FunctionReq functionReq) {
public FunctionResp requestFunction(FunctionReq functionReq) {
FunctionCallInfoConfig functionCallInfoConfig = ContextUtils.getBean(FunctionCallInfoConfig.class);
String url = functionCallInfoConfig.getUrl() + functionCallInfoConfig.getPluginSelectPath();
HttpHeaders headers = new HttpHeaders();
long startTime = System.currentTimeMillis();
headers.setContentType(MediaType.APPLICATION_JSON);

View File

@@ -43,6 +43,9 @@ public class AgentCheckParser implements SemanticParser {
if (!tool.getQueryModes().contains(query.getQueryMode())) {
return true;
}
if (CollectionUtils.isEmpty(tool.getModelIds())) {
return true;
}
if (tool.isContainsAllModel() || tool.getModelIds().contains(query.getParseInfo().getModelId())) {
return false;
}

View File

@@ -1,5 +1,12 @@
package com.tencent.supersonic.chat.parser.rule;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.MODEL;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
@@ -8,8 +15,8 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
import java.util.AbstractMap;
import java.util.ArrayList;
@@ -21,13 +28,6 @@ import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.extern.slf4j.Slf4j;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.MODEL;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
@Slf4j
public class ContextInheritParser implements SemanticParser {
@@ -97,10 +97,10 @@ public class ContextInheritParser implements SemanticParser {
}
protected boolean shouldInherit(QueryContext queryContext, ChatContext chatContext) {
Long contextmodelId = chatContext.getParseInfo().getModelId();
Long contextModelId = chatContext.getParseInfo().getModelId();
// if map info doesn't contain the same Model of the context,
// no inheritance could be done
if (queryContext.getMapInfo().getMatchedElements(contextmodelId) == null) {
if (queryContext.getMapInfo().getMatchedElements(contextModelId) == null) {
return false;
}

View File

@@ -6,8 +6,13 @@ import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public class QueryModeParser implements SemanticParser {
@@ -24,6 +29,14 @@ public class QueryModeParser implements SemanticParser {
queryContext.getCandidateQueries().add(query);
}
}
// if modelElementMatches id empty,so remove it.
Map<Long, List<SchemaElementMatch>> filterModelElementMatches = new HashMap<>();
for (Long modelId : queryContext.getMapInfo().getModelElementMatches().keySet()) {
if (!CollectionUtils.isEmpty(queryContext.getMapInfo().getModelElementMatches().get(modelId))) {
filterModelElementMatches.put(modelId, queryContext.getMapInfo().getModelElementMatches().get(modelId));
}
}
queryContext.getMapInfo().setModelElementMatches(filterModelElementMatches);
}
}

View File

@@ -153,6 +153,7 @@ public class TimeRangeParser implements SemanticParser {
}
info.setDetectWord(detectWord);
info.setStartDate(LocalDate.now().minusDays(days).toString());
info.setEndDate(LocalDate.now().minusDays(1).toString());
info.setUnit(num);
return info;

View File

@@ -55,4 +55,8 @@ public class Plugin extends RecordInfo {
return CollectionUtils.isNotEmpty(modelList) && modelList.contains(-1L);
}
public Long getDefaultMode() {
return -1L;
}
}

View File

@@ -15,6 +15,7 @@ import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
import com.tencent.supersonic.chat.plugin.event.PluginAddEvent;
import com.tencent.supersonic.chat.plugin.event.PluginDelEvent;
import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.query.plugin.ParamOption;
import com.tencent.supersonic.chat.query.plugin.WebBase;
@@ -116,8 +117,8 @@ public class PluginManager {
}
@EventListener
public void delPlugin(PluginAddEvent pluginAddEvent) {
Plugin plugin = pluginAddEvent.getPlugin();
public void delPlugin(PluginDelEvent pluginDelEvent) {
Plugin plugin = pluginDelEvent.getPlugin();
if (CollectionUtils.isNotEmpty(plugin.getExampleQuestionList())) {
requestEmbeddingPluginDelete(getEmbeddingId(Lists.newArrayList(plugin)));
}
@@ -142,18 +143,22 @@ public class PluginManager {
return ResponseEntity.of(Optional.empty());
}
String url = embeddingConfig.getUrl() + path;
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setLocation(URI.create(url));
URI requestUrl = UriComponentsBuilder
.fromHttpUrl(url).build().encode().toUri();
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
log.info("[embedding] equest body :{}, url:{}", jsonBody, url);
ResponseEntity<String> responseEntity =
restTemplate.exchange(requestUrl, HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {
});
log.info("[embedding] result body:{}", responseEntity);
return responseEntity;
try {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setLocation(URI.create(url));
URI requestUrl = UriComponentsBuilder
.fromHttpUrl(url).build().encode().toUri();
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
log.info("[embedding] equest body :{}, url:{}", jsonBody, url);
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {});
log.info("[embedding] result body:{}", responseEntity);
return responseEntity;
} catch (Throwable e) {
log.warn("connect to embedding service failed, url:{}", url);
}
return ResponseEntity.of(Optional.empty());
}
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
@@ -298,7 +303,7 @@ public class PluginManager {
private static Set<Long> getPluginMatchedModel(Plugin plugin, QueryContext queryContext) {
Set<Long> matchedModel = queryContext.getMapInfo().getMatchedModels();
if (plugin.isContainsAllModel()) {
return matchedModel;
return Sets.newHashSet(plugin.getDefaultMode());
}
List<Long> modelIds = plugin.getModelList();
Set<Long> pluginMatchedModel = Sets.newHashSet();

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.metricinterpret;
package com.tencent.supersonic.chat.query.llm.interpret;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.metricinterpret;
package com.tencent.supersonic.chat.query.llm.interpret;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.metricinterpret;
package com.tencent.supersonic.chat.query.llm.interpret;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;

View File

@@ -11,6 +11,7 @@ import com.tencent.supersonic.chat.query.plugin.ParamOption;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.plugin.WebBase;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -56,9 +57,18 @@ public class WebServiceQuery extends PluginSemanticQuery {
PluginParseResult pluginParseResult = JsonUtil.toObject(
JsonUtil.toString(properties.get(Constants.CONTEXT)), PluginParseResult.class);
WebServiceResponse webServiceResponse = buildResponse(pluginParseResult);
queryResult.setResponse(webServiceResponse);
queryResult.setQueryState(QueryState.SUCCESS);
//parseInfo.setProperties(null);
Object object = webServiceResponse.getResult();
// in order to show webServiceQuery result int frontend conveniently,
// webServiceResponse result format is consistent with queryByStruct result.
log.info("webServiceResponse result:{}", JsonUtil.toString(object));
try {
Map<String, Object> data = JsonUtil.toMap(JsonUtil.toString(object), String.class, Object.class);
queryResult.setQueryResults((List<Map<String, Object>>) data.get("resultList"));
queryResult.setQueryColumns((List<QueryColumn>) data.get("columns"));
queryResult.setQueryState(QueryState.SUCCESS);
} catch (Exception e) {
log.info("webServiceResponse result has an exception:{}", e.getMessage());
}
return queryResult;
}

View File

@@ -31,6 +31,7 @@ import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
@@ -141,61 +142,44 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
if (!id2Values.isEmpty()) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : id2Values.entrySet()) {
SchemaElement entity = modelSchema.getElement(SchemaElementType.ENTITY, entry.getKey());
if (entry.getValue().size() == 1) {
SchemaElementMatch schemaMatch = entry.getValue().get(0);
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(schemaMatch.getWord());
dimensionFilter.setBizName(entity.getBizName());
dimensionFilter.setName(entity.getName());
dimensionFilter.setOperator(FilterOperatorEnum.EQUALS);
dimensionFilter.setElementID(schemaMatch.getElement().getId());
parseInfo.getDimensionFilters().add(dimensionFilter);
parseInfo.setEntity(modelSchema.getEntity());
} else {
QueryFilter dimensionFilter = new QueryFilter();
List<String> vals = new ArrayList<>();
entry.getValue().stream().forEach(i -> vals.add(i.getWord()));
dimensionFilter.setValue(vals);
dimensionFilter.setBizName(entity.getBizName());
dimensionFilter.setName(entity.getName());
dimensionFilter.setOperator(FilterOperatorEnum.IN);
dimensionFilter.setElementID(entry.getKey());
parseInfo.getDimensionFilters().add(dimensionFilter);
}
addFilters(parseInfo, modelSchema, entry, SchemaElementType.ENTITY);
}
}
if (!dim2Values.isEmpty()) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dim2Values.entrySet()) {
SchemaElement dimension = modelSchema.getElement(SchemaElementType.DIMENSION, entry.getKey());
if (entry.getValue().size() == 1) {
SchemaElementMatch schemaMatch = entry.getValue().get(0);
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(schemaMatch.getWord());
dimensionFilter.setBizName(dimension.getBizName());
dimensionFilter.setName(dimension.getName());
dimensionFilter.setOperator(FilterOperatorEnum.EQUALS);
dimensionFilter.setElementID(schemaMatch.getElement().getId());
parseInfo.getDimensionFilters().add(dimensionFilter);
parseInfo.setEntity(modelSchema.getEntity());
} else {
QueryFilter dimensionFilter = new QueryFilter();
List<String> vals = new ArrayList<>();
entry.getValue().stream().forEach(i -> vals.add(i.getWord()));
dimensionFilter.setValue(vals);
dimensionFilter.setBizName(dimension.getBizName());
dimensionFilter.setName(dimension.getName());
dimensionFilter.setOperator(FilterOperatorEnum.IN);
dimensionFilter.setElementID(entry.getKey());
parseInfo.getDimensionFilters().add(dimensionFilter);
}
addFilters(parseInfo, modelSchema, entry, SchemaElementType.DIMENSION);
}
}
}
private void addFilters(SemanticParseInfo parseInfo, ModelSchema modelSchema,
Entry<Long, List<SchemaElementMatch>> entry, SchemaElementType dimension1) {
SchemaElement dimension = modelSchema.getElement(dimension1, entry.getKey());
if (entry.getValue().size() == 1) {
SchemaElementMatch schemaMatch = entry.getValue().get(0);
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(schemaMatch.getWord());
dimensionFilter.setBizName(dimension.getBizName());
dimensionFilter.setName(dimension.getName());
dimensionFilter.setOperator(FilterOperatorEnum.EQUALS);
dimensionFilter.setElementID(schemaMatch.getElement().getId());
parseInfo.getDimensionFilters().add(dimensionFilter);
parseInfo.setEntity(modelSchema.getEntity());
} else {
QueryFilter dimensionFilter = new QueryFilter();
List<String> vals = new ArrayList<>();
entry.getValue().stream().forEach(i -> vals.add(i.getWord()));
dimensionFilter.setValue(vals);
dimensionFilter.setBizName(dimension.getBizName());
dimensionFilter.setName(dimension.getName());
dimensionFilter.setOperator(FilterOperatorEnum.IN);
dimensionFilter.setElementID(entry.getKey());
parseInfo.getDimensionFilters().add(dimensionFilter);
}
}
@Override
public QueryResult execute(User user) {
@@ -292,7 +276,6 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
}
protected QueryStructReq convertQueryStruct() {
return QueryReqBuilder.buildStructReq(parseInfo);
}

View File

@@ -4,29 +4,35 @@ package com.tencent.supersonic.chat.rest;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.service.DictionaryService;
import com.tencent.supersonic.knowledge.listener.ApplicationStartedListener;
import com.tencent.supersonic.knowledge.dictionary.DictTaskFilter;
import com.tencent.supersonic.knowledge.dictionary.DimValue2DictCommand;
import com.tencent.supersonic.knowledge.dictionary.DimValueDictInfo;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PutMapping;
@RestController
@RequestMapping("/api/chat/dict")
public class DictionaryController {
public class KnowledgeController {
@Autowired
private DictionaryService dictApplicationService;
@Autowired
private ApplicationStartedListener applicationStartedListener;
/**
* addDictInfo
*
@@ -34,8 +40,8 @@ public class DictionaryController {
*/
@PostMapping("/task")
public Long addDictTask(@RequestBody DimValue2DictCommand dimValue2DictCommend,
HttpServletRequest request,
HttpServletResponse response) {
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return dictApplicationService.addDictTask(dimValue2DictCommend, user);
}
@@ -47,8 +53,8 @@ public class DictionaryController {
*/
@DeleteMapping("/task")
public Long deleteDictTask(@RequestBody DimValue2DictCommand dimValue2DictCommend,
HttpServletRequest request,
HttpServletResponse response) {
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return dictApplicationService.deleteDictTask(dimValue2DictCommend, user);
}
@@ -60,16 +66,22 @@ public class DictionaryController {
*/
@PostMapping("/task/search")
public List<DimValueDictInfo> searchDictTaskList(@RequestBody DictTaskFilter filter,
HttpServletRequest request,
HttpServletResponse response) {
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return dictApplicationService.searchDictTaskList(filter, user);
}
@GetMapping("/rootPath")
public String getDictRootPath(HttpServletRequest request,
HttpServletResponse response) {
HttpServletResponse response) {
return dictApplicationService.getDictRootPath();
}
}
@PutMapping("/knowledge/dimValue")
public Boolean updateDimValue(HttpServletRequest request,
HttpServletResponse response) {
return applicationStartedListener.updateKnowledgeDimValue();
}
}

View File

@@ -5,8 +5,10 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.request.ItemNameVisibilityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.config.ChatConfig;
import java.util.List;
@@ -16,6 +18,10 @@ public interface ConfigService {
Long editConfig(ChatConfigEditReqReq extendEditCmd, User user);
ItemNameVisibilityInfo getItemNameVisibility(ChatConfig chatConfig);
ItemNameVisibilityInfo getVisibilityByModelId(Long modelId);
List<ChatConfigResp> search(ChatConfigFilter filter, User user);
ChatConfigRichResp getConfigRichInfo(Long modelId);

View File

@@ -5,15 +5,16 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility;
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ItemNameVisibilityInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.request.Entity;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.Entity;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
@@ -27,6 +28,7 @@ import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.persistence.repository.ChatConfigRepository;
import com.tencent.supersonic.chat.utils.ChatConfigHelper;
import com.tencent.supersonic.chat.utils.VisibilityEvent;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.ArrayList;
@@ -36,9 +38,14 @@ import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.model.domain.DimensionService;
import com.tencent.supersonic.semantic.model.domain.MetricService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@@ -49,16 +56,24 @@ public class ConfigServiceImpl implements ConfigService {
private final ChatConfigRepository chatConfigRepository;
private final ChatConfigHelper chatConfigHelper;
private final DimensionService dimensionService;
private final MetricService metricService;
@Autowired
private SemanticService semanticService;
@Autowired
private ApplicationEventPublisher applicationEventPublisher;
private SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
public ConfigServiceImpl(ChatConfigRepository chatConfigRepository,
ChatConfigHelper chatConfigHelper) {
ChatConfigHelper chatConfigHelper,
DimensionService dimensionService,
MetricService metricService) {
this.chatConfigRepository = chatConfigRepository;
this.chatConfigHelper = chatConfigHelper;
this.dimensionService = dimensionService;
this.metricService = metricService;
}
@Override
@@ -68,6 +83,7 @@ public class ConfigServiceImpl implements ConfigService {
permissionCheckLogic(configBaseCmd.getModelId(), user.getName());
ChatConfig chaConfig = chatConfigHelper.newChatConfig(configBaseCmd, user);
Long id = chatConfigRepository.createConfig(chaConfig);
applicationEventPublisher.publishEvent(new VisibilityEvent(this, chaConfig));
return id;
}
@@ -91,9 +107,56 @@ public class ConfigServiceImpl implements ConfigService {
permissionCheckLogic(configEditCmd.getModelId(), user.getName());
ChatConfig chaConfig = chatConfigHelper.editChatConfig(configEditCmd, user);
chatConfigRepository.updateConfig(chaConfig);
applicationEventPublisher.publishEvent(new VisibilityEvent(this, chaConfig));
return configEditCmd.getId();
}
public ItemNameVisibilityInfo getVisibilityByModelId(Long modelId) {
ChatConfigResp chatConfigResp = fetchConfigByModelId(modelId);
ChatConfig chatConfig = new ChatConfig();
chatConfig.setModelId(modelId);
chatConfig.setChatAggConfig(chatConfigResp.getChatAggConfig());
chatConfig.setChatDetailConfig(chatConfigResp.getChatDetailConfig());
ItemNameVisibilityInfo itemNameVisibility = getItemNameVisibility(chatConfig);
return itemNameVisibility;
}
public ItemNameVisibilityInfo getItemNameVisibility(ChatConfig chatConfig) {
Long modelId = chatConfig.getModelId();
List<Long> blackDimIdList = new ArrayList<>();
if (Objects.nonNull(chatConfig.getChatAggConfig()) && Objects.nonNull(chatConfig.getChatAggConfig())) {
blackDimIdList.addAll(chatConfig.getChatAggConfig().getVisibility().getBlackDimIdList());
}
if (Objects.nonNull(chatConfig.getChatDetailConfig()) && Objects.nonNull(chatConfig.getChatDetailConfig())) {
blackDimIdList.addAll(chatConfig.getChatDetailConfig().getVisibility().getBlackDimIdList());
}
List<Long> filterDimIdList = blackDimIdList.stream().distinct().collect(Collectors.toList());
List<Long> blackMetricIdList = new ArrayList<>();
if (Objects.nonNull(chatConfig.getChatAggConfig()) && Objects.nonNull(chatConfig.getChatAggConfig())) {
blackMetricIdList.addAll(chatConfig.getChatAggConfig().getVisibility().getBlackMetricIdList());
}
if (Objects.nonNull(chatConfig.getChatDetailConfig()) && Objects.nonNull(chatConfig.getChatDetailConfig())) {
blackMetricIdList.addAll(chatConfig.getChatDetailConfig().getVisibility().getBlackMetricIdList());
}
List<Long> filterMetricIdList = blackMetricIdList.stream().distinct().collect(Collectors.toList());
ItemNameVisibilityInfo itemNameVisibility = new ItemNameVisibilityInfo();
if (!CollectionUtils.isEmpty(blackDimIdList)) {
List<DimensionResp> dimensionRespList = dimensionService.getDimensions(modelId);
List<String> blackDimNameList = dimensionRespList.stream().filter(o -> filterDimIdList.contains(o.getId()))
.map(o -> o.getName()).collect(Collectors.toList());
itemNameVisibility.setBlackDimNameList(blackDimNameList);
}
if (!CollectionUtils.isEmpty(blackMetricIdList)) {
List<MetricResp> metricRespList = metricService.getMetrics(modelId);
List<String> blackMetricList = metricRespList.stream().filter(o -> filterMetricIdList.contains(o.getId()))
.map(o -> o.getName()).collect(Collectors.toList());
itemNameVisibility.setBlackMetricNameList(blackMetricList);
}
return itemNameVisibility;
}
/**
* model administrators have the right to modify related configuration information.

View File

@@ -247,7 +247,9 @@ public class QueryServiceImpl implements QueryService {
public QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws SqlParseException {
SemanticQuery semanticQuery = QueryManager.createRuleQuery(queryData.getQueryMode());
BeanUtils.copyProperties(queryData, semanticQuery.getParseInfo());
return semanticQuery.execute(user);
QueryResult queryResult = semanticQuery.execute(user);
queryResult.setChatContext(semanticQuery.getParseInfo());
return queryResult;
}
@Override
@@ -274,7 +276,7 @@ public class QueryServiceImpl implements QueryService {
List<String> groups = new ArrayList<>();
groups.add(dimensionValueReq.getBizName());
queryStructReq.setGroups(groups);
if (Objects.isNull(dimensionValueReq.getValue())) {
if (!Objects.isNull(dimensionValueReq.getValue())) {
List<Filter> dimensionFilters = new ArrayList<>();
Filter dimensionFilter = new Filter();
dimensionFilter.setOperator(FilterOperatorEnum.LIKE);

View File

@@ -1,16 +1,19 @@
package com.tencent.supersonic.chat.service.impl;
import com.github.benmanes.caffeine.cache.Cache;
import com.google.common.collect.Lists;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.ItemNameVisibilityInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SearchResult;
import com.tencent.supersonic.chat.mapper.MapperHelper;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.dictionary.ModelInfoStat;
import com.tencent.supersonic.chat.mapper.ModelWithSemanticType;
@@ -22,7 +25,7 @@ import com.tencent.supersonic.chat.service.SearchService;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import java.util.ArrayList;
@@ -40,6 +43,7 @@ import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
@@ -59,6 +63,12 @@ public class SearchServiceImpl implements SearchService {
private SearchMatchStrategy searchMatchStrategy;
@Autowired
private AgentService agentService;
@Autowired
@Qualifier("searchCaffeineCache")
private Cache<Long, Object> caffeineCache;
@Autowired
private ConfigService configService;
@Override
public List<SearchResult> search(QueryReq queryReq) {
@@ -80,7 +90,7 @@ public class SearchServiceImpl implements SearchService {
// 3.detect by segment
List<Term> originals = HanlpHelper.getTerms(queryText);
log.info("hanlp parse result: {}", originals);
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
Set<Long> detectModelIds = mapperHelper.getModelIds(queryReq);
@@ -94,7 +104,6 @@ public class SearchServiceImpl implements SearchService {
.reduce((entry1, entry2) ->
entry1.getKey().getDetectSegment().length() >= entry2.getKey().getDetectSegment().length()
? entry1 : entry2);
log.debug("mostSimilarSearchResult:{}", mostSimilarSearchResult);
// 5.optimize the results after the query
if (!mostSimilarSearchResult.isPresent()) {
@@ -194,7 +203,16 @@ public class SearchServiceImpl implements SearchService {
}
searchResults.add(searchResult);
int metricSize = getMetricSize(natureToNameMap);
List<String> metrics = filerMetricsByModel(metricsDb, modelId, metricSize);
//invisibility to filter metrics
ItemNameVisibilityInfo itemNameVisibility = (ItemNameVisibilityInfo) caffeineCache.getIfPresent(modelId);
if (itemNameVisibility == null) {
itemNameVisibility = configService.getVisibilityByModelId(modelId);
caffeineCache.put(modelId, itemNameVisibility);
}
List<String> blackMetricNameList = itemNameVisibility.getBlackMetricNameList();
List<String> metrics = filerMetricsByModel(metricsDb, modelId, metricSize * 3)
.stream().filter(o -> !blackMetricNameList.contains(o))
.limit(metricSize).collect(Collectors.toList());
for (String metric : metrics) {
SearchResult result = SearchResult.builder()
@@ -279,7 +297,7 @@ public class SearchServiceImpl implements SearchService {
private boolean searchMetricAndDimension(Set<Long> possibleModels, Map<Long, String> modelToName,
Map.Entry<MatchText, List<MapResult>> searchTextEntry, Set<SearchResult> searchResults) {
boolean existMetric = false;
log.info("searchMetricAndDimension searchTextEntry:{}", searchTextEntry);
MatchText matchText = searchTextEntry.getKey();
List<MapResult> mapResults = searchTextEntry.getValue();
@@ -297,7 +315,6 @@ public class SearchServiceImpl implements SearchService {
existMetric = true;
Long modelId = modelWithSemanticType.getModel();
SchemaElementType semanticType = modelWithSemanticType.getSemanticType();
SearchResult searchResult = SearchResult.builder()
.modelId(modelId)
.modelName(modelToName.get(modelId))
@@ -305,12 +322,25 @@ public class SearchServiceImpl implements SearchService {
.subRecommend(mapResult.getName())
.schemaElementType(semanticType)
.build();
searchResults.add(searchResult);
//visibility to filter metrics
ItemNameVisibilityInfo visibility = (ItemNameVisibilityInfo) caffeineCache.getIfPresent(modelId);
if (visibility == null) {
visibility = configService.getVisibilityByModelId(modelId);
caffeineCache.put(modelId, visibility);
}
if (semanticType.equals(SchemaElementType.DIMENSION)
&& !visibility.getBlackDimNameList().contains(mapResult.getName())) {
searchResults.add(searchResult);
}
if (semanticType.equals(SchemaElementType.METRIC)
&& !visibility.getBlackMetricNameList().contains(mapResult.getName())) {
searchResults.add(searchResult);
}
}
log.info("parseResult:{},dimensionMetricClassIds:{},possibleModels:{}", mapResult, dimensionMetricClassIds,
possibleModels);
}
log.info("searchMetricAndDimension searchResults:{}", searchResults);
return existMetric;
}

View File

@@ -0,0 +1,22 @@
package com.tencent.supersonic.chat.utils;
import com.tencent.supersonic.chat.config.ChatConfig;
import org.springframework.context.ApplicationEvent;
public class VisibilityEvent extends ApplicationEvent {
private static final long serialVersionUID = 1L;
private ChatConfig chatConfig;
public VisibilityEvent(Object source, ChatConfig chatConfig) {
super(source);
this.chatConfig = chatConfig;
}
public void setChatConfig(ChatConfig chatConfig) {
this.chatConfig = chatConfig;
}
public ChatConfig getChatConfig() {
return chatConfig;
}
}

View File

@@ -0,0 +1,30 @@
package com.tencent.supersonic.chat.utils;
import com.github.benmanes.caffeine.cache.Cache;
import com.tencent.supersonic.chat.api.pojo.request.ItemNameVisibilityInfo;
import com.tencent.supersonic.chat.service.ConfigService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.ApplicationListener;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class VisibilityListener implements ApplicationListener<VisibilityEvent> {
@Autowired
@Qualifier("searchCaffeineCache")
private Cache<Long, Object> caffeineCache;
@Autowired
private ConfigService configService;
@Override
public void onApplicationEvent(VisibilityEvent event) {
log.info("visibility has changed,so update cache!");
ItemNameVisibilityInfo itemNameVisibility = configService.getItemNameVisibility(event.getChatConfig());
log.info("itemNameVisibility :{}", itemNameVisibility);
caffeineCache.put(event.getChatConfig().getModelId(), itemNameVisibility);
}
}

View File

@@ -25,31 +25,20 @@ app = FastAPI()
@app.post("/query2sql/")
async def din_query2sql(query_body: Mapping[str, Any]):
if 'queryText' not in query_body:
raise HTTPException(status_code=400,
if 'queryText' not in query_body:
raise HTTPException(status_code=400,
detail="query_text is not in query_body")
else:
query_text = query_body['queryText']
else:
query_text = query_body['queryText']
if 'schema' not in query_body:
raise HTTPException(status_code=400, detail="schema is not in query_body")
else:
schema = query_body['schema']
if 'schema' not in query_body:
raise HTTPException(status_code=400, detail="schema is not in query_body")
else:
schema = query_body['schema']
if 'currentDate' not in query_body:
raise HTTPException(status_code=400, detail="currentDate is not in query_body")
else:
current_date = query_body['currentDate']
resp = query2sql(query_text=query_text, schema=schema)
if 'linking' not in query_body:
linking = None
else:
linking = query_body['linking']
resp = query2sql(query_text=query_text,
schema=schema, current_date=current_date, linking=linking)
return resp
return resp
@app.post("/preset_query_retrival/")

View File

@@ -1,296 +1,147 @@
examplars= [
{ "current_date":"2020-12-01",
examplars= [
{
"table_name":"内容库产品",
"fields_list":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",
"question":"比较jackjchen和robinlee在内容库的访问次数",
"prior_schema_links":"""['jackjchen'->用户名, 'robinlee'->用户名]""",
"analysis": """让我们一步一步地思考。在问题“比较jackjchen和robinlee在内容库的访问次数“中我们被问
比较jackjchen和robinlee”所以我们需要column=[用户名]
”内容库的访问次数“所以我们需要column=[访问次数]
基于table和columns可能的cell values 是 = ['jackjchen', 'robinlee']""",
"schema_links":"""["用户名", "访问次数", "'jackjchen'", "'robinlee'"]""",
"sql":"""select 用户名, 访问次数 from 内容库产品 where 用户名 in ('jackjchen', 'robinlee') and 数据日期 = '2020-12-01' """
},
{ "current_date":"2022-11-06",
"fields_list":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长"]""",
"question":"比较jerry和tom在内容库的访问次数",
"analysis": """让我们一步一步地思考。在问题“比较jerry和tom在内容库的访问次数“中我们被问
“内容库的访问次数”所以我们需要column=[访问次数]
比较jerry和tom“所以我们需要column=[用户名]
基于table和columns可能的cell values 是 = ['jerry', 'tom']。""",
"schema_links":"""["访问次数", "用户名", "'jerry'", "'tom'"]""",
"sql":"""select 用户名, 访问次数 from 内容库产品 where 用户名 in ('jerry', 'tom')"""
},
{
"table_name":"内容库产品",
"fields_list":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",
"fields_list":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长"]""",
"question":"内容库近12个月访问人数 按部门",
"prior_schema_links":"""[]""",
"analysis": """让我们一步一步地思考。在问题“内容库近12个月访问人数 按部门“中,我们被问:
内容库近12个月所以我们需要column=[数据日期]
“访问人数”所以我们需要column=[访问人数]
内容库近12个月访问人数”所以我们需要column=[访问人数]
”按部门“所以我们需要column=[部门]
基于table和columns可能的cell values 是 = [12]。""",
"schema_links":"""["访问人数", "部门", "数据日期", 12]""",
"sql":"""select 部门, 数据日期, 访问人数 from 内容库产品 where datediff('month', 数据日期, '2022-11-06') <= 12 """
},
{ "current_date":"2023-04-21",
基于table和columns可能的cell values 是 = []。""",
"schema_links":"""["访问人数", "部门"]""",
"sql":"""select 部门, sum(访问人数) from 内容库产品 where 部门 group by 部门"""
},
{
"table_name":"内容库产品",
"fields_list":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",
"question":"内容库美术部、技术研发部的访问时长",
"prior_schema_links":"""['美术部'->部门, '技术研发部'->部门]""",
"analysis": """让我们一步一步地思考。在问题“内容库美术部、技术研发部的访问时长“中,我们被问:
"fields_list":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长"]""",
"question":"内容库编辑部、美术部的访问时长",
"analysis": """让我们一步一步地思考。在问题“内容库编辑部、美术部的访问时长“中,我们被问:
“访问时长”所以我们需要column=[访问时长]
”内容库美术部、技术研发部“所以我们需要column=[部门]
基于table和columns可能的cell values 是 = ['美术', '技术研发']。""",
"schema_links":"""["访问时长", "部门", "'美术'", "'技术研发'"]""",
"sql":"""select 部门, 访问时长 from 内容库产品 where 部门 in ('美术', '技术研发') and 数据日期 = '2023-04-21' """
},
{ "current_date":"2023-08-21",
"table_name":"",
"fields_list":"""["严选版权归属系", "付费模式", "结算播放份额", "付费用户结算播放份额", "数据日期"]""",
"question":"近3天海田飞系MPPM结算播放份额",
"prior_schema_links":"""['海田飞系'->严选版权归属系]""",
"analysis": """让我们一步一步地思考。在问题“近3天海田飞系MPPM结算播放份额“中我们被问
“MPPM结算播放份额所以我们需要column=[结算播放份额]
”海田飞系“所以我们需要column=[严选版权归属系]
”近3天“所以我们需要column=[数据日期]
基于table和columns可能的cell values 是 = ['海田飞系', 3]。""",
"schema_links":"""["结算播放份额", "严选版权归属系", "数据日期", "'海田飞系'", 3]""",
"sql":"""select 严选版权归属系, 结算播放份额 from 严选 where 严选版权归属系 = '海田飞系' and datediff('day', 数据日期, '2023-08-21') <= 3 """
},
{ "current_date":"2023-05-22",
”内容库编辑部、美术部“所以我们需要column=[部门]
基于table和columns可能的cell values 是 = ['编辑', '美术']。""",
"schema_links":"""["访问时长", "部门", "'编辑'", "'美术'"]""",
"sql":"""select 部门, 访问时长 from 内容库产品 where 部门 in ('编辑', '美术')"""
},
{
"table_name":"",
"fields_list":"""['归属系', '付费模式', '结算播放份额', '付费用户结算播放份额']""",
"question":"近3天飞天系结算播放份额",
"analysis": """让我们一步一步地思考。在问题“近3天飞天系结算播放份额“中我们被问
“结算播放份额”所以我们需要column=[结算播放份额]
飞天系“所以我们需要column=[归属系]
基于table和columns可能的cell values 是 = ['飞天系']。""",
"schema_links":"""["结算播放份额", "归属系", "'飞天系'"]""",
"sql":"""select 归属系, 结算播放份额 from 精选 where 归属系 in ('')"""
},
{
"table_name":"歌曲库",
"fields_list":"""["是否潮流人歌曲", "C音歌曲ID", "C音歌曲MID", "歌曲名", "歌曲版本", "语种", "歌曲类型", "翻唱类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "结算播放量", "运营播放量", "付费用户结算播放量", "历史累计结算播放量", "运营搜播量", "结算搜播量", "运营完播量", "运营推播量", "近7日复播率", "日均搜播量", "数据日期"]""",
"question":"对比近7天翻唱版和纯音乐的歌曲播放量",
"prior_schema_links":"""['纯音乐'->语种, '翻唱版'->歌曲版本]""",
"fields_list":"""['歌曲ID', '歌曲MID', '歌曲名', '歌曲版本', '歌曲类型', '翻唱类型', '结算播放量', '运营播放量', '付费用户结算播放量', '历史累计结算播放量', '运营搜播量', '结算搜播量', '运营完播量', '运营推播量', '近7日复播率', '日均搜播量']""",
"question":"对比近3天翻唱版和纯音乐的歌曲播放量",
"analysis": """让我们一步一步地思考。在问题“对比近3天翻唱版和纯音乐的歌曲播放量“中我们被问
“歌曲播放量”所以我们需要column=[结算播放量]
”翻唱版“所以我们需要column=[歌曲版本]
”和纯音乐的歌曲“所以我们需要column=[语种]
”近7天“所以我们需要column=[数据日期]
基于table和columns可能的cell values 是 = ['翻唱版', '纯音乐', 7]。""",
"schema_links":"""["结算播放量", "歌曲版本", "语种", "数据日期", "'翻唱版'", "'纯音乐'", 7]""",
"sql":"""select 歌曲版本, 语种, 结算播放量 from 歌曲库 where 歌曲版本 = '翻唱版' and 语种 = '纯音乐' and datediff('day', 数据日期, '2023-05-22') <= 7 """
},
{ "current_date":"2023-05-31",
”翻唱版和纯音乐所以我们需要column=[歌曲类型]
基于table和columns可能的cell values 是 = ['翻唱版', '纯音乐']。""",
"schema_links":"""["结算播放量", "歌曲类型", "'翻唱版'", "'纯音乐'"]""",
"sql":"""select 歌曲类型, 结算播放量 from 歌曲库 where 歌曲类型 in ('翻唱版', '纯音乐')"""
},
{
"table_name":"艺人库",
"fields_list":"""["上下架状态", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "活跃区域", "年龄", "歌手才能", "歌手风格", "粉丝数", "潮音粉丝数", "超声波粉丝数", "推博粉丝数", "超声波歌曲数", "在架歌曲数", "超声波分享数", "独占歌曲数", "超声波在架歌曲评论", "有播放量歌曲数", "数据日期"]""",
"question":"对比一下陈拙悬、孟梅琦、赖媚韵的粉丝数",
"prior_schema_links":"""['1527896'->MPPM歌手ID, '1565463'->MPPM歌手ID, '2141459'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“对比一下陈拙悬、孟梅琦、赖媚韵的粉丝数“中,我们被问:
"fields_list":"""['上下架状态', '歌手名', '歌手等级', '歌手类型', '歌手来源', '活跃区域', '年龄', '歌手才能', '歌手风格', '粉丝数', '在架歌曲数', '有播放量歌曲数']""",
"question":"对比一下流得滑、锅富程、章雪友的粉丝数",
"analysis": """让我们一步一步地思考。在问题“对比一下流得滑、锅富程、章雪友的粉丝数“中,我们被问:
“粉丝数”所以我们需要column=[粉丝数]
陈拙悬、孟梅琦、赖媚韵所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = ['陈拙悬', '孟梅琦', '赖媚韵']。""",
"schema_links":"""["粉丝数", "歌手名", "'陈拙悬'", "'孟梅琦'", "'赖媚韵'"]""",
"sql":"""select 歌手名, 粉丝数 from 艺人库 where 歌手名 in ('陈拙悬', '孟梅琦', '赖媚韵') and 数据日期 = '2023-05-31' """
},
{ "current_date":"2023-07-31",
流得滑、锅富程、章雪友所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = ['流得滑', '锅富程', '章雪友']。""",
"schema_links":"""["粉丝数", "歌手名", "'流得滑'", "'锅富程'", "'章雪友'"]""",
"sql":"""select 歌手名, 粉丝数 from 艺人库 where 歌手名 in ('流得滑', '锅富程', '章雪友')"""
},
{
"table_name":"歌曲库",
"fields_list":"""["歌曲", "歌曲版本", "歌曲类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享", "收藏", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享", "结算播放份额", "数据日期"]""",
"fields_list":"""['歌曲ID', '歌曲MID', '歌曲', '歌曲版本', '歌曲类型', '翻唱类型', '结算播放量', '运营播放量', '付费用户结算播放', '历史累计结算播放', '运营搜播量', '结算搜播量', '运营完播量', '运营推播量', '近7日复播', '日均搜播量']""",
"question":"播放量大于1万的歌曲有多少",
"prior_schema_links":"""[]""",
"analysis": """让我们一步一步地思考。在问题“播放量大于1万的歌曲有多少“中我们被问
“歌曲有多少”所以我们需要column=[歌曲名]
”播放量大于1万所以我们需要column=[结算播放量]
”播放量大于1万“所以我们需要column=[结算播放量]
基于table和columns可能的cell values 是 = [10000]。""",
"schema_links":"""["歌曲名", "结算播放量", 10000]""",
"sql":"""select 歌曲名 from 歌曲库 where 结算播放量 > 10000 and 数据日期 = '2023-07-31' """
},
{ "current_date":"2023-07-31",
"sql":"""select 歌曲名 from 歌曲库 where 结算播放量 > 10000"""
},
{
"table_name":"内容库产品",
"fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"fields_list":"""['用户名', '部门', '模块', '访问时长', '访问次数', '访问人数']""",
"question":"内容库访问时长小于1小时且来自美术部的用户是哪些",
"prior_schema_links":"""['美术部'->部门]""",
"analysis": """让我们一步一步地思考。在问题“内容库访问时长小于1小时且来自美术部的用户是哪些“中我们被问
“用户是哪些”所以我们需要column=[用户名]
”美术部的“所以我们需要column=[部门]
”访问时长小于1小时“所以我们需要column=[访问时长]
基于table和columns可能的cell values 是 = ['美术部', 1]。""",
"schema_links":"""["用户名", "部门", "访问时长", "'美术部'", 1]""",
"sql":"""select 用户名 from 内容库产品 where 部门 = '美术部' and 访问时长 < 1 and 数据日期 = '2023-07-31' """
},
{ "current_date":"2023-08-31",
"sql":"""select 用户名 from 内容库产品 where 部门 = '美术部' and 访问时长 < 1"""
},
{
"table_name":"内容库产品",
"fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"fields_list":"""['用户名', '部门', '模块', '访问次数', '访问人数', '访问时长']""",
"question":"内容库pv最高的用户有哪些",
"prior_schema_links":"""[]""",
"analysis": """让我们一步一步地思考。在问题“内容库pv最高的用户有哪些“中我们被问
“用户有哪些”所以我们需要column=[用户名]
”pv最高的“所以我们需要column=[访问次数]
基于table和columns可能的cell values 是 = []。""",
"schema_links":"""["用户名", "访问次数"]""",
"sql":"""select 用户名 from 内容库产品 where 数据日期 = '2023-08-31' order by 访问次数 desc limit 10 """
},
{ "current_date":"2023-08-31",
"sql":"""select 用户名 from 内容库产品 order by 访问次数 desc limit 10"""
},
{
"table_name":"艺人库",
"fields_list":"""["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""",
"question":"近90天袁亚伟播放量平均值是多少",
"prior_schema_links":"""['152789226'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“近90天袁亚伟播放量平均值是多少“中我们被问
"fields_list":"""['歌手名', '歌手等级', '歌手类型', '歌手来源', '结算播放量', '运营播放量', '历史累计结算播放量', '有播放量歌曲数', '历史累计运营播放量', '付费用户结算播放量', '结算播放量占比', '运营播放份额', '完播量']""",
"question":"近90天袁呀味播放量平均值是多少",
"analysis": """让我们一步一步地思考。在问题“近90天袁呀味播放量平均值是多少“中我们被问
“播放量平均值是多少”所以我们需要column=[结算播放量]
”袁亚伟所以我们需要column=[歌手名]
”近90天“所以我们需要column=[数据日期]
基于table和columns可能的cell values 是 = ['亚伟', 90]。""",
"schema_links":"""["结算播放量", "歌手名", "数据日期", "'亚伟'", 90]""",
"sql":"""select avg(结算播放量) from 艺人库 where 歌手名 = '袁亚伟' and datediff('day', 数据日期, '2023-08-31') <= 90 """
},
{ "current_date":"2023-08-31",
”袁呀味所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = ['袁呀味']。""",
"schema_links":"""["结算播放量", "歌手名", "'呀味'"]""",
"sql":"""select avg(结算播放量) from 艺人库 where 歌手名 = '呀味'"""
},
{
"table_name":"艺人库",
"fields_list":"""["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""",
"question":"倩倩近7天结算播放量总和是多少",
"prior_schema_links":"""['199509'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“周倩倩近7天结算播放量总和是多少“中我们被问
"fields_list":"""['歌手名', '歌手等级', '歌手类型', '歌手来源', '结算播放量', '运营播放量', '历史累计结算播放量', '有播放量歌曲数', '历史累计运营播放量', '付费用户结算播放量', '结算播放量占比', '运营播放份额', '完播量']""",
"question":"近7天结算播放量总和是多少",
"analysis": """让我们一步一步地思考。在问题“周浅近7天结算播放量总和是多少“中我们被问
“结算播放量总和是多少”所以我们需要column=[结算播放量]
”周倩倩所以我们需要column=[歌手名]
”近7天“所以我们需要column=[数据日期]
基于table和columns可能的cell values 是 = ['倩倩', 7]。""",
"schema_links":"""["结算播放量", "歌手名", "数据日期", "'倩倩'", 7]""",
"sql":"""select sum(结算播放量) from 艺人库 where 歌手名 = '周倩倩' and datediff('day', 数据日期, '2023-08-31') <= 7 """
},
{ "current_date":"2023-09-14",
”周所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = ['周浅']。""",
"schema_links":"""["结算播放量", "歌手名", "''"]""",
"sql":"""select sum(结算播放量) from 艺人库 where 歌手名 = ''"""
},
{
"table_name":"内容库产品",
"fields_list":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",
"fields_list":"""['部门', '模块', '用户名', '访问次数', '访问人数', '访问时长']""",
"question":"内容库访问次数大于1k的部门是哪些",
"prior_schema_links":"""[]""",
"analysis": """让我们一步一步地思考。在问题“内容库访问次数大于1k的部门是哪些“中我们被问
“部门是哪些”所以我们需要column=[部门]
”访问次数大于1k的“所以我们需要column=[访问次数]
基于table和columns可能的cell values 是 = [1000]。""",
"schema_links":"""["部门", "访问次数", 1000]""",
"sql":"""select 部门 from 内容库产品 where 访问次数 > 1000 and 数据日期 = '2023-09-14' """
},
{ "current_date":"2023-09-18",
"sql":"""select 部门 from 内容库产品 where 访问次数 > 1000"""
},
{
"table_name":"歌曲库",
"fields_list":"""["歌曲", "MPPM歌手ID", "歌曲版本", "歌曲类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享", "收藏", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question":"亿训唱的所有的播放量大于20k的孤勇者有哪些",
"prior_schema_links":"""['199509'->MPPM歌手ID, '1527123'->MPPM歌曲ID]""",
"analysis": """让我们一步一步地思考。在问题“陈亿训唱的所有的播放量大于20k的孤勇者有哪些“中我们被问
“孤勇者有哪些”所以我们需要column=[歌曲名]
"fields_list":"""['歌曲ID', '歌曲MID', '歌曲', '歌曲版本', '歌曲类型', '翻唱类型', '结算播放量', '运营播放量', '付费用户结算播放', '历史累计结算播放', '运营搜播量', '结算搜播量', '运营完播量', '运营推播量', '近7日复播率', '日均搜播量']""",
"question":"奕迅唱的所有的播放量大于20k的雇佣者有哪些",
"analysis": """让我们一步一步地思考。在问题“陈易迅唱的所有的播放量大于20k的雇佣者有哪些“中我们被问
“雇佣者有哪些”所以我们需要column=[歌曲名]
”播放量大于20k的“所以我们需要column=[结算播放量]
”陈亿训唱的“所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = [20000, '亿训', '孤勇者']。""",
"schema_links":"""["歌曲名", "结算播放量", "歌手名", 20000, "'亿训'", "'孤勇者'"]""",
"sql":"""select 歌曲名 from 歌曲库 where 结算播放量 > 20000 and 歌手名 = '亿训' and 歌曲名 = '孤勇者' and 数据日期 = '2023-09-18' """
},
{ "current_date":"2023-09-18",
"table_name":"歌曲库",
"fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question":"周洁轮去年发布的歌曲有哪些",
"prior_schema_links":"""['23109'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“周洁轮去年发布的歌曲有哪些“中,我们被问:
“歌曲有哪些”所以我们需要column=[歌曲名]
”去年发布的“所以我们需要column=[发布时间]
”周洁轮“所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = ['周洁轮', 1]。""",
"schema_links":"""["歌曲名", "发布时间", "歌手名", 1, "'周洁轮'"]""",
"sql":"""select 歌曲名 from 歌曲库 where datediff('year', 发布时间, '2023-09-18') <= 1 and 歌手名 = '周洁轮' and 数据日期 = '2023-09-18' """
},
{ "current_date":"2023-09-11",
"table_name":"艺人库",
"fields_list":"""["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "签约日期", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""",
"question":"我想要近半年签约的播放量前十的歌手有哪些",
"prior_schema_links":"""[]""",
"analysis": """让我们一步一步地思考。在问题“我想要近半年签约的播放量前十的歌手“中,我们被问:
“歌手有哪些”所以我们需要column=[歌手名]
”播放量前十的“所以我们需要column=[结算播放量]
”近半年签约的“所以我们需要column=[签约日期]
基于table和columns可能的cell values 是 = [0.5, 10]。""",
"schema_links":"""["歌手名", "结算播放量", "签约日期", 0.5, 10]""",
"sql":"""select 歌手名 from 艺人库 where datediff('year', 签约日期, '2023-09-11') <= 0.5 and 数据日期 = '2023-09-11' order by 结算播放量 desc limit 10"""
},
{ "current_date":"2023-08-12",
"table_name":"歌曲库",
"fields_list": """["发行日期", "歌曲语言", "歌曲来源", "歌曲流派", "歌曲名", "歌曲版本", "歌曲类型", "发行时间", "数据日期"]""",
"question":"最近一年发行的歌曲中有哪些在近7天播放超过一千万的",
"prior_schema_links":"""[]""",
"analysis": """让我们一步一步地思考。在问题“最近一年发行的歌曲中有哪些在近7天播放超过一千万的“中我们被问
“发行的歌曲中有哪些”所以我们需要column=[歌曲名]
”最近一年发行的“所以我们需要column=[发行日期]
”在近7天播放超过一千万的“所以我们需要column=[数据日期, 结算播放量]
基于table和columns可能的cell values 是 = [1, 10000000]""",
"schema_links":"""["歌曲名", "发行日期", "数据日期", "结算播放量", 1, 10000000]""",
"sql":"""select 歌曲名 from 歌曲库 where datediff('year', 发行日期, '2023-08-12') <= 1 and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000"""
},
{ "current_date":"2023-08-12",
"table_name":"歌曲库",
"fields_list": """["发行日期", "歌曲语言", "歌曲来源", "歌曲流派", "歌曲名", "歌曲版本", "歌曲类型", "发行时间", "数据日期"]""",
"question":"今年以来发行的歌曲中有哪些在近7天播放超过一千万的",
"prior_schema_links":"""[]""",
"analysis": """让我们一步一步地思考。在问题“今年以来发行的歌曲中有哪些在近7天播放超过一千万的“中我们被问
“发行的歌曲中有哪些”所以我们需要column=[歌曲名]
”今年以来发行的“所以我们需要column=[发行日期]
”在近7天播放超过一千万的“所以我们需要column=[数据日期, 结算播放量]
基于table和columns可能的cell values 是 = [0, 7, 10000000]""",
"schema_links":"""["歌曲名", "发行日期", "数据日期", "结算播放量", 0, 7, 10000000]""",
"sql":"""select 歌曲名 from 歌曲库 where datediff('year', 发行日期, '2023-08-12') <= 0 and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000"""
},
{ "current_date":"2023-08-12",
"table_name":"歌曲库",
"fields_list": """["发行日期", "歌曲语言", "歌曲来源", "歌曲流派", "歌曲名", "歌曲版本", "歌曲类型", "发行时间", "数据日期"]""",
"question":"2023年以来发行的歌曲中有哪些在近7天播放超过一千万的",
"prior_schema_links":"""['514129144'->MPPM歌曲ID]""",
"analysis": """让我们一步一步地思考。在问题“2023年以来发行的歌曲中有哪些在近7天播放超过一千万的“中我们被问
“发行的歌曲中有哪些”所以我们需要column=[歌曲名]
”2023年以来发行的“所以我们需要column=[发行日期]
”在近7天播放超过一千万的“所以我们需要column=[数据日期, 结算播放量]
基于table和columns可能的cell values 是 = [2023, 7, 10000000]""",
"schema_links":"""["歌曲名", "发行日期", "数据日期", "结算播放量", 2023, 7, 10000000]""",
"sql":"""select 歌曲名 from 歌曲库 where YEAR(发行日期) >= 2023 and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000"""
},
{ "current_date":"2023-08-01",
"table_name":"歌曲库",
"fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question":"周洁轮2023年6月之后发布的歌曲有哪些",
"prior_schema_links":"""['23109'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“周洁轮2023年6月之后发布的歌曲有哪些“中我们被问
“歌曲有哪些”所以我们需要column=[歌曲名]
”2023年6月之后发布的“所以我们需要column=[发布时间]
”周洁轮“所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = ['周洁轮', 2023, 6]。""",
"schema_links":"""["歌曲名", "发布时间", "歌手名", "周洁轮", 2023, 6]""",
"sql":"""select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 2023 and MONTH(发布时间) >= 6 and 歌手名 = '周洁轮' and 数据日期 = '2023-08-01' """
},
{ "current_date":"2023-08-01",
"table_name":"歌曲库",
"fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question":"邓梓琦在2023年1月5日之后发布的歌曲中有哪些播放量大于500W的",
"prior_schema_links":"""['2312311'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“邓梓琦在2023年1月5日之后发布的歌曲中有哪些播放量大于500W的“中我们被问
“播放量大于500W的”所以我们需要column=[结算播放量]
”邓梓琦在2023年1月5日之后发布的“所以我们需要column=[发布时间]
”邓梓琦“所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = ['邓梓琦', 2023, 1, 5, 5000000]。""",
"schema_links":"""["结算播放量", "发布时间", "歌手名", "邓梓琦", 2023, 1, 5, 5000000]""",
"sql":"""select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 2023 and MONTH(发布时间) >= 1 and DAY(发布时间) >= 5 and 歌手名 = '邓梓琦' and 结算播放量 > 5000000 and 数据日期 = '2023-08-01'"""
},
{ "current_date":"2023-09-17",
"table_name":"歌曲库",
"fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question":"2023年6月以后张亮英播放量大于200万的歌曲有哪些",
"prior_schema_links":"""['45453'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“2023年6月以后张亮英播放量大于200万的歌曲有哪些“中我们被问
“播放量大于200万的”所以我们需要column=[结算播放量]
”2023年6月以后张亮英“所以我们需要column=[数据日期, 歌手名]
”歌曲有哪些“所以我们需要column=[歌曲名]
基于table和columns可能的cell values 是 = ['张亮英', 2023, 6, 2000000]。""",
"schema_links":"""["结算播放量", "数据日期", "歌手名", "张亮英", 2023, 6, 2000000]""",
"sql":"""select 歌曲名 from 歌曲库 where YEAR(数据日期) >= 2023 and MONTH(数据日期) >= 6 and 歌手名 = '张亮英' and 结算播放量 > 2000000 """
},
{ "current_date":"2023-08-16",
"table_name":"歌曲库",
"fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question":"2021年6月以后发布的李雨纯的播放量大于20万的歌曲有哪些",
"prior_schema_links":"""['23109'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“2021年6月以后发布的李雨纯的播放量大于20万的歌曲有哪些“中我们被问
“播放量大于20万的”所以我们需要column=[结算播放量]
”2021年6月以后发布的“所以我们需要column=[发布时间]
”李雨纯“所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = ['李雨纯', 2021, 6, 200000]。""",
"schema_links":"""["结算播放量", "发布时间", "歌手名", "李雨纯", 2021, 6, 200000]""",
"sql":"""select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 2021 and MONTH(发布时间) >= 6 and 歌手名 = '李雨纯' and 结算播放量 > 200000 and 数据日期 = '2023-08-16'"""
},
{ "current_date":"2023-08-16",
"table_name":"歌曲库",
"fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
"question":"刘锝桦在1992年4月2日到2020年5月2日之间发布的播放量大于20万的歌曲有哪些",
"prior_schema_links":"""['4234234'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“刘锝桦在1992年4月2日到2020年5月2日之间发布的播放量大于20万的歌曲有哪些“中我们被问
“播放量大于20万的”所以我们需要column=[结算播放量]
”1992年4月2日到2020年5月2日之间发布的“所以我们需要column=[发布时间]
”刘锝桦“所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = ['刘锝桦', 1992, 4, 2, 2020, 5, 2, 200000]。""",
"schema_links":"""["结算播放量", "发布时间", "歌手名", "刘锝桦", 1992, 4, 2, 2020, 5, 2, 200000]""",
"sql":"""select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 1992 and MONTH(发布时间) >= 4 and DAY(发布时间) >= 2 and YEAR(发布时间) <= 2020 and MONTH(发布时间) <= 5 and DAY(发布时间) <= 2 and 歌手名 = '刘锝桦' and 结算播放量 > 200000 and 数据日期 = '2023-08-16'"""
}
”陈易迅唱的“所以我们需要column=[歌手名]
基于table和columns可能的cell values 是 = [20000, '易迅']。""",
"schema_links":"""["歌曲名", "结算播放量", "歌手名", 20000, "'易迅'"]""",
"sql":"""select 歌曲名 from 歌曲库 where 结算播放量 > 20000 and 歌手名 = '易迅'"""
}
]

View File

@@ -8,7 +8,8 @@ from typing import Any, List, Mapping, Optional, Union
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import chromadb
from chromadb.config import Settings
from chromadb.api import Collection, Documents, Embeddings
from langchain.llms import OpenAI
@@ -20,9 +21,13 @@ from preset_query_db import (get_ids, add2preset_query_collection,
from util.text2vec import Text2VecEmbeddingFunction
from run_config import CHROMA_DB_PERSIST_PATH, PRESET_QUERY_COLLECTION_NAME
from util.chromadb_instance import client
client = chromadb.Client(Settings(
chroma_db_impl="duckdb+parquet",
persist_directory=CHROMA_DB_PERSIST_PATH # Optional, defaults to .chromadb/ in the current directory
))
emb_func = Text2VecEmbeddingFunction()
collection = client.get_or_create_collection(name=PRESET_QUERY_COLLECTION_NAME,
@@ -30,8 +35,6 @@ collection = client.get_or_create_collection(name=PRESET_QUERY_COLLECTION_NAME,
metadata={"hnsw:space": "cosine"}
) # Get a collection object from an existing collection, by name. If it doesn't exist, create it.
print("init_preset_query_collection_size: ", preset_query_collection_size(collection))
def preset_query_retrieval_run(collection:Collection, query_texts_list:List[str], n_results:int=5):
retrieval_res = query2preset_query_collection(collection=collection,

View File

@@ -9,7 +9,6 @@ TEMPERATURE = 0.0
CHROMA_DB_PERSIST_DIR = 'chm_db'
PRESET_QUERY_COLLECTION_NAME = "preset_query_collection"
TEXT2DSL_COLLECTION_NAME = "text2dsl_collection"
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)

View File

@@ -1,53 +0,0 @@
# -*- coding:utf-8 -*-
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts import PromptTemplate
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
import chromadb
from chromadb.config import Settings
from few_shot_example.sql_exampler import examplars as din_sql_examplars
from util.text2vec import Text2VecEmbeddingFunction, hg_embedding
from util.chromadb_instance import client as chromadb_client
from run_config import TEXT2DSL_COLLECTION_NAME
vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME,
embedding_function=hg_embedding,
client=chromadb_client)
example_nums = 15
schema_linking_example_selector = SemanticSimilarityExampleSelector(vectorstore=vectorstore, k=example_nums,
input_keys=["question"],
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"])
sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=vectorstore, k=example_nums,
input_keys=["question"],
example_keys=["question", "current_date", "table_name", "schema_links", "sql"])
if vectorstore._collection.count() > 0:
print("examples already in din_sql_vectorstore")
print("init din_sql_vectorstore size:", vectorstore._collection.count())
if vectorstore._collection.count() < len(din_sql_examplars):
print("din_sql_examplars size:", len(din_sql_examplars))
vectorstore._collection.delete()
print("empty din_sql_vectorstore")
for example in din_sql_examplars:
schema_linking_example_selector.add_example(example)
print("added din_sql_vectorstore size:", vectorstore._collection.count())
else:
for example in din_sql_examplars:
schema_linking_example_selector.add_example(example)
print("added din_sql_vectorstore size:", vectorstore._collection.count())

View File

@@ -1,13 +1,15 @@
# -*- coding:utf-8 -*-
import re
def schema_link_parse(schema_link_output):
try:
schema_link_output = schema_link_output.strip()
pattern = r'Schema_links:(.*)'
schema_link_output = re.findall(pattern, schema_link_output, re.DOTALL)[0].strip()
except Exception as e:
print(e)
schema_link_output = None
return schema_link_output
def schema_link_parse(schema_link_output):
try:
schema_link_output = schema_link_output.strip()
pattern = r'Schema_links:(.*)'
schema_link_output = re.findall(pattern, schema_link_output, re.DOTALL)[
0].strip()
except Exception as e:
print(e)
schema_link_output = None
return schema_link_output

View File

@@ -1,5 +1,8 @@
# -*- coding:utf-8 -*-
from typing import Any, List, Mapping, Optional, Union
import requests
import logging
import json
import os
import sys
@@ -8,68 +11,78 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from langchain.prompts import PromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
from langchain.llms import OpenAI
from few_shot_example.sql_exampler import examplars
from output_parser import schema_link_parse
def schema_linking_prompt_maker(user_query: str, model_name: str,
fields_list: List[str],
few_shots_example: str):
instruction = "# 根据数据库的表结构,找出为每个问题生成SQL查询语句的schema_links\n"
schema_linking_prompt = "Table {table_name}, columns = {fields_list}\n问题:{user_query}\n分析: 让我们一步一步地思考。".format(
table_name=model_name,
fields_list=fields_list,
user_query=user_query)
return instruction + few_shots_example + schema_linking_prompt
def schema_linking_exampler(user_query: str,
domain_name: str,
fields_list: List[str],
prior_schema_links: Mapping[str,str],
example_selector: SemanticSimilarityExampleSelector,
) -> str:
model_name: str,
fields_list: List[str]
) -> str:
example_prompt_template = PromptTemplate(
input_variables=["table_name", "fields_list", "question", "analysis",
"schema_links"],
template="Table {table_name}, columns = {fields_list}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}")
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
instruction = "# 根据数据库的表结构,找出为每个问题生成SQL查询语句的schema_links"
example_prompt_template = PromptTemplate(input_variables=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"],
template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}")
schema_linking_prompt = "Table {table_name}, columns = {fields_list}\n问题:{question}\n分析: 让我们一步一步地思考。"
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links"
schema_linking_example_prompt_template = FewShotPromptTemplate(
examples=examplars,
example_prompt=example_prompt_template,
example_separator="\n\n",
prefix=instruction,
input_variables=["table_name", "fields_list", "question"],
suffix=schema_linking_prompt
)
schema_linking_prompt = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析: 让我们一步一步地思考。"
schema_linking_example_prompt = schema_linking_example_prompt_template.format(
table_name=model_name,
fields_list=fields_list,
question=user_query)
schema_linking_example_prompt_template = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=example_prompt_template,
example_separator="\n\n",
prefix=instruction,
input_variables=["table_name", "fields_list", "prior_schema_links", "question"],
suffix=schema_linking_prompt
)
schema_linking_example_prompt = schema_linking_example_prompt_template.format(table_name=domain_name,
fields_list=fields_list,
prior_schema_links=prior_schema_links_str,
question=user_query)
return schema_linking_example_prompt
return schema_linking_example_prompt
def sql_exampler(user_query: str,
domain_name: str,
schema_link_str: str,
data_date: str,
example_selector: SemanticSimilarityExampleSelector,
) -> str:
instruction = "# 根据schema_links为每个问题生成SQL查询语句"
model_name: str,
schema_link_str: str
) -> str:
instruction = "# 根据schema_links为每个问题生成SQL查询语句"
sql_example_prompt_template = PromptTemplate(input_variables=["question", "current_date", "table_name", "schema_links", "sql"],
template="问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:{sql}")
sql_example_prompt_template = PromptTemplate(
input_variables=["question", "table_name", "schema_links", "sql"],
template="问题:{question}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:{sql}")
sql_prompt = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:"
sql_prompt = "问题:{question}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:"
sql_example_prompt_template = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=sql_example_prompt_template,
example_separator="\n\n",
prefix=instruction,
input_variables=["question", "current_date", "table_name", "schema_links"],
suffix=sql_prompt
)
sql_example_prompt_template = FewShotPromptTemplate(
examples=examplars,
example_prompt=sql_example_prompt_template,
example_separator="\n\n",
prefix=instruction,
input_variables=["question", "table_name", "schema_links"],
suffix=sql_prompt
)
sql_example_prompt = sql_example_prompt_template.format(question=user_query,
current_date=data_date,
table_name=domain_name,
schema_links=schema_link_str)
sql_example_prompt = sql_example_prompt_template.format(question=user_query,
table_name=model_name,
schema_links=schema_link_str)
return sql_example_prompt
return sql_example_prompt

View File

@@ -1,4 +1,6 @@
from typing import List, Union, Mapping
# -*- coding:utf-8 -*-
from typing import List, Union
import logging
import json
import os
@@ -7,54 +9,33 @@ import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from sql.prompt_maker import schema_linking_exampler, sql_exampler
from sql.constructor import schema_linking_example_selector, sql_example_selector
from sql.output_parser import schema_link_parse
from sql.prompt_maker import schema_linking_exampler, schema_link_parse, \
sql_exampler
from util.llm_instance import llm
def query2sql(query_text: str, schema: dict):
print("schema: ", schema)
def query2sql(query_text: str,
schema : Union[dict, None] = None,
current_date: str = None,
linking: Union[List[Mapping[str, str]], None] = None
):
print("query_text: ", query_text)
print("schema: ", schema)
print("current_date: ", current_date)
print("prior_schema_links: ", linking)
model_name = schema['modelName']
fields_list = schema['fieldNameList']
if linking is not None:
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
else:
prior_schema_links = {}
schema_linking_prompt = schema_linking_exampler(query_text, model_name,
fields_list)
schema_link_output = llm(schema_linking_prompt)
schema_link_str = schema_link_parse(schema_link_output)
model_name = schema['modelName']
fields_list = schema['fieldNameList']
sql_prompt = sql_exampler(query_text, model_name, schema_link_str)
sql_output = llm(sql_prompt)
schema_linking_prompt = schema_linking_exampler(query_text, model_name, fields_list, prior_schema_links, schema_linking_example_selector)
print("schema_linking_prompt->", schema_linking_prompt)
schema_link_output = llm(schema_linking_prompt)
schema_link_str = schema_link_parse(schema_link_output)
sql_prompt = sql_exampler(query_text, model_name, schema_link_str, current_date, sql_example_selector)
print("sql_prompt->", sql_prompt)
sql_output = llm(sql_prompt)
resp = dict()
resp['query'] = query_text
resp['model'] = model_name
resp['fields'] = fields_list
resp = dict()
resp['query'] = query_text
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = linking
resp['dataDate'] = current_date
resp['schemaLinkingOutput'] = schema_link_output
resp['schemaLinkStr'] = schema_link_str
resp['schemaLinkingOutput'] = schema_link_output
resp['schemaLinkStr'] = schema_link_str
resp['sqlOutput'] = sql_output
print("resp: ", resp)
return resp
resp['sqlOutput'] = sql_output
return resp

View File

@@ -1,10 +0,0 @@
# -*- coding:utf-8 -*-
import chromadb
from chromadb.config import Settings
from run_config import CHROMA_DB_PERSIST_PATH
client = chromadb.Client(Settings(
chroma_db_impl="duckdb+parquet",
persist_directory=CHROMA_DB_PERSIST_PATH # Optional, defaults to .chromadb/ in the current directory
))

View File

@@ -1,37 +1,45 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import static org.mockito.ArgumentMatchers.any;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
class DateFieldCorrectorTest {
@Test
void rewriter() {
void corrector() {
MockedStatic<DSLDateHelper> dslDateHelper = Mockito.mockStatic(DSLDateHelper.class);
dslDateHelper.when(() -> DSLDateHelper.getReferenceDate(any())).thenReturn("2023-08-14");
DateFieldCorrector dateFieldCorrector = new DateFieldCorrector();
SemanticParseInfo parseInfo = new SemanticParseInfo();
SchemaElement model = new SchemaElement();
model.setId(2L);
parseInfo.setModel(model);
CorrectionInfo correctionInfo = CorrectionInfo.builder()
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select count(歌曲名) from 歌曲库 ")
.parseInfo(parseInfo)
.build();
CorrectionInfo rewriter = dateFieldCorrector.corrector(correctionInfo);
dateFieldCorrector.correct(semanticCorrectInfo);
Assert.assertEquals("SELECT count(歌曲名) FROM 歌曲库 WHERE 数据日期 = '2023-08-14'", rewriter.getSql());
Assert.assertEquals("SELECT count(歌曲名) FROM 歌曲库 WHERE 数据日期 = '2023-08-14'", semanticCorrectInfo.getSql());
correctionInfo = CorrectionInfo.builder()
semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'")
.parseInfo(parseInfo)
.build();
rewriter = dateFieldCorrector.corrector(correctionInfo);
dateFieldCorrector.correct(semanticCorrectInfo);
Assert.assertEquals("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'", rewriter.getSql());
Assert.assertEquals("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'", semanticCorrectInfo.getSql());
}
}

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
@@ -16,10 +16,10 @@ import org.junit.jupiter.api.Test;
class FieldNameCorrectorTest {
@Test
void rewriter() {
void corrector() {
FieldNameCorrector corrector = new FieldNameCorrector();
CorrectionInfo correctionInfo = CorrectionInfo.builder()
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select 歌曲名 from 歌曲库 where 专辑照片 = '七里香' and 专辑名 = '流行' and 数据日期 = '2023-08-19'")
.build();
@@ -55,11 +55,11 @@ class FieldNameCorrectorTest {
properties.put(Constants.CONTEXT, dslParseResult);
parseInfo.setProperties(properties);
correctionInfo.setParseInfo(parseInfo);
semanticCorrectInfo.setParseInfo(parseInfo);
CorrectionInfo rewriter = corrector.corrector(correctionInfo);
corrector.correct(semanticCorrectInfo);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE 歌曲名 = '七里香' AND 歌曲流派 = '流行' AND 数据日期 = '2023-08-19'",
rewriter.getSql());
semanticCorrectInfo.getSql());
}
}

View File

@@ -2,9 +2,9 @@ package com.tencent.supersonic.chat.corrector;
import static org.mockito.Mockito.when;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -53,19 +53,19 @@ class FieldValueCorrectorTest {
SchemaElement model = new SchemaElement();
model.setId(2L);
parseInfo.setModel(model);
CorrectionInfo correctionInfo = CorrectionInfo.builder()
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生'")
.parseInfo(parseInfo)
.build();
FieldValueCorrector corrector = new FieldValueCorrector();
CorrectionInfo info = corrector.corrector(correctionInfo);
corrector.correct(semanticCorrectInfo);
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", info.getSql());
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", semanticCorrectInfo.getSql());
correctionInfo.setSql("select count(song_name) from 歌曲库 where singer_name = '杰伦'");
info = corrector.corrector(correctionInfo);
semanticCorrectInfo.setSql("select count(song_name) from 歌曲库 where singer_name = '杰伦'");
corrector.correct(semanticCorrectInfo);
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", info.getSql());
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", semanticCorrectInfo.getSql());
}
}

View File

@@ -1,25 +1,26 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
class SelectFieldAppendCorrectorTest {
@Test
void rewriter() {
void corrector() {
SelectFieldAppendCorrector corrector = new SelectFieldAppendCorrector();
CorrectionInfo correctionInfo = CorrectionInfo.builder()
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' "
+ "and sys_imp_date = '2023-08-09' and 歌曲发布时 = '2023-08-01' order by 播放量 desc limit 11")
.build();
CorrectionInfo rewriter = corrector.corrector(correctionInfo);
corrector.correct(semanticCorrectInfo);
Assert.assertEquals(
"SELECT 歌曲名, 歌手名, 歌曲发布时, 发布日期 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 "
+ "AND 歌手名 = '邓紫棋' AND sys_imp_date = '2023-08-09' "
+ "AND 歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", rewriter.getSql());
"SELECT 歌曲名, 歌手名, 播放量, 歌曲发布时, 发布日期 FROM 歌曲库 WHERE "
+ "datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '邓紫棋' "
+ "AND sys_imp_date = '2023-08-09' AND 歌曲发布时 = '2023-08-01'"
+ " ORDER BY 播放量 DESC LIMIT 11", semanticCorrectInfo.getSql());
}
}

View File

@@ -2,9 +2,9 @@ package com.tencent.supersonic.chat.parser.llm.dsl;
import static org.mockito.Mockito.when;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -67,14 +67,14 @@ class LLMDslParserTest {
SchemaElement model = new SchemaElement();
model.setId(2L);
parseInfo.setModel(model);
CorrectionInfo correctionInfo = CorrectionInfo.builder()
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生' and YEAR(publish_time) >= 2023 and ")
.parseInfo(parseInfo)
.build();
LLMDslParser llmDslParser = new LLMDslParser();
llmDslParser.setFilter(correctionInfo, 2L, parseInfo);
llmDslParser.setFilter(semanticCorrectInfo, 2L, parseInfo);
}
}

View File

@@ -11,10 +11,6 @@
<artifactId>chat-knowledge</artifactId>
<dependencies>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-context</artifactId>
@@ -111,8 +107,9 @@
<groupId>com.tencent.supersonic</groupId>
<artifactId>semantic-query</artifactId>
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
</dependencies>
</project>
</project>

View File

@@ -1,55 +0,0 @@
package com.tencent.supersonic.knowledge.dictionary;
import org.apache.commons.lang3.StringUtils;
/***
* nature type
* such as : metric、dimension etc.
*/
public enum DictWordType {
METRIC("metric"),
DIMENSION("dimension"),
VALUE("value"),
DOMAIN("dm"),
MODEL("model"),
ENTITY("entity"),
NUMBER("m"),
SUFFIX("suffix");
public static final String NATURE_SPILT = "_";
public static final String SPACE = " ";
private String type;
DictWordType(String type) {
this.type = type;
}
public String getType() {
return NATURE_SPILT + type;
}
public static DictWordType getNatureType(String nature) {
if (StringUtils.isEmpty(nature) || !nature.startsWith(NATURE_SPILT)) {
return null;
}
for (DictWordType dictWordType : values()) {
if (nature.endsWith(dictWordType.getType())) {
return dictWordType;
}
}
//domain
String[] natures = nature.split(DictWordType.NATURE_SPILT);
if (natures.length == 2 && StringUtils.isNumeric(natures[1])) {
return DOMAIN;
}
//dimension value
if (natures.length == 3 && StringUtils.isNumeric(natures[1]) && StringUtils.isNumeric(natures[2])) {
return VALUE;
}
return null;
}
}

View File

@@ -375,4 +375,4 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
return true;
}
}
}
}

View File

@@ -5,7 +5,7 @@ import java.util.List;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import lombok.extern.slf4j.Slf4j;
/**

View File

@@ -7,7 +7,7 @@ import java.util.List;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

View File

@@ -6,7 +6,7 @@ import java.util.List;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;

View File

@@ -7,7 +7,7 @@ import java.util.Objects;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;

View File

@@ -7,7 +7,7 @@ import java.util.List;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

View File

@@ -7,7 +7,7 @@ import java.util.Objects;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.knowledge.dictionary.builder;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.knowledge;
package com.tencent.supersonic.knowledge.listener;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.knowledge.service.SchemaService;
@@ -16,7 +16,7 @@ import java.util.List;
@Slf4j
@Component
public class ApplicationStartedInit implements ApplicationListener<ApplicationStartedEvent> {
public class ApplicationStartedListener implements ApplicationListener<ApplicationStartedEvent> {
@Autowired
private KnowledgeService knowledgeService;
@@ -27,6 +27,11 @@ public class ApplicationStartedInit implements ApplicationListener<ApplicationSt
@Override
public void onApplicationEvent(ApplicationStartedEvent event) {
updateKnowledgeDimValue();
}
public Boolean updateKnowledgeDimValue() {
Boolean isOk = false;
try {
log.debug("ApplicationStartedInit start");
@@ -35,9 +40,11 @@ public class ApplicationStartedInit implements ApplicationListener<ApplicationSt
knowledgeService.reloadAllData(dictWords);
log.debug("ApplicationStartedInit end");
isOk = true;
} catch (Exception e) {
log.error("ApplicationStartedInit error", e);
}
return isOk;
}
/***
@@ -66,4 +73,4 @@ public class ApplicationStartedInit implements ApplicationListener<ApplicationSt
log.debug("reloadKnowledge end");
}
}
}

View File

@@ -0,0 +1,27 @@
package com.tencent.supersonic.knowledge.listener;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DataAddEvent;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.ApplicationListener;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class DataAddListener implements ApplicationListener<DataAddEvent> {
@Override
public void onApplicationEvent(DataAddEvent event) {
DictWord dictWord = new DictWord();
dictWord.setWord(event.getName());
String sign = DictWordType.NATURE_SPILT;
String nature = sign + event.getModelId() + sign + event.getId() + event.getType();
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
dictWord.setNature(nature);
dictWord.setNatureWithFrequency(natureWithFrequency);
log.info("dataAddListener begins to add data:{}", dictWord);
HanlpHelper.addToCustomDictionary(dictWord);
}
}

View File

@@ -0,0 +1,27 @@
package com.tencent.supersonic.knowledge.listener;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DataDeleteEvent;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.ApplicationListener;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class DataDeleteListener implements ApplicationListener<DataDeleteEvent> {
@Override
public void onApplicationEvent(DataDeleteEvent event) {
DictWord dictWord = new DictWord();
dictWord.setWord(event.getName());
String sign = DictWordType.NATURE_SPILT;
String nature = sign + event.getModelId() + sign + event.getId() + event.getType();
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
dictWord.setNature(nature);
dictWord.setNatureWithFrequency(natureWithFrequency);
log.info("dataDeleteListener begins to delete data:{}", dictWord);
HanlpHelper.removeFromCustomDictionary(dictWord);
}
}

View File

@@ -0,0 +1,29 @@
package com.tencent.supersonic.knowledge.listener;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DataUpdateEvent;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.ApplicationListener;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class DataUpdateListener implements ApplicationListener<DataUpdateEvent> {
@Override
public void onApplicationEvent(DataUpdateEvent event) {
DictWord dictWord = new DictWord();
dictWord.setWord(event.getName());
String sign = DictWordType.NATURE_SPILT;
String nature = sign + event.getModelId() + sign + event.getId() + event.getType();
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
dictWord.setNature(nature);
dictWord.setNatureWithFrequency(natureWithFrequency);
log.info("dataUpdateListener begins to update data:{}", dictWord);
HanlpHelper.removeFromCustomDictionary(dictWord);
dictWord.setWord(event.getNewName());
HanlpHelper.addToCustomDictionary(dictWord);
}
}

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.knowledge.service;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import java.util.List;
@@ -52,4 +52,4 @@ public class KnowledgeServiceImpl implements KnowledgeService {
}
}
}
}

View File

@@ -5,7 +5,7 @@ import com.hankcs.hanlp.collection.trie.bintrie.BinTrie;
import com.hankcs.hanlp.corpus.tag.Nature;
import com.hankcs.hanlp.dictionary.CoreDictionary;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.DictionaryAttributeUtil;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import java.util.Arrays;

View File

@@ -4,7 +4,7 @@ import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.builder.WordBuilderFactory;
import java.util.ArrayList;

View File

@@ -7,7 +7,7 @@ import com.hankcs.hanlp.dictionary.CoreDictionary;
import com.hankcs.hanlp.dictionary.DynamicCustomDictionary;
import com.hankcs.hanlp.seg.Segment;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import java.io.FileNotFoundException;
import java.io.IOException;
@@ -20,6 +20,7 @@ import com.tencent.supersonic.knowledge.dictionary.HadoopFileIOAdapter;
import com.tencent.supersonic.knowledge.service.SearchService;
import com.tencent.supersonic.knowledge.dictionary.MultiCustomDictionary;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ResourceUtils;
@@ -163,6 +164,29 @@ public class HanlpHelper {
return getDynamicCustomDictionary().insert(dictWord.getWord(), dictWord.getNatureWithFrequency());
}
public static void removeFromCustomDictionary(DictWord dictWord) {
log.info("dictWord:{}", dictWord);
CoreDictionary.Attribute attribute = getDynamicCustomDictionary().get(dictWord.getWord());
if (attribute != null) {
return;
}
log.info("get attribute:{}", attribute);
getDynamicCustomDictionary().remove(dictWord.getWord());
StringBuilder sb = new StringBuilder();
for (int i = 0; i < attribute.nature.length; i++) {
if (!attribute.nature[i].toString().equals(dictWord.getNature())) {
sb.append(attribute.nature[i].toString() + " ");
sb.append(attribute.frequency[i] + " ");
}
}
String natureWithFrequency = sb.toString();
int len = natureWithFrequency.length();
log.info("filtered natureWithFrequency:{}", natureWithFrequency);
if (StringUtils.isNotBlank(natureWithFrequency)) {
getDynamicCustomDictionary().add(dictWord.getWord(), natureWithFrequency.substring(0, len - 1));
}
}
public static void transLetterOriginal(List<MapResult> mapResults) {
if (CollectionUtils.isEmpty(mapResults)) {
return;

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.knowledge.utils;
import com.hankcs.hanlp.corpus.tag.Nature;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.ModelInfoStat;
import java.util.ArrayList;
import java.util.Comparator;