(improvement)(Chat) Move chat-core to headless (#805)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2024-03-12 22:20:30 +08:00
committed by GitHub
parent f152deeb81
commit f93bee81cb
301 changed files with 2256 additions and 4527 deletions

View File

@@ -1,89 +0,0 @@
package com.tencent.supersonic.chat.api.pojo;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import lombok.Data;
@Data
public class DataSetSchema {
private SchemaElement dataSet;
private Set<SchemaElement> metrics = new HashSet<>();
private Set<SchemaElement> dimensions = new HashSet<>();
private Set<SchemaElement> dimensionValues = new HashSet<>();
private Set<SchemaElement> tags = new HashSet<>();
private Set<SchemaElement> tagValues = new HashSet<>();
private SchemaElement entity = new SchemaElement();
private QueryConfig queryConfig;
private QueryType queryType;
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
Optional<SchemaElement> element = Optional.empty();
switch (elementType) {
case ENTITY:
element = Optional.ofNullable(entity);
break;
case DATASET:
element = Optional.of(dataSet);
break;
case METRIC:
element = metrics.stream().filter(e -> e.getId() == elementID).findFirst();
break;
case DIMENSION:
element = dimensions.stream().filter(e -> e.getId() == elementID).findFirst();
break;
case VALUE:
element = dimensionValues.stream().filter(e -> e.getId() == elementID).findFirst();
break;
case TAG:
element = tags.stream().filter(e -> e.getId() == elementID).findFirst();
break;
case TAG_VALUE:
element = tagValues.stream().filter(e -> e.getId() == elementID).findFirst();
break;
default:
}
if (element.isPresent()) {
return element.get();
} else {
return null;
}
}
public TimeDefaultConfig getTagTypeTimeDefaultConfig() {
if (queryConfig == null) {
return null;
}
if (queryConfig.getTagTypeDefaultConfig() == null) {
return null;
}
return queryConfig.getTagTypeDefaultConfig().getTimeDefaultConfig();
}
public TimeDefaultConfig getMetricTypeTimeDefaultConfig() {
if (queryConfig == null) {
return null;
}
if (queryConfig.getMetricTypeDefaultConfig() == null) {
return null;
}
return queryConfig.getMetricTypeDefaultConfig().getTimeDefaultConfig();
}
public TagTypeDefaultConfig getTagTypeDefaultConfig() {
if (queryConfig == null) {
return null;
}
return queryConfig.getTagTypeDefaultConfig();
}
}

View File

@@ -1,24 +0,0 @@
package com.tencent.supersonic.chat.api.pojo;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
@Data
@ToString
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class SchemaElementMatch {
SchemaElement element;
double similarity;
String detectWord;
String word;
Long frequency;
boolean isInherited;
}

View File

@@ -1,33 +0,0 @@
package com.tencent.supersonic.chat.api.pojo;
import com.google.common.collect.Lists;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class SchemaMapInfo {
private Map<Long, List<SchemaElementMatch>> dataSetElementMatches = new HashMap<>();
public Set<Long> getMatchedDataSetInfos() {
return dataSetElementMatches.keySet();
}
public List<SchemaElementMatch> getMatchedElements(Long dataSet) {
return dataSetElementMatches.getOrDefault(dataSet, Lists.newArrayList());
}
public Map<Long, List<SchemaElementMatch>> getDataSetElementMatches() {
return dataSetElementMatches;
}
public void setDataSetElementMatches(Map<Long, List<SchemaElementMatch>> dataSetElementMatches) {
this.dataSetElementMatches = dataSetElementMatches;
}
public void setMatchedElements(Long dataSet, List<SchemaElementMatch> elementMatches) {
dataSetElementMatches.put(dataSet, elementMatches);
}
}

View File

@@ -1,22 +0,0 @@
package com.tencent.supersonic.chat.api.pojo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
@Builder
public class SemanticCorrectInfo {
private QueryFilters queryFilters;
private SemanticParseInfo parseInfo;
private String sql;
private String preSql;
}

View File

@@ -1,82 +0,0 @@
package com.tencent.supersonic.chat.api.pojo;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.FilterType;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import lombok.Data;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
@Data
public class SemanticParseInfo {
private Integer id;
private String queryMode;
private SchemaElement dataSet;
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
private Set<SchemaElement> dimensions = new LinkedHashSet();
private SchemaElement entity;
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
private FilterType filterType = FilterType.UNION;
private Set<QueryFilter> dimensionFilters = new LinkedHashSet();
private Set<QueryFilter> metricFilters = new LinkedHashSet();
private Set<Order> orders = new LinkedHashSet();
private DateConf dateInfo;
private Long limit;
private double score;
private List<SchemaElementMatch> elementMatches = new ArrayList<>();
private Map<String, Object> properties = new HashMap<>();
private EntityInfo entityInfo;
private SqlInfo sqlInfo = new SqlInfo();
private QueryType queryType = QueryType.ID;
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
@Override
public int compare(SchemaElement o1, SchemaElement o2) {
if (o1.getOrder() != o2.getOrder()) {
if (o1.getOrder() < o2.getOrder()) {
return -1;
} else {
return 1;
}
}
int len1 = o1.getName().length();
int len2 = o2.getName().length();
if (len1 != len2) {
return len1 - len2;
} else {
return o1.getName().compareTo(o2.getName());
}
}
}
public Set<SchemaElement> getMetrics() {
Set<SchemaElement> metricSet = new TreeSet<>(new SchemaNameLengthComparator());
metricSet.addAll(metrics);
metrics = metricSet;
return metrics;
}
public Long getDataSetId() {
if (dataSet == null) {
return null;
}
return dataSet.getDataSet();
}
}

View File

@@ -1,172 +0,0 @@
package com.tencent.supersonic.chat.api.pojo;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import org.springframework.util.CollectionUtils;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
public class SemanticSchema implements Serializable {
private List<DataSetSchema> dataSetSchemaList;
public SemanticSchema(List<DataSetSchema> dataSetSchemaList) {
this.dataSetSchemaList = dataSetSchemaList;
}
public void add(DataSetSchema schema) {
dataSetSchemaList.add(schema);
}
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
Optional<SchemaElement> element = Optional.empty();
switch (elementType) {
case ENTITY:
element = getElementsById(elementID, getEntities());
break;
case DATASET:
element = getElementsById(elementID, getDataSets());
break;
case METRIC:
element = getElementsById(elementID, getMetrics());
break;
case DIMENSION:
element = getElementsById(elementID, getDimensions());
break;
case VALUE:
element = getElementsById(elementID, getDimensionValues());
break;
case TAG:
element = getElementsById(elementID, getTags());
break;
case TAG_VALUE:
element = getElementsById(elementID, getTagValues());
break;
default:
}
if (element.isPresent()) {
return element.get();
} else {
return null;
}
}
public Map<Long, String> getDataSetIdToName() {
return dataSetSchemaList.stream()
.collect(Collectors.toMap(a -> a.getDataSet().getId(), a -> a.getDataSet().getName(), (k1, k2) -> k1));
}
public List<SchemaElement> getDimensionValues() {
List<SchemaElement> dimensionValues = new ArrayList<>();
dataSetSchemaList.stream().forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
return dimensionValues;
}
public List<SchemaElement> getDimensions() {
List<SchemaElement> dimensions = new ArrayList<>();
dataSetSchemaList.stream().forEach(d -> dimensions.addAll(d.getDimensions()));
return dimensions;
}
public List<SchemaElement> getDimensions(Long dataSetId) {
List<SchemaElement> dimensions = getDimensions();
return getElementsByDataSetId(dataSetId, dimensions);
}
public SchemaElement getDimension(Long id) {
List<SchemaElement> dimensions = getDimensions();
Optional<SchemaElement> dimension = getElementsById(id, dimensions);
return dimension.orElse(null);
}
public List<SchemaElement> getTags() {
List<SchemaElement> tags = new ArrayList<>();
dataSetSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
return tags;
}
public List<SchemaElement> getTags(Long dataSetId) {
List<SchemaElement> tags = getTags();
return getElementsByDataSetId(dataSetId, tags);
}
public List<SchemaElement> getTagValues() {
List<SchemaElement> tags = new ArrayList<>();
dataSetSchemaList.stream().forEach(d -> tags.addAll(d.getTagValues()));
return tags;
}
public List<SchemaElement> getTagValues(Long dataSetId) {
List<SchemaElement> tags = getTagValues();
return getElementsByDataSetId(dataSetId, tags);
}
public List<SchemaElement> getMetrics() {
List<SchemaElement> metrics = new ArrayList<>();
dataSetSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
return metrics;
}
public List<SchemaElement> getMetrics(Long dataSetId) {
List<SchemaElement> metrics = getMetrics();
return getElementsByDataSetId(dataSetId, metrics);
}
public List<SchemaElement> getEntities() {
List<SchemaElement> entities = new ArrayList<>();
dataSetSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
return entities;
}
public List<SchemaElement> getEntities(Long dataSetId) {
List<SchemaElement> entities = getEntities();
return getElementsByDataSetId(dataSetId, entities);
}
private List<SchemaElement> getElementsByDataSetId(Long dataSetId, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> dataSetId.equals(schemaElement.getDataSet()))
.collect(Collectors.toList());
}
private Optional<SchemaElement> getElementsById(Long id, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> id.equals(schemaElement.getId()))
.findFirst();
}
public SchemaElement getDataSet(Long dataSetId) {
List<SchemaElement> dataSets = getDataSets();
return getElementsById(dataSetId, dataSets).orElse(null);
}
public List<SchemaElement> getDataSets() {
List<SchemaElement> dataSets = new ArrayList<>();
dataSetSchemaList.stream().forEach(d -> dataSets.add(d.getDataSet()));
return dataSets;
}
public Map<String, String> getBizNameToName(Long dataSetId) {
List<SchemaElement> allElements = new ArrayList<>();
allElements.addAll(getDimensions(dataSetId));
allElements.addAll(getMetrics(dataSetId));
return allElements.stream()
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
}
public Map<Long, DataSetSchema> getDataSetSchemaMap() {
if (CollectionUtils.isEmpty(dataSetSchemaList)) {
return new HashMap<>();
}
return dataSetSchemaList.stream().collect(Collectors.toMap(dataSetSchema
-> dataSetSchema.getDataSet().getDataSet(), dataSetSchema -> dataSetSchema));
}
}

View File

@@ -16,16 +16,6 @@ public class ChatConfigBaseReq {
private Long modelId;
/**
* the chatDetailConfig about the model
*/
private ChatDetailConfigReq chatDetailConfig;
/**
* the chatAggConfig about the model
*/
private ChatAggConfigReq chatAggConfig;
/**
* the recommended questions about the model

View File

@@ -1,20 +1,21 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Builder
@Data
public class ExecuteQueryReq {
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ChatExecuteReq {
private User user;
private Integer agentId;
private Integer chatId;
private String queryText;
private Long queryId;
private Integer parseId;
private SemanticParseInfo parseInfo;
private Integer chatId;
private int parseId;
private String queryText;
private boolean saveAnswer;
}

View File

@@ -1,15 +1,16 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import lombok.Data;
@Data
public class QueryReq {
public class ChatParseReq {
private String queryText;
private Integer chatId;
private Long dataSetId;
private Integer agentId;
private User user;
private QueryFilters queryFilters;
private boolean saveAnswer = true;
private Integer agentId;
}

View File

@@ -1,21 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.common.pojo.DateConf;
import java.util.HashSet;
import java.util.Set;
import lombok.Data;
@Data
public class QueryDataReq {
private User user;
private Set<SchemaElement> metrics = new HashSet<>();
private Set<SchemaElement> dimensions = new HashSet<>();
private Set<QueryFilter> dimensionFilters = new HashSet<>();
private Set<QueryFilter> metricFilters = new HashSet<>();
private DateConf dateInfo;
private Long queryId;
private Integer parseId;
}

View File

@@ -1,43 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.google.common.base.Objects;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import lombok.Data;
import lombok.ToString;
@Data
@ToString(callSuper = true)
public class QueryFilter {
private String bizName;
private String name;
private FilterOperatorEnum operator = FilterOperatorEnum.EQUALS;
private Object value;
private Long elementID;
private String function;
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
QueryFilter that = (QueryFilter) o;
return Objects.equal(bizName, that.bizName) && Objects.equal(name,
that.name) && operator == that.operator && Objects.equal(value, that.value)
&& Objects.equal(elementID, that.elementID) && Objects.equal(
function, that.function);
}
@Override
public int hashCode() {
return Objects.hashCode(bizName, name, operator, value, elementID, function);
}
}

View File

@@ -1,13 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.request;
import lombok.Data;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Data
public class QueryFilters {
private List<QueryFilter> filters = new ArrayList<>();
private Map<String, Object> params = new HashMap<>();
}

View File

@@ -1,16 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.response;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class DataInfo {
private Integer itemId;
private String name;
private String bizName;
private String value;
}

View File

@@ -1,13 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.response;
import lombok.Data;
import java.io.Serializable;
import java.util.List;
@Data
public class DataSetInfo extends DataInfo implements Serializable {
private List<String> words;
private String primaryKey;
}

View File

@@ -1,16 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.response;
import lombok.Data;
import java.util.ArrayList;
import java.util.List;
@Data
public class EntityInfo {
private DataSetInfo dataSetInfo = new DataSetInfo();
private List<DataInfo> dimensions = new ArrayList<>();
private List<DataInfo> metrics = new ArrayList<>();
private String entityId;
}

View File

@@ -1,23 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import lombok.Data;
import java.util.List;
@Data
public class ParseResp {
private Integer chatId;
private String queryText;
private Long queryId;
private ParseState state;
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
private ParseTimeCostDO parseTimeCost = new ParseTimeCostDO();
public enum ParseState {
COMPLETED,
PENDING,
FAILED
}
}

View File

@@ -1,15 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.response;
import lombok.Data;
@Data
public class ParseTimeCostDO {
private long parseStartTime;
private long parseTime;
private long sqlTime;
public ParseTimeCostDO() {
this.parseStartTime = System.currentTimeMillis();
}
}

View File

@@ -1,22 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import lombok.Data;
import java.util.Date;
import java.util.List;
@Data
public class QueryResp {
private Long questionId;
private Date createTime;
private Long chatId;
private Integer score;
private String feedback;
private String queryText;
private QueryResult queryResult;
private List<SemanticParseInfo> parseInfos;
private List<SimilarQueryRecallResp> similarQueries;
}

View File

@@ -1,27 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.common.pojo.QueryAuthorization;
import com.tencent.supersonic.common.pojo.QueryColumn;
import lombok.Data;
import java.util.List;
import java.util.Map;
@Data
public class QueryResult {
public EntityInfo entityInfo;
public AggregateInfo aggregateInfo;
private Long queryId;
private String queryMode;
private String querySql;
private QueryState queryState = QueryState.EMPTY;
private List<QueryColumn> queryColumns;
private QueryAuthorization queryAuthorization;
private SemanticParseInfo chatContext;
private Object response;
private List<Map<String, Object>> queryResults;
private Long queryTimeCost;
private List<SchemaElement> recommendedDimensions;
}

View File

@@ -1,8 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.response;
public enum QueryState {
SUCCESS,
SEARCH_EXCEPTION,
EMPTY,
INVALID;
}

View File

@@ -1,15 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.response;
import java.util.List;
import lombok.Data;
@Data
public class SearchResp {
private List<SearchResult> searchResults;
public SearchResp(List<SearchResult> searchResults) {
this.searchResults = searchResults;
}
}

View File

@@ -1,45 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import java.util.Objects;
import lombok.Builder;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
@Data
@Setter
@Getter
@Builder
public class SearchResult {
private String recommend;
private String subRecommend;
private String modelName;
private Long modelId;
private SchemaElementType schemaElementType;
private boolean isComplete = true;
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SearchResult searchResult1 = (SearchResult) o;
return Objects.equals(recommend, searchResult1.recommend) && Objects.equals(modelName,
searchResult1.modelName);
}
@Override
public int hashCode() {
return Objects.hash(recommend, modelName);
}
}

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
import lombok.Data;
import java.util.List;

View File

@@ -1,11 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.response;
import lombok.Data;
@Data
public class SqlInfo {
private String s2SQL;
private String correctS2SQL;
private String querySQL;
}

View File

@@ -1,110 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>chat</artifactId>
<groupId>com.tencent.supersonic</groupId>
<version>${revision}</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>chat-core</artifactId>
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-context</artifactId>
</dependency>
<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<version>${org.testng.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-compress</artifactId>
<version>${commons.compress.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>druid</artifactId>
<version>${alibaba.druid.version}</version>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
</dependency>
<dependency>
<groupId>com.h2database</groupId>
<artifactId>h2</artifactId>
<version>${h2.version}</version>
</dependency>
<dependency>
<groupId>com.tencent.supersonic</groupId>
<artifactId>headless-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.tencent.supersonic</groupId>
<artifactId>headless-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.tencent.supersonic</groupId>
<artifactId>chat-api</artifactId>
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.github.xkzhangsan</groupId>
<artifactId>xk-time</artifactId>
<version>${xk.time.version}</version>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<version>${mockito-inline.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.tencent.supersonic</groupId>
<artifactId>headless-server</artifactId>
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
</dependencies>
</project>

View File

@@ -1,13 +0,0 @@
package com.tencent.supersonic.chat.core.config;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Data
@Configuration
public class AggregatorConfig {
@Value("${metric.aggregator.ratio.enable:true}")
private Boolean enableRatio;
}

View File

@@ -1,42 +0,0 @@
package com.tencent.supersonic.chat.core.config;
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
import java.io.FileNotFoundException;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Data
@Configuration
@Slf4j
public class ChatLocalFileConfig {
@Value("${dict.directory.latest:/data/dictionary/custom}")
private String dictDirectoryLatest;
@Value("${dict.directory.backup:./dict/backup}")
private String dictDirectoryBackup;
public String getDictDirectoryLatest() {
return getResourceDir() + dictDirectoryLatest;
}
public String getDictDirectoryBackup() {
return dictDirectoryBackup;
}
private String getResourceDir() {
String hanlpPropertiesPath = "";
try {
hanlpPropertiesPath = HanlpHelper.getHanlpPropertiesPath();
} catch (FileNotFoundException e) {
log.warn("getResourceDir, e:", e);
}
return hanlpPropertiesPath;
}
}

View File

@@ -1,42 +0,0 @@
package com.tencent.supersonic.chat.core.config;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class DefaultMetric {
/**
* default metrics
*/
private Long metricId;
/**
* default time span unit
*/
private Integer unit;
/**
* default time type: DAY
* DAY, WEEK, MONTH, YEAR
*/
private String period;
private String bizName;
private String name;
public DefaultMetric(Long metricId, Integer unit, String period) {
this.metricId = metricId;
this.unit = unit;
this.period = period;
}
public DefaultMetric(String bizName, Integer unit, String period) {
this.bizName = bizName;
this.unit = unit;
this.period = period;
}
}

View File

@@ -1,32 +0,0 @@
package com.tencent.supersonic.chat.core.config;
import com.tencent.supersonic.common.pojo.Constants;
import lombok.Data;
import lombok.ToString;
/**
* default metrics about the model
*/
@ToString
@Data
public class DefaultMetricInfo {
/**
* default metrics
*/
private Long metricId;
/**
* default time span unit
*/
private Integer unit = 1;
/**
* default time type: day
* DAY, WEEK, MONTH, YEAR
*/
private String period = Constants.DAY;
}

View File

@@ -1,44 +0,0 @@
package com.tencent.supersonic.chat.core.config;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Configuration
@Data
public class DefaultSemanticConfig {
@Value("${semantic.url.prefix:http://localhost:8081}")
private String semanticUrl;
@Value("${searchByStruct.path:/api/semantic/query/struct}")
private String searchByStructPath;
@Value("${searchByStruct.path:/api/semantic/query/multiStruct}")
private String searchByMultiStructPath;
@Value("${searchByStruct.path:/api/semantic/query/sql}")
private String searchBySqlPath;
@Value("${searchByStruct.path:/api/semantic/query/queryDimValue}")
private String queryDimValuePath;
@Value("${fetchModelSchemaPath.path:/api/semantic/schema}")
private String fetchModelSchemaPath;
@Value("${fetchModelList.path:/api/semantic/schema/dimension/page}")
private String fetchDimensionPagePath;
@Value("${fetchModelList.path:/api/semantic/schema/metric/page}")
private String fetchMetricPagePath;
@Value("${fetchModelList.path:/api/semantic/schema/domain/list}")
private String fetchDomainListPath;
@Value("${fetchModelList.path:/api/semantic/schema/model/list}")
private String fetchModelListPath;
@Value("${explain.path:/api/semantic/query/explain}")
private String explainPath;
}

View File

@@ -1,21 +0,0 @@
package com.tencent.supersonic.chat.core.config;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
@Data
@AllArgsConstructor
@ToString
@NoArgsConstructor
public class Dim4Dict {
private Long dimId;
private String bizName;
private List<String> blackList;
private List<String> whiteList;
private List<String> ruleList;
}

View File

@@ -1,15 +0,0 @@
package com.tencent.supersonic.chat.core.config;
import java.util.List;
import lombok.Data;
/**
* when query an entity, return related dimension/metric info
*/
@Data
public class EntityDetailData {
private List<Long> dimensionIds;
private List<Long> metricIds;
}

View File

@@ -1,14 +0,0 @@
package com.tencent.supersonic.chat.core.config;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import java.util.List;
import lombok.Data;
@Data
public class EntityInternalDetail {
List<DimSchemaResp> dimensionList;
List<MetricSchemaResp> metricList;
}

View File

@@ -1,30 +0,0 @@
package com.tencent.supersonic.chat.core.config;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Configuration
@Data
public class LLMParserConfig {
@Value("${llm.parser.url:}")
private String url;
@Value("${query2sql.path:/query2sql}")
private String queryToSqlPath;
@Value("${dimension.topn:10}")
private Integer dimensionTopN;
@Value("${metric.topn:10}")
private Integer metricTopN;
@Value("${tag.topn:20}")
private Integer tagTopN;
@Value("${all.model:false}")
private Boolean allModel;
}

View File

@@ -1,175 +0,0 @@
package com.tencent.supersonic.chat.core.config;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.common.service.SysParameterService;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Configuration
@Data
@Slf4j
public class OptimizationConfig {
@Value("${one.detection.size:8}")
private Integer oneDetectionSize;
@Value("${one.detection.max.size:20}")
private Integer oneDetectionMaxSize;
@Value("${metric.dimension.min.threshold:0.3}")
private Double metricDimensionMinThresholdConfig;
@Value("${metric.dimension.threshold:0.3}")
private Double metricDimensionThresholdConfig;
@Value("${dimension.value.threshold:0.5}")
private Double dimensionValueThresholdConfig;
@Value("${long.text.threshold:0.8}")
private Double longTextThreshold;
@Value("${short.text.threshold:0.5}")
private Double shortTextThreshold;
@Value("${query.text.length.threshold:10}")
private Integer queryTextLengthThreshold;
@Value("${embedding.mapper.word.min:4}")
private int embeddingMapperWordMin;
@Value("${embedding.mapper.word.max:5}")
private int embeddingMapperWordMax;
@Value("${embedding.mapper.batch:50}")
private int embeddingMapperBatch;
@Value("${embedding.mapper.number:5}")
private int embeddingMapperNumber;
@Value("${embedding.mapper.round.number:10}")
private int embeddingMapperRoundNumber;
@Value("${embedding.mapper.distance.threshold:0.01}")
private Double embeddingMapperDistanceThreshold;
@Value("${s2SQL.linking.value.switch:true}")
private boolean useLinkingValueSwitch;
@Value("${s2SQL.generation:TWO_PASS_AUTO_COT}")
private SqlGenerationMode sqlGenerationMode;
@Value("${s2SQL.use.switch:true}")
private boolean useS2SqlSwitch;
@Value("${text2sql.example.num:15}")
private int text2sqlExampleNum;
@Value("${text2sql.fewShots.num:10}")
private int text2sqlFewShotsNum;
@Value("${text2sql.self.consistency.num:5}")
private int text2sqlSelfConsistencyNum;
@Value("${parse.show.count:3}")
private Integer parseShowCount;
@Autowired
private SysParameterService sysParameterService;
public Integer getOneDetectionSize() {
return convertValue("one.detection.size", Integer.class, oneDetectionSize);
}
public Integer getOneDetectionMaxSize() {
return convertValue("one.detection.max.size", Integer.class, oneDetectionMaxSize);
}
public Double getMetricDimensionMinThresholdConfig() {
return convertValue("metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
}
public Double getMetricDimensionThresholdConfig() {
return convertValue("metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
}
public Double getDimensionValueThresholdConfig() {
return convertValue("dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
}
public Double getLongTextThreshold() {
return convertValue("long.text.threshold", Double.class, longTextThreshold);
}
public Double getShortTextThreshold() {
return convertValue("short.text.threshold", Double.class, shortTextThreshold);
}
public Integer getQueryTextLengthThreshold() {
return convertValue("query.text.length.threshold", Integer.class, queryTextLengthThreshold);
}
public boolean isUseS2SqlSwitch() {
return convertValue("use.s2SQL.switch", Boolean.class, useS2SqlSwitch);
}
public Integer getEmbeddingMapperWordMin() {
return convertValue("embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
}
public Integer getEmbeddingMapperWordMax() {
return convertValue("embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
}
public Integer getEmbeddingMapperBatch() {
return convertValue("embedding.mapper.batch", Integer.class, embeddingMapperBatch);
}
public Integer getEmbeddingMapperNumber() {
return convertValue("embedding.mapper.number", Integer.class, embeddingMapperNumber);
}
public Integer getEmbeddingMapperRoundNumber() {
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
}
public Double getEmbeddingMapperDistanceThreshold() {
return convertValue("embedding.mapper.distance.threshold", Double.class, embeddingMapperDistanceThreshold);
}
public boolean isUseLinkingValueSwitch() {
return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch);
}
public SqlGenerationMode getSqlGenerationMode() {
return convertValue("s2SQL.generation", SqlGenerationMode.class, sqlGenerationMode);
}
public Integer getParseShowCount() {
return convertValue("parse.show.count", Integer.class, parseShowCount);
}
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {
try {
String value = sysParameterService.getSysParameter().getParameterByName(paramName);
if (StringUtils.isBlank(value)) {
return defaultValue;
}
if (targetType == Double.class) {
return targetType.cast(Double.parseDouble(value));
} else if (targetType == Integer.class) {
return targetType.cast(Integer.parseInt(value));
} else if (targetType == Boolean.class) {
return targetType.cast(Boolean.parseBoolean(value));
} else if (targetType == SqlGenerationMode.class) {
return targetType.cast(SqlGenerationMode.valueOf(value));
}
} catch (Exception e) {
log.error("convertValue", e);
}
return defaultValue;
}
}

View File

@@ -1,158 +0,0 @@
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.core.env.Environment;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* basic semantic correction functionality, offering common methods and an
* abstract method called doCorrect
*/
@Slf4j
public abstract class BaseSemanticCorrector implements SemanticCorrector {
public void correct(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
try {
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
return;
}
doCorrect(queryContext, semanticParseInfo);
log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
} catch (Exception e) {
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
}
}
public abstract void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Long dataSetId) {
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
List<SchemaElement> dbAllFields = new ArrayList<>();
dbAllFields.addAll(semanticSchema.getMetrics());
dbAllFields.addAll(semanticSchema.getDimensions());
// support fieldName and field alias
Map<String, String> result = dbAllFields.stream()
.filter(entry -> dataSetId.equals(entry.getDataSet()))
.flatMap(schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
elements.addAll(schemaElement.getAlias());
}
return elements.stream();
})
.collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1));
result.put(TimeDimensionEnum.DAY.getChName(), TimeDimensionEnum.DAY.getChName());
result.put(TimeDimensionEnum.MONTH.getChName(), TimeDimensionEnum.MONTH.getChName());
result.put(TimeDimensionEnum.WEEK.getChName(), TimeDimensionEnum.WEEK.getChName());
result.put(TimeDimensionEnum.DAY.getName(), TimeDimensionEnum.DAY.getChName());
result.put(TimeDimensionEnum.MONTH.getName(), TimeDimensionEnum.MONTH.getChName());
result.put(TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.WEEK.getChName());
return result;
}
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
//decide whether add order by expression field to select
Environment environment = ContextUtils.getBean(Environment.class);
String correctorAdditionalInfo = environment.getProperty("corrector.additional.information");
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
}
// If there is no aggregate function in the S2SQL statement and
// there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement.
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
List<String> timeChNameList = TimeDimensionEnum.getChNameList();
Set<String> timeFields = whereFields.stream().filter(field -> timeChNameList.contains(field))
.collect(Collectors.toSet());
needAddFields.addAll(timeFields);
}
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
return;
}
needAddFields.removeAll(selectFields);
String replaceFields = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
}
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
Long dataSetId = semanticParseInfo.getDataSet().getDataSet();
List<SchemaElement> metrics = getMetricElements(queryContext, dataSetId);
Map<String, String> metricToAggregate = metrics.stream()
.map(schemaElement -> {
if (Objects.isNull(schemaElement.getDefaultAgg())) {
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
}
return schemaElement;
}).flatMap(schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
elements.addAll(schemaElement.getAlias());
}
return elements.stream().map(element -> Pair.of(element, schemaElement.getDefaultAgg())
);
}).collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1));
if (CollectionUtils.isEmpty(metricToAggregate)) {
return;
}
String aggregateSql = SqlAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
}
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long dataSetId) {
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
return semanticSchema.getMetrics(dataSetId);
}
protected Set<String> getDimensions(Long dataSetId, SemanticSchema semanticSchema) {
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
.flatMap(
schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
elements.addAll(schemaElement.getAlias());
}
return elements.stream();
}
).collect(Collectors.toSet());
dimensions.add(TimeDimensionEnum.DAY.getChName());
return dimensions;
}
}

View File

@@ -1,31 +0,0 @@
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import java.util.ArrayList;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
/**
* Correcting SQL syntax, primarily including fixes to select, where, groupBy, and Having clauses
*/
@Slf4j
public class GrammarCorrector extends BaseSemanticCorrector {
private List<BaseSemanticCorrector> correctors;
public GrammarCorrector() {
correctors = new ArrayList<>();
correctors.add(new SelectCorrector());
correctors.add(new WhereCorrector());
correctors.add(new GroupByCorrector());
correctors.add(new HavingCorrector());
}
@Override
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
for (BaseSemanticCorrector corrector : correctors) {
corrector.correct(queryContext, semanticParseInfo);
}
}
}

View File

@@ -1,116 +0,0 @@
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.DataSetService;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
/**
* Perform SQL corrections on the "Group by" section in S2SQL.
*/
@Slf4j
public class GroupByCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Boolean needAddGroupBy = needAddGroupBy(queryContext, semanticParseInfo);
if (!needAddGroupBy) {
return;
}
addGroupByFields(queryContext, semanticParseInfo);
}
private Boolean needAddGroupBy(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Long dataSetId = semanticParseInfo.getDataSetId();
DataSetService dataSetService = ContextUtils.getBean(DataSetService.class);
ModelService modelService = ContextUtils.getBean(ModelService.class);
DataSetResp dataSetResp = dataSetService.getDataSet(dataSetId);
List<Long> modelIds = dataSetResp.getDataSetDetail()
.getDataSetModelConfigs().stream().map(DataSetModelConfig::getId)
.collect(Collectors.toList());
MetaFilter metaFilter = new MetaFilter();
metaFilter.setIds(modelIds);
List<ModelResp> modelRespList = modelService.getModelList(metaFilter);
for (ModelResp modelResp : modelRespList) {
List<Dim> dimList = modelResp.getModelDetail().getDimensions();
for (Dim dim : dimList) {
if (Objects.nonNull(dim.getTypeParams()) && dim.getTypeParams().getTimeGranularity().equals("none")) {
return false;
}
}
}
//add dimension group by
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
// check has distinct
if (SqlSelectHelper.hasDistinct(correctS2SQL)) {
log.info("not add group by ,exist distinct in correctS2SQL:{}", correctS2SQL);
return false;
}
//add alias field name
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
return false;
}
// if only date in select not add group by.
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
return false;
}
if (SqlSelectHelper.hasGroupBy(correctS2SQL)) {
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
return false;
}
return true;
}
private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Long dataSetId = semanticParseInfo.getDataSetId();
//add dimension group by
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
//add alias field name
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
Set<String> groupByFields = selectFields.stream()
.filter(field -> dimensions.contains(field))
.filter(field -> {
if (!CollectionUtils.isEmpty(aggregateFields) && aggregateFields.contains(field)) {
return false;
}
return true;
})
.collect(Collectors.toSet());
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
addAggregate(queryContext, semanticParseInfo);
}
private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
return;
}
addAggregateToMetric(queryContext, semanticParseInfo);
}
}

View File

@@ -1,69 +0,0 @@
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.env.Environment;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Perform SQL corrections on the "Having" section in S2SQL.
*/
@Slf4j
public class HavingCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric
addHaving(queryContext, semanticParseInfo);
//decide whether add having expression field to select
Environment environment = ContextUtils.getBean(Environment.class);
String correctorAdditionalInfo = environment.getProperty("corrector.additional.information");
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
addHavingToSelect(semanticParseInfo);
}
}
private void addHaving(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Long dataSet = semanticParseInfo.getDataSet().getDataSet();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Set<String> metrics = semanticSchema.getMetrics(dataSet).stream()
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
if (CollectionUtils.isEmpty(metrics)) {
return;
}
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
}
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
return;
}
List<Expression> havingExpressionList = SqlSelectHelper.getHavingExpression(correctS2SQL);
if (!CollectionUtils.isEmpty(havingExpressionList)) {
String replaceSql = SqlAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
}
return;
}
}

View File

@@ -1,150 +0,0 @@
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Perform schema corrections on the Schema information in S2SQL.
*/
@Slf4j
public class SchemaCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);
correctAggFunction(semanticParseInfo);
replaceAlias(semanticParseInfo);
updateFieldNameByLinkingValue(semanticParseInfo);
updateFieldValueByLinkingValue(semanticParseInfo);
correctFieldName(queryContext, semanticParseInfo);
}
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
sqlInfo.setCorrectS2SQL(sql);
}
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
sqlInfo.setCorrectS2SQL(replaceAlias);
}
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getDataSetId());
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
sqlInfo.setCorrectS2SQL(sql);
}
private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) {
List<ElementValue> linking = getLinkingValues(semanticParseInfo);
if (CollectionUtils.isEmpty(linking)) {
return;
}
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
Collectors.groupingBy(ElementValue::getFieldValue,
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
sqlInfo.setCorrectS2SQL(sql);
}
private List<ElementValue> getLinkingValues(SemanticParseInfo semanticParseInfo) {
Object context = semanticParseInfo.getProperties().get(Constants.CONTEXT);
if (Objects.isNull(context)) {
return null;
}
ParseResult parseResult = JsonUtil.toObject(JsonUtil.toString(context), ParseResult.class);
if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
return null;
}
return parseResult.getLinkingValues();
}
private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo) {
List<ElementValue> linking = getLinkingValues(semanticParseInfo);
if (CollectionUtils.isEmpty(linking)) {
return;
}
Map<String, Map<String, String>> filedNameToValueMap = linking.stream().collect(
Collectors.groupingBy(ElementValue::getFieldName,
Collectors.mapping(ElementValue::getFieldValue, Collectors.toMap(
oldValue -> oldValue,
newValue -> newValue,
(existingValue, newValue) -> newValue)
)));
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
sqlInfo.setCorrectS2SQL(sql);
}
public void removeFilterIfNotInLinkingValue(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL);
if (CollectionUtils.isEmpty(whereExpressionList)) {
return;
}
List<ElementValue> linkingValues = getLinkingValues(semanticParseInfo);
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Set<String> dimensions = getDimensions(semanticParseInfo.getDataSetId(), semanticSchema);
if (CollectionUtils.isEmpty(linkingValues)) {
linkingValues = new ArrayList<>();
}
Set<String> linkingFieldNames = linkingValues.stream().map(linking -> linking.getFieldName())
.collect(Collectors.toSet());
Set<String> removeFieldNames = whereExpressionList.stream()
.filter(fieldExpression -> StringUtils.isBlank(fieldExpression.getFunction()))
.filter(fieldExpression -> !TimeDimensionEnum.containsTimeDimension(fieldExpression.getFieldName()))
.filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator()))
.filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName()))
.filter(fieldExpression -> !DateUtils.isAnyDateString(fieldExpression.getFieldValue().toString()))
.filter(fieldExpression -> !linkingFieldNames.contains(fieldExpression.getFieldName()))
.map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet());
String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
sqlInfo.setCorrectS2SQL(sql);
}
}

View File

@@ -1,29 +0,0 @@
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
/**
* Perform SQL corrections on the "Select" section in S2SQL.
*/
@Slf4j
public class SelectCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
if (!CollectionUtils.isEmpty(aggregateFields)
&& !CollectionUtils.isEmpty(selectFields)
&& aggregateFields.size() == selectFields.size()) {
return;
}
addFieldsToSelect(semanticParseInfo, correctS2SQL);
}
}

View File

@@ -1,13 +0,0 @@
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
/**
* A semantic corrector checks validity of extracted semantic information and
* performs correction and optimization if needed.
*/
public interface SemanticCorrector {
void correct(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
}

View File

@@ -1,57 +0,0 @@
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.util.jsqlparser.DateVisitor.DateBoundInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlDateSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.lang3.StringUtils;
/**
* Perform SQL corrections on the time in S2SQL.
*/
@Slf4j
public class TimeCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
parserDateDiffFunction(semanticParseInfo);
addLowerBoundDate(semanticParseInfo);
}
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
if (Objects.isNull(dateBoundInfo)) {
return;
}
if (StringUtils.isBlank(dateBoundInfo.getLowerBound())
&& StringUtils.isNotBlank(dateBoundInfo.getUpperBound())
&& StringUtils.isNotBlank(dateBoundInfo.getUpperDate())) {
String upperDate = dateBoundInfo.getUpperDate();
try {
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
String condExpr = dateBoundInfo.getColumName() + " >= '" + upperDate + "'";
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, CCJSqlParserUtil.parseCondExpression(condExpr));
} catch (JSQLParserException e) {
log.error("parseCondExpression", e);
}
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}
}
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
correctS2SQL = SqlReplaceHelper.replaceFunction(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}
}

View File

@@ -1,157 +0,0 @@
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.util.Strings;
import org.springframework.util.CollectionUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* Perform SQL corrections on the "Where" section in S2SQL.
*/
@Slf4j
public class WhereCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
addDateIfNotExist(queryContext, semanticParseInfo);
addQueryFilter(queryContext, semanticParseInfo);
updateFieldValueByTechName(queryContext, semanticParseInfo);
}
private void addQueryFilter(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
String queryFilter = getQueryFilter(queryContext.getQueryFilters());
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to correctS2SQL :{}", queryFilter);
Expression expression = null;
try {
expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
} catch (JSQLParserException e) {
log.error("parseCondExpression", e);
}
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}
}
private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext,
semanticParseInfo.getDataSetId(), semanticParseInfo.getQueryType());
if (StringUtils.isNotBlank(startEndDate.getLeft())
&& StringUtils.isNotBlank(startEndDate.getRight())) {
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
String dateChName = TimeDimensionEnum.DAY.getChName();
String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )", dateChName,
startEndDate.getLeft(), dateChName, startEndDate.getRight());
try {
Expression expression = CCJSqlParserUtil.parseCondExpression(condExpr);
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
} catch (JSQLParserException e) {
log.error("parseCondExpression:{}", e);
}
}
}
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}
private String getQueryFilter(QueryFilters queryFilters) {
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return null;
}
return queryFilters.getFilters().stream()
.map(filter -> {
String bizNameWrap = StringUtil.getSpaceWrap(filter.getName());
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
return bizNameWrap + operatorWrap + valueWrap;
})
.collect(Collectors.joining(Constants.AND_UPPER));
}
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Long dataSetId = semanticParseInfo.getDataSetId();
List<SchemaElement> dimensions = semanticSchema.getDimensions(dataSetId);
if (CollectionUtils.isEmpty(dimensions)) {
return;
}
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
aliasAndBizNameToTechName);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
if (CollectionUtils.isEmpty(dimensions)) {
return new HashMap<>();
}
Map<String, Map<String, String>> result = new HashMap<>();
for (SchemaElement dimension : dimensions) {
if (Objects.isNull(dimension)
|| Strings.isEmpty(dimension.getName())
|| CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) {
continue;
}
String name = dimension.getName();
Map<String, String> aliasAndBizNameToTechName = new HashMap<>();
for (SchemaValueMap valueMap : dimension.getSchemaValueMaps()) {
if (Objects.isNull(valueMap) || Strings.isEmpty(valueMap.getTechName())) {
continue;
}
if (Strings.isNotEmpty(valueMap.getBizName())) {
aliasAndBizNameToTechName.put(valueMap.getBizName(), valueMap.getTechName());
}
if (!CollectionUtils.isEmpty(valueMap.getAlias())) {
valueMap.getAlias().stream().forEach(alias -> {
if (Strings.isNotEmpty(alias)) {
aliasAndBizNameToTechName.put(alias, valueMap.getTechName());
}
});
}
}
if (!CollectionUtils.isEmpty(aliasAndBizNameToTechName)) {
result.put(name, aliasAndBizNameToTechName);
}
}
return result;
}
}

View File

@@ -1,100 +0,0 @@
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
@Slf4j
public abstract class BaseMapper implements SchemaMapper {
@Override
public void map(QueryContext queryContext) {
String simpleName = this.getClass().getSimpleName();
long startTime = System.currentTimeMillis();
log.debug("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getDataSetElementMatches());
try {
doMap(queryContext);
} catch (Exception e) {
log.error("work error", e);
}
long cost = System.currentTimeMillis() - startTime;
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost,
queryContext.getMapInfo().getDataSetElementMatches());
}
public abstract void doMap(QueryContext queryContext);
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) {
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getDataSetElementMatches();
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>());
if (schemaElementMatches == null) {
schemaElementMatches = modelElementMatches.get(modelId);
}
//remove duplication
AtomicBoolean needAddNew = new AtomicBoolean(true);
schemaElementMatches.removeIf(
existElementMatch -> {
SchemaElement existElement = existElementMatch.getElement();
SchemaElement newElement = newElementMatch.getElement();
if (existElement.equals(newElement)) {
if (newElementMatch.getSimilarity() > existElementMatch.getSimilarity()) {
return true;
} else {
needAddNew.set(false);
}
}
return false;
}
);
if (needAddNew.get()) {
schemaElementMatches.add(newElementMatch);
}
}
public SchemaElement getSchemaElement(Long dataSetId, SchemaElementType elementType, Long elementID,
SemanticSchema semanticSchema) {
SchemaElement element = new SchemaElement();
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
if (Objects.isNull(dataSetSchema)) {
return null;
}
SchemaElement elementDb = dataSetSchema.getElement(elementType, elementID);
if (Objects.isNull(elementDb)) {
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
return null;
}
BeanUtils.copyProperties(elementDb, element);
element.setAlias(getAlias(elementDb));
return element;
}
public List<String> getAlias(SchemaElement element) {
if (!SchemaElementType.VALUE.equals(element.getType())) {
return element.getAlias();
}
if (org.apache.commons.collections.CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(
element.getName())) {
return element.getAlias().stream()
.filter(aliasItem -> aliasItem.contains(element.getName()))
.collect(Collectors.toList());
}
return element.getAlias();
}
}

View File

@@ -1,156 +0,0 @@
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
@Slf4j
public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
@Autowired
private MapperHelper mapperHelper;
@Override
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
String text = queryContext.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
}
log.debug("terms:{},,detectDataSetIds:{}", terms, detectDataSetIds);
List<T> detects = detect(queryContext, terms, detectDataSetIds);
Map<MatchText, List<T>> result = new HashMap<>();
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
return result;
}
public List<T> detect(QueryContext queryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
String text = queryContext.getQueryText();
Set<T> results = new HashSet<>();
Set<String> detectSegments = new HashSet<>();
for (Integer startIndex = 0; startIndex <= text.length() - 1; ) {
for (Integer index = startIndex; index <= text.length(); ) {
int offset = mapperHelper.getStepOffset(terms, startIndex);
index = mapperHelper.getStepIndex(regOffsetToLength, index);
if (index <= text.length()) {
String detectSegment = text.substring(startIndex, index).trim();
detectSegments.add(detectSegment);
detectByStep(queryContext, results, detectDataSetIds, detectSegment, offset);
}
}
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
}
detectByBatch(queryContext, results, detectDataSetIds, detectSegments);
return new ArrayList<>(results);
}
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectDataSetIds,
Set<String> detectSegments) {
return;
}
public Map<Integer, Integer> getRegOffsetToLength(List<S2Term> terms) {
return terms.stream().sorted(Comparator.comparing(S2Term::length))
.collect(Collectors.toMap(S2Term::getOffset, term -> term.word.length(),
(value1, value2) -> value2));
}
public void selectResultInOneRound(Set<T> existResults, List<T> oneRoundResults) {
if (CollectionUtils.isEmpty(oneRoundResults)) {
return;
}
for (T oneRoundResult : oneRoundResults) {
if (existResults.contains(oneRoundResult)) {
boolean isDeleted = existResults.removeIf(
existResult -> {
boolean delete = needDelete(oneRoundResult, existResult);
if (delete) {
log.info("deleted existResult:{}", existResult);
}
return delete;
}
);
if (isDeleted) {
log.info("deleted, add oneRoundResult:{}", oneRoundResult);
existResults.add(oneRoundResult);
}
} else {
existResults.add(oneRoundResult);
}
}
}
public List<T> getMatches(QueryContext queryContext, List<S2Term> terms) {
Set<Long> dataSetIds = mapperHelper.getDataSetIds(queryContext.getDataSetId(), queryContext.getAgent());
terms = filterByDataSetId(terms, dataSetIds);
Map<MatchText, List<T>> matchResult = match(queryContext, terms, dataSetIds);
List<T> matches = new ArrayList<>();
if (Objects.isNull(matchResult)) {
return matches;
}
Optional<List<T>> first = matchResult.entrySet().stream()
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
.map(entry -> entry.getValue()).findFirst();
if (first.isPresent()) {
matches = first.get();
}
return matches;
}
public List<S2Term> filterByDataSetId(List<S2Term> terms, Set<Long> dataSetIds) {
logTerms(terms);
if (CollectionUtils.isNotEmpty(dataSetIds)) {
terms = terms.stream().filter(term -> {
Long dataSetId = NatureHelper.getDataSetId(term.getNature().toString());
if (Objects.nonNull(dataSetId)) {
return dataSetIds.contains(dataSetId);
}
return false;
}).collect(Collectors.toList());
log.info("terms filter by dataSetId:{}", dataSetIds);
logTerms(terms);
}
return terms;
}
public void logTerms(List<S2Term> terms) {
if (CollectionUtils.isEmpty(terms)) {
return;
}
for (S2Term term : terms) {
log.debug("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
}
}
public abstract boolean needDelete(T oneRoundResult, T existResult);
public abstract String getMapKey(T a);
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset);
}

View File

@@ -1,125 +0,0 @@
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.DatabaseMapResult;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.stream.Collectors;
/**
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements.
* It currently supports fuzzy matching against names and aliases.
*/
@Service
@Slf4j
public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult> {
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private MapperHelper mapperHelper;
private List<SchemaElement> allElements;
@Override
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
this.allElements = getSchemaElements(queryContext);
return super.match(queryContext, terms, detectDataSetIds);
}
@Override
public boolean needDelete(DatabaseMapResult oneRoundResult, DatabaseMapResult existResult) {
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
}
@Override
public String getMapKey(DatabaseMapResult a) {
return a.getName() + Constants.UNDERLINE + a.getSchemaElement().getId()
+ Constants.UNDERLINE + a.getSchemaElement().getName();
}
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset) {
if (StringUtils.isBlank(detectSegment)) {
return;
}
Double metricDimensionThresholdConfig = getThreshold(queryContext);
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
String name = entry.getKey();
if (!name.contains(detectSegment)
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
continue;
}
Set<SchemaElement> schemaElements = entry.getValue();
if (!CollectionUtils.isEmpty(detectDataSetIds)) {
schemaElements = schemaElements.stream()
.filter(schemaElement -> detectDataSetIds.contains(schemaElement.getDataSet()))
.collect(Collectors.toSet());
}
for (SchemaElement schemaElement : schemaElements) {
DatabaseMapResult databaseMapResult = new DatabaseMapResult();
databaseMapResult.setDetectWord(detectSegment);
databaseMapResult.setName(schemaElement.getName());
databaseMapResult.setSchemaElement(schemaElement);
existResults.add(databaseMapResult);
}
}
}
private List<SchemaElement> getSchemaElements(QueryContext queryContext) {
List<SchemaElement> allElements = new ArrayList<>();
allElements.addAll(queryContext.getSemanticSchema().getDimensions());
allElements.addAll(queryContext.getSemanticSchema().getMetrics());
return allElements;
}
private Double getThreshold(QueryContext queryContext) {
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getDataSetElementMatches();
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
if (!existElement) {
double halfThreshold = metricDimensionThresholdConfig / 2;
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
: metricDimensionMinThresholdConfig;
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
modelElementMatches, metricDimensionThresholdConfig);
}
return metricDimensionThresholdConfig;
}
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
return models.stream().collect(
Collectors.toMap(SchemaElement::getName, a -> {
Set<SchemaElement> result = new HashSet<>();
result.add(a);
return result;
}, (k1, k2) -> {
k1.addAll(k2);
return k1;
}));
}
}

View File

@@ -1,60 +0,0 @@
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.server.service.KnowledgeService;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
import java.util.Objects;
/***
* A mapper that recognizes schema elements with vector embedding.
*/
@Slf4j
public class EmbeddingMapper extends BaseMapper {
@Override
public void doMap(QueryContext queryContext) {
//1. query from embedding by queryText
String queryText = queryContext.getQueryText();
KnowledgeService knowledgeService = ContextUtils.getBean(KnowledgeService.class);
List<S2Term> terms = knowledgeService.getTerms(queryText);
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);
HanlpHelper.transLetterOriginal(matchResults);
//2. build SchemaElementMatch by info
for (EmbeddingResult matchResult : matchResults) {
Long elementId = Retrieval.getLongId(matchResult.getId());
Long dataSetId = Retrieval.getLongId(matchResult.getMetadata().get("dataSetId"));
if (Objects.isNull(dataSetId)) {
continue;
}
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId,
queryContext.getSemanticSchema());
if (schemaElement == null) {
continue;
}
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(schemaElement)
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
.word(matchResult.getName())
.similarity(1 - matchResult.getDistance())
.detectWord(matchResult.getDetectWord())
.build();
//3. add to mapInfo
addToSchemaMap(queryContext.getMapInfo(), dataSetId, schemaElementMatch);
}
}
}

View File

@@ -1,125 +0,0 @@
package com.tencent.supersonic.chat.core.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
/**
* EmbeddingMatchStrategy uses vector database to perform
* similarity search against the embeddings of schema elements.
*/
@Service
@Slf4j
public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private MetaEmbeddingService metaEmbeddingService;
@Override
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
&& existResult.getDistance() > oneRoundResult.getDistance();
}
@Override
public String getMapKey(EmbeddingResult a) {
return a.getName() + Constants.UNDERLINE + a.getId();
}
@Override
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset) {
}
@Override
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
Set<String> detectSegments) {
List<String> queryTextsList = detectSegments.stream()
.map(detectSegment -> detectSegment.trim())
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
&& detectSegment.length() >= optimizationConfig.getEmbeddingMapperWordMin()
&& detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMax())
.collect(Collectors.toList());
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
optimizationConfig.getEmbeddingMapperBatch());
for (List<String> queryTextsSub : queryTextsSubList) {
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub);
}
}
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
List<String> queryTextsSub) {
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
// step1. build query params
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
// step2. retrieveQuery by detectSegment
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
new ArrayList<>(detectDataSetIds), retrieveQuery, embeddingNumber);
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
return;
}
// step3. build EmbeddingResults
List<EmbeddingResult> collect = retrieveQueryResults.stream()
.map(retrieveQueryResult -> {
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
if (CollectionUtils.isNotEmpty(retrievals)) {
retrievals.removeIf(retrieval -> {
if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) {
return retrieval.getDistance() > distance.doubleValue();
}
return false;
});
}
return retrieveQueryResult;
})
.filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval()))
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream()
.map(retrieval -> {
EmbeddingResult embeddingResult = new EmbeddingResult();
BeanUtils.copyProperties(retrieval, embeddingResult);
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
embeddingResult.setName(retrieval.getQuery());
Map<String, String> convertedMap = retrieval.getMetadata().entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toString()));
embeddingResult.setMetadata(convertedMap);
return embeddingResult;
}))
.collect(Collectors.toList());
// step4. select mapResul in one round
int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber() * queryTextsSub.size();
List<EmbeddingResult> oneRoundResults = collect.stream()
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
.limit(roundNumber)
.collect(Collectors.toList());
selectResultInOneRound(results, oneRoundResults);
}
}

View File

@@ -1,76 +0,0 @@
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.stream.Collectors;
/**
* A mapper capable of converting the VALUE of entity dimension values into ID types.
*/
@Slf4j
public class EntityMapper extends BaseMapper {
@Override
public void doMap(QueryContext queryContext) {
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
for (Long dataSetId : schemaMapInfo.getMatchedDataSetInfos()) {
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
continue;
}
SchemaElement entity = getEntity(dataSetId, queryContext);
if (entity == null || entity.getId() == null) {
continue;
}
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
.filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
continue;
}
if (!checkExistSameEntitySchemaElements(schemaElementMatch, schemaElementMatchList)) {
SchemaElementMatch entitySchemaElementMath = new SchemaElementMatch();
BeanUtils.copyProperties(schemaElementMatch, entitySchemaElementMath);
entitySchemaElementMath.setElement(entity);
schemaElementMatchList.add(entitySchemaElementMath);
}
schemaElementMatch.getElement().setType(SchemaElementType.ID);
}
}
}
private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch,
List<SchemaElementMatch> schemaElementMatchList) {
List<SchemaElementMatch> entitySchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
SchemaElementType.ENTITY.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : entitySchemaElements) {
if (schemaElementMatch.getElement().getId().equals(valueSchemaElementMatch.getElement().getId())) {
return true;
}
}
return false;
}
private SchemaElement getEntity(Long dataSetId, QueryContext queryContext) {
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
DataSetSchema modelSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
if (modelSchema != null && modelSchema.getEntity() != null) {
return modelSchema.getEntity();
}
return null;
}
}

View File

@@ -1,119 +0,0 @@
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.server.service.KnowledgeService;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
/**
* HanlpDictMatchStrategy uses <a href="https://www.hanlp.com/">HanLP</a> to
* match schema elements. It currently supports prefix and suffix matching
* against names, values and aliases.
*/
@Service
@Slf4j
public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
@Autowired
private MapperHelper mapperHelper;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private KnowledgeService knowledgeService;
@Override
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
String text = queryContext.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
}
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectDataSetIds);
List<HanlpMapResult> detects = detect(queryContext, terms, detectDataSetIds);
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
return result;
}
@Override
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
}
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset) {
// step1. pre search
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
oneDetectionMaxSize, detectDataSetIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
// step2. suffix search
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(detectSegment,
oneDetectionMaxSize, detectDataSetIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
hanlpMapResults.addAll(suffixHanlpMapResults);
if (CollectionUtils.isEmpty(hanlpMapResults)) {
return;
}
// step3. merge pre/suffix result
hanlpMapResults = hanlpMapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toCollection(LinkedHashSet::new));
// step4. filter by similarity
hanlpMapResults = hanlpMapResults.stream()
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
>= mapperHelper.getThresholdMatch(term.getNatures()))
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
.collect(Collectors.toCollection(LinkedHashSet::new));
log.info("after isSimilarity parseResults:{}", hanlpMapResults);
hanlpMapResults = hanlpMapResults.stream().map(parseResult -> {
parseResult.setOffset(offset);
parseResult.setSimilarity(mapperHelper.getSimilarity(detectSegment, parseResult.getName()));
return parseResult;
}).collect(Collectors.toCollection(LinkedHashSet::new));
// step5. take only one dimension or 10 metric/dimension value per rond.
List<HanlpMapResult> dimensionMetrics = hanlpMapResults.stream()
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
.collect(Collectors.toList())
.stream()
.limit(1)
.collect(Collectors.toList());
Integer oneDetectionSize = optimizationConfig.getOneDetectionSize();
List<HanlpMapResult> oneRoundResults = hanlpMapResults.stream().limit(oneDetectionSize)
.collect(Collectors.toList());
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
oneRoundResults = dimensionMetrics;
}
// step6. select mapResul in one round
selectResultInOneRound(existResults, oneRoundResults);
}
public String getMapKey(HanlpMapResult a) {
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
}
}

View File

@@ -1,126 +0,0 @@
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.DatabaseMapResult;
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.headless.server.service.KnowledgeService;
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/***
* A mapper that recognizes schema elements with keyword.
* It leverages two matching strategies: HanlpDictMatchStrategy and DatabaseMatchStrategy.
*/
@Slf4j
public class KeywordMapper extends BaseMapper {
@Override
public void doMap(QueryContext queryContext) {
String queryText = queryContext.getQueryText();
//1.hanlpDict Match
KnowledgeService knowledgeService = ContextUtils.getBean(KnowledgeService.class);
List<S2Term> terms = knowledgeService.getTerms(queryText);
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
List<HanlpMapResult> hanlpMapResults = hanlpMatchStrategy.getMatches(queryContext, terms);
convertHanlpMapResultToMapInfo(hanlpMapResults, queryContext, terms);
//2.database Match
DatabaseMatchStrategy databaseMatchStrategy = ContextUtils.getBean(DatabaseMatchStrategy.class);
List<DatabaseMapResult> databaseResults = databaseMatchStrategy.getMatches(queryContext, terms);
convertDatabaseMapResultToMapInfo(queryContext, databaseResults);
}
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, QueryContext queryContext,
List<S2Term> terms) {
if (CollectionUtils.isEmpty(mapResults)) {
return;
}
HanlpHelper.transLetterOriginal(mapResults);
Map<String, Long> wordNatureToFrequency = terms.stream().collect(
Collectors.toMap(entry -> entry.getWord() + entry.getNature(),
term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2));
for (HanlpMapResult hanlpMapResult : mapResults) {
for (String nature : hanlpMapResult.getNatures()) {
Long dataSetId = NatureHelper.getDataSetId(nature);
if (Objects.isNull(dataSetId)) {
continue;
}
SchemaElementType elementType = NatureHelper.convertToElementType(nature);
if (Objects.isNull(elementType)) {
continue;
}
Long elementID = NatureHelper.getElementID(nature);
SchemaElement element = getSchemaElement(dataSetId, elementType,
elementID, queryContext.getSemanticSchema());
if (element == null) {
continue;
}
if (element.getType().equals(SchemaElementType.VALUE) || element.getType()
.equals(SchemaElementType.TAG_VALUE)) {
element.setName(hanlpMapResult.getName());
}
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element)
.frequency(frequency)
.word(hanlpMapResult.getName())
.similarity(hanlpMapResult.getSimilarity())
.detectWord(hanlpMapResult.getDetectWord())
.build();
addToSchemaMap(queryContext.getMapInfo(), dataSetId, schemaElementMatch);
}
}
}
private void convertDatabaseMapResultToMapInfo(QueryContext queryContext, List<DatabaseMapResult> mapResults) {
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
for (DatabaseMapResult match : mapResults) {
SchemaElement schemaElement = match.getSchemaElement();
Set<Long> regElementSet = getRegElementSet(queryContext.getMapInfo(), schemaElement);
if (regElementSet.contains(schemaElement.getId())) {
continue;
}
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(schemaElement)
.word(schemaElement.getName())
.detectWord(match.getDetectWord())
.frequency(10000L)
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
.build();
log.info("add to schema, elementMatch {}", schemaElementMatch);
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getDataSet(), schemaElementMatch);
}
}
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getDataSet());
if (CollectionUtils.isEmpty(elements)) {
return new HashSet<>();
}
return elements.stream()
.filter(elementMatch ->
SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
|| SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType()))
.map(elementMatch -> elementMatch.getElement().getId())
.collect(Collectors.toSet());
}
}

View File

@@ -1,114 +0,0 @@
package com.tencent.supersonic.chat.core.mapper;
import com.hankcs.hanlp.algorithm.EditDistance;
import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@Data
@Service
@Slf4j
public class MapperHelper {
@Autowired
private OptimizationConfig optimizationConfig;
public Integer getStepIndex(Map<Integer, Integer> regOffsetToLength, Integer index) {
Integer subRegLength = regOffsetToLength.get(index);
if (Objects.nonNull(subRegLength)) {
index = index + subRegLength;
} else {
index++;
}
return index;
}
public Integer getStepOffset(List<S2Term> termList, Integer index) {
List<Integer> offsetList = termList.stream().sorted(Comparator.comparing(S2Term::getOffset))
.map(term -> term.getOffset()).collect(Collectors.toList());
for (int j = 0; j < termList.size() - 1; j++) {
if (offsetList.get(j) <= index && offsetList.get(j + 1) > index) {
return offsetList.get(j);
}
}
return index;
}
public double getThresholdMatch(List<String> natures) {
if (existDimensionValues(natures)) {
return optimizationConfig.getDimensionValueThresholdConfig();
}
return optimizationConfig.getMetricDimensionThresholdConfig();
}
/***
* exist dimension values
* @param natures
* @return
*/
public boolean existDimensionValues(List<String> natures) {
for (String nature : natures) {
if (NatureHelper.isDimensionValueDataSetId(nature)) {
return true;
}
}
return false;
}
/***
* get similarity
* @param detectSegment
* @param matchName
* @return
*/
public double getSimilarity(String detectSegment, String matchName) {
String detectSegmentLower = detectSegment == null ? null : detectSegment.toLowerCase();
String matchNameLower = matchName == null ? null : matchName.toLowerCase();
return 1 - (double) EditDistance.compute(detectSegmentLower, matchNameLower) / Math.max(matchName.length(),
detectSegment.length());
}
public Set<Long> getDataSetIds(Long dataSetId, Agent agent) {
Set<Long> detectDataSetIds = new HashSet<>();
if (Objects.nonNull(agent)) {
detectDataSetIds = agent.getDataSetIds();
}
//contains all
if (Agent.containsAllModel(detectDataSetIds)) {
if (Objects.nonNull(dataSetId) && dataSetId > 0) {
Set<Long> result = new HashSet<>();
result.add(dataSetId);
return result;
}
return new HashSet<>();
}
if (Objects.nonNull(detectDataSetIds)) {
detectDataSetIds = detectDataSetIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
}
if (Objects.nonNull(dataSetId) && dataSetId > 0 && Objects.nonNull(detectDataSetIds)) {
if (detectDataSetIds.contains(dataSetId)) {
Set<Long> result = new HashSet<>();
result.add(dataSetId);
return result;
}
}
return detectDataSetIds;
}
}

View File

@@ -1,18 +0,0 @@
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* MatchStrategy encapsulates a concrete matching algorithm
* executed during query or search process.
*/
public interface MatchStrategy<T> {
Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms, Set<Long> detectDataSetIds);
}

View File

@@ -1,33 +0,0 @@
package com.tencent.supersonic.chat.core.mapper;
import java.util.Objects;
import lombok.Builder;
import lombok.Data;
import lombok.ToString;
@Data
@ToString
@Builder
public class MatchText {
private String regText;
private String detectSegment;
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
MatchText that = (MatchText) o;
return Objects.equals(regText, that.regText) && Objects.equals(detectSegment, that.detectSegment);
}
@Override
public int hashCode() {
return Objects.hash(regText, detectSegment);
}
}

View File

@@ -1,19 +0,0 @@
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import java.io.Serializable;
import lombok.Data;
import lombok.ToString;
@Data
@ToString
public class ModelWithSemanticType implements Serializable {
private Long model;
private SchemaElementType schemaElementType;
public ModelWithSemanticType(Long model, SchemaElementType schemaElementType) {
this.model = model;
this.schemaElementType = schemaElementType;
}
}

View File

@@ -1,99 +0,0 @@
package com.tencent.supersonic.chat.core.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
public class QueryFilterMapper implements SchemaMapper {
private double similarity = 1.0;
@Override
public void map(QueryContext queryContext) {
Agent agent = queryContext.getAgent();
if (agent == null || CollectionUtils.isEmpty(agent.getDataSetIds())) {
return;
}
if (Agent.containsAllModel(agent.getDataSetIds())) {
return;
}
Set<Long> viewIds = agent.getDataSetIds();
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
clearOtherSchemaElementMatch(viewIds, schemaMapInfo);
for (Long viewId : viewIds) {
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(viewId);
if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList();
schemaMapInfo.setMatchedElements(viewId, schemaElementMatches);
}
addValueSchemaElementMatch(viewId, queryContext, schemaElementMatches);
}
}
private void clearOtherSchemaElementMatch(Set<Long> viewIds, SchemaMapInfo schemaMapInfo) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getDataSetElementMatches().entrySet()) {
if (!viewIds.contains(entry.getKey())) {
entry.getValue().clear();
}
}
}
private List<SchemaElementMatch> addValueSchemaElementMatch(Long viewId, QueryContext queryContext,
List<SchemaElementMatch> candidateElementMatches) {
QueryFilters queryFilters = queryContext.getQueryFilters();
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return candidateElementMatches;
}
for (QueryFilter filter : queryFilters.getFilters()) {
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
continue;
}
SchemaElement element = SchemaElement.builder()
.id(filter.getElementID())
.name(String.valueOf(filter.getValue()))
.type(SchemaElementType.VALUE)
.bizName(filter.getBizName())
.dataSet(viewId)
.build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element)
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
.word(String.valueOf(filter.getValue()))
.similarity(similarity)
.detectWord(Constants.EMPTY)
.build();
candidateElementMatches.add(schemaElementMatch);
}
return candidateElementMatches;
}
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,
List<SchemaElementMatch> schemaElementMatches) {
List<SchemaElementMatch> valueSchemaElements = schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
if (schemaElementMatch.getElement().getId().equals(queryFilter.getElementID())
&& schemaElementMatch.getWord().equals(String.valueOf(queryFilter.getValue()))) {
return true;
}
}
return false;
}
}

View File

@@ -1,12 +0,0 @@
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
/**
* A schema mapper identifies references to schema elements(metrics/dimensions/entities/values)
* in user queries. It matches the query text against the knowledge base.
*/
public interface SchemaMapper {
void map(QueryContext queryContext);
}

View File

@@ -1,101 +0,0 @@
package com.tencent.supersonic.chat.core.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.core.knowledge.SearchService;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.headless.server.service.KnowledgeService;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
/**
* SearchMatchStrategy encapsulates a concrete matching algorithm
* executed during search process.
*/
@Service
public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
private static final int SEARCH_SIZE = 3;
@Autowired
private KnowledgeService knowledgeService;
@Override
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
Set<Long> detectDataSetIds) {
String text = queryContext.getQueryText();
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
List<Integer> detectIndexList = Lists.newArrayList();
for (Integer index = 0; index < text.length(); ) {
if (index < text.length()) {
detectIndexList.add(index);
}
Integer regLength = regOffsetToLength.get(index);
if (Objects.nonNull(regLength)) {
index = index + regLength;
} else {
index++;
}
}
Map<MatchText, List<HanlpMapResult>> regTextMap = new ConcurrentHashMap<>();
detectIndexList.stream().parallel().forEach(detectIndex -> {
String regText = text.substring(0, detectIndex);
String detectSegment = text.substring(detectIndex);
if (StringUtils.isNotEmpty(detectSegment)) {
List<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
SearchService.SEARCH_SIZE, detectDataSetIds);
List<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(
detectSegment, SEARCH_SIZE, detectDataSetIds);
hanlpMapResults.addAll(suffixHanlpMapResults);
// remove entity name where search
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
List<String> natures = entry.getNatures().stream()
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getTypeWithSpilt()))
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(natures)) {
return false;
}
return true;
}).collect(Collectors.toList());
MatchText matchText = MatchText.builder()
.regText(regText)
.detectSegment(detectSegment)
.build();
regTextMap.put(matchText, hanlpMapResults);
}
}
);
return regTextMap;
}
@Override
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
return false;
}
@Override
public String getMapKey(HanlpMapResult a) {
return null;
}
@Override
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset) {
}
}

View File

@@ -1,66 +0,0 @@
package com.tencent.supersonic.chat.core.parser;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionPromptGenerator;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGeneration;
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGenerationFactory;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import java.util.Objects;
/**
* LLMProxy based on langchain4j Java version.
*/
@Slf4j
@Component
public class JavaLLMProxy implements LLMProxy {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Override
public boolean isSkip(QueryContext queryContext) {
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
if (Objects.isNull(chatLanguageModel)) {
log.warn("chatLanguageModel is null, skip :{}", JavaLLMProxy.class.getName());
return true;
}
return false;
}
public LLMResp query2sql(LLMReq llmReq, Long dataSetId) {
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
String modelName = llmReq.getSchema().getDataSetName();
LLMResp result = sqlGeneration.generation(llmReq, dataSetId);
result.setQuery(llmReq.getQueryText());
result.setModelName(modelName);
return result;
}
@Override
public FunctionResp requestFunction(FunctionReq functionReq) {
FunctionPromptGenerator promptGenerator = ContextUtils.getBean(FunctionPromptGenerator.class);
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
String functionCallPrompt = promptGenerator.generateFunctionCallPrompt(functionReq.getQueryText(),
functionReq.getPluginConfigs());
keyPipelineLog.info("functionCallPrompt:{}", functionCallPrompt);
String response = chatLanguageModel.generate(functionCallPrompt);
keyPipelineLog.info("functionCall response:{}", response);
return OutputFormat.functionCallParse(response);
}
}

View File

@@ -1,22 +0,0 @@
package com.tencent.supersonic.chat.core.parser;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
/**
* LLMProxy encapsulates functions performed by LLMs so that multiple
* orchestration frameworks (e.g. LangChain in python, LangChain4j in java)
* could be used.
*/
public interface LLMProxy {
boolean isSkip(QueryContext queryContext);
LLMResp query2sql(LLMReq llmReq, Long dataSetId);
FunctionResp requestFunction(FunctionReq functionReq);
}

View File

@@ -1,104 +0,0 @@
package com.tencent.supersonic.chat.core.parser;
import com.alibaba.fastjson.JSON;
import com.tencent.supersonic.chat.core.config.LLMParserConfig;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionCallConfig;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.net.URL;
import java.util.ArrayList;
/**
* PythonLLMProxy sends requests to LangChain-based python service.
*/
@Slf4j
@Component
public class PythonLLMProxy implements LLMProxy {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Override
public boolean isSkip(QueryContext queryContext) {
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
log.warn("llmParserUrl is empty, skip :{}", PythonLLMProxy.class.getName());
return true;
}
return false;
}
public LLMResp query2sql(LLMReq llmReq, Long dataSetId) {
long startTime = System.currentTimeMillis();
log.info("requestLLM request, dataSetId:{},llmReq:{}", dataSetId, llmReq);
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
try {
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
URL url = new URL(new URL(llmParserConfig.getUrl()), llmParserConfig.getQueryToSqlPath());
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
LLMResp.class);
LLMResp llmResp = responseEntity.getBody();
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
System.currentTimeMillis() - startTime, url, entity, llmResp);
keyPipelineLog.info("LLMResp:{}", llmResp);
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(new ArrayList<>(), llmResp.getSqlWeight()));
}
return llmResp;
} catch (Exception e) {
log.error("requestLLM error", e);
}
return null;
}
public FunctionResp requestFunction(FunctionReq functionReq) {
FunctionCallConfig functionCallInfoConfig = ContextUtils.getBean(FunctionCallConfig.class);
String url = functionCallInfoConfig.getUrl() + functionCallInfoConfig.getPluginSelectPath();
HttpHeaders headers = new HttpHeaders();
long startTime = System.currentTimeMillis();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> entity = new HttpEntity<>(JSON.toJSONString(functionReq), headers);
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
try {
log.info("requestFunction functionReq:{}", JsonUtil.toString(functionReq));
keyPipelineLog.info("requestFunction functionReq:{}", JsonUtil.toString(functionReq));
ResponseEntity<FunctionResp> responseEntity = restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
FunctionResp.class);
log.info("requestFunction responseEntity:{},cost:{}", responseEntity,
System.currentTimeMillis() - startTime);
keyPipelineLog.info("response:{}", responseEntity.getBody());
return responseEntity.getBody();
} catch (Exception e) {
log.error("requestFunction error", e);
}
return null;
}
}

View File

@@ -1,32 +0,0 @@
package com.tencent.supersonic.chat.core.parser;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
/**
* QueryTypeParser resolves query type as either METRIC or TAG, or ID.
*/
@Slf4j
public class QueryTypeParser implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
User user = queryContext.getUser();
for (SemanticQuery semanticQuery : candidateQueries) {
// 1.init S2SQL
semanticQuery.initS2Sql(queryContext.getSemanticSchema(), user);
// 2.set queryType
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
parseInfo.setQueryType(queryContext.getQueryType(parseInfo.getDataSetId()));
}
}
}

View File

@@ -1,49 +0,0 @@
package com.tencent.supersonic.chat.core.parser;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
/**
* This checker can be used by semantic parsers to check if query intent
* has already been satisfied by current candidate queries. If so, current
* parser could be skipped.
*/
@Slf4j
public class SatisfactionChecker {
// check all the parse info in candidate
public static boolean isSkip(QueryContext queryContext) {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
if (query.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
continue;
}
if (checkThreshold(queryContext.getQueryText(), query.getParseInfo())) {
return true;
}
}
return false;
}
private static boolean checkThreshold(String queryText, SemanticParseInfo semanticParseInfo) {
int queryTextLength = queryText.replaceAll(" ", "").length();
double degree = semanticParseInfo.getScore() / queryTextLength;
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
if (queryTextLength > optimizationConfig.getQueryTextLengthThreshold()) {
if (degree < optimizationConfig.getLongTextThreshold()) {
return false;
}
} else if (degree < optimizationConfig.getShortTextThreshold()) {
return false;
}
log.info("queryMode:{}, degree:{}, parse info:{}",
semanticParseInfo.getQueryMode(), degree, semanticParseInfo);
return true;
}
}

View File

@@ -1,15 +0,0 @@
package com.tencent.supersonic.chat.core.parser;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
/**
* A semantic parser understands user queries and extracts semantic information.
* It could leverage either rule-based or LLM-based approach to identify query intent
* and extract related semantic items from the query.
*/
public interface SemanticParser {
void parse(QueryContext queryContext, ChatContext chatContext);
}

View File

@@ -1,127 +0,0 @@
package com.tencent.supersonic.chat.core.parser.plugin.function;
import com.tencent.supersonic.chat.core.parser.PythonLLMProxy;
import com.tencent.supersonic.chat.core.parser.plugin.ParseMode;
import com.tencent.supersonic.chat.core.parser.plugin.PluginParser;
import com.tencent.supersonic.chat.core.plugin.Plugin;
import com.tencent.supersonic.chat.core.plugin.PluginManager;
import com.tencent.supersonic.chat.core.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* FunctionCallParser is an implementation of a recall plugin based on FunctionCall
*/
@Slf4j
public class FunctionCallParser extends PluginParser {
@Override
public boolean checkPreCondition(QueryContext queryContext) {
FunctionCallConfig functionCallConfig = ContextUtils.getBean(FunctionCallConfig.class);
String functionUrl = functionCallConfig.getUrl();
if (StringUtils.isBlank(functionUrl) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) {
log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl,
queryContext.getQueryText());
return false;
}
List<Plugin> plugins = getPluginList(queryContext);
return !CollectionUtils.isEmpty(plugins);
}
@Override
public PluginRecallResult recallPlugin(QueryContext queryContext) {
FunctionResp functionResp = functionCall(queryContext);
if (skipFunction(functionResp)) {
return null;
}
log.info("requestFunction result:{}", functionResp.getToolSelection());
String toolSelection = functionResp.getToolSelection();
Plugin plugin = queryContext.getNameToPlugin().get(toolSelection);
if (Objects.isNull(plugin)) {
log.info("pluginOptional is not exist:{}, skip the parse", toolSelection);
return null;
}
plugin.setParseMode(ParseMode.FUNCTION_CALL);
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryContext);
if (pluginResolveResult.getLeft()) {
Set<Long> dataSetList = pluginResolveResult.getRight();
if (CollectionUtils.isEmpty(dataSetList)) {
return null;
}
double score = queryContext.getQueryText().length();
return PluginRecallResult.builder().plugin(plugin).dataSetIds(dataSetList).score(score).build();
}
return null;
}
public FunctionResp functionCall(QueryContext queryContext) {
List<PluginParseConfig> pluginToFunctionCall =
getPluginToFunctionCall(queryContext.getDataSetId(), queryContext);
if (CollectionUtils.isEmpty(pluginToFunctionCall)) {
log.info("function call parser, plugin is empty, skip");
return null;
}
FunctionResp functionResp = new FunctionResp();
if (pluginToFunctionCall.size() == 1) {
functionResp.setToolSelection(pluginToFunctionCall.iterator().next().getName());
} else {
FunctionReq functionReq = FunctionReq.builder()
.queryText(queryContext.getQueryText())
.pluginConfigs(pluginToFunctionCall).build();
functionResp = ComponentFactory.getLLMProxy().requestFunction(functionReq);
}
return functionResp;
}
private boolean skipFunction(FunctionResp functionResp) {
return Objects.isNull(functionResp) || StringUtils.isBlank(functionResp.getToolSelection());
}
private List<PluginParseConfig> getPluginToFunctionCall(Long modelId, QueryContext queryContext) {
log.info("user decide Model:{}", modelId);
List<Plugin> plugins = getPluginList(queryContext);
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
return false;
}
if (plugin.getParseModeConfig() == null) {
return false;
}
PluginParseConfig pluginParseConfig = JsonUtil.toObject(plugin.getParseModeConfig(),
PluginParseConfig.class);
if (StringUtils.isBlank(pluginParseConfig.getName())) {
return false;
}
Pair<Boolean, Set<Long>> pluginResolverResult = PluginManager.resolve(plugin, queryContext);
log.info("plugin [{}-{}] resolve: {}", plugin.getId(), plugin.getName(), pluginResolverResult);
if (!pluginResolverResult.getLeft()) {
return false;
} else {
Set<Long> resolveModel = pluginResolverResult.getRight();
if (modelId != null && modelId > 0) {
if (plugin.isContainsAllModel()) {
return true;
}
return resolveModel.contains(modelId);
}
return true;
}
}).map(o -> JsonUtil.toObject(o.getParseModeConfig(), PluginParseConfig.class)).collect(Collectors.toList());
log.info("PluginToFunctionCall: {}", JsonUtil.toString(functionDOList));
return functionDOList;
}
}

View File

@@ -1,9 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import lombok.Data;
@Data
public class DataSetMatchResult {
private Integer count = 0;
private double maxSimilarity;
}

View File

@@ -1,12 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import java.util.Set;
public interface DataSetResolver {
Long resolve(QueryContext queryContext, Set<Long> restrictiveModels);
}

View File

@@ -1,138 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
public class HeuristicDataSetResolver implements DataSetResolver {
protected static Long selectDataSetBySchemaElementMatchScore(Map<Long, SemanticQuery> dataSetQueryModes,
SchemaMapInfo schemaMap) {
//dataSet count priority
Long dataSetIdByDataSetCount = getDataSetIdByMatchDataSetScore(schemaMap);
if (Objects.nonNull(dataSetIdByDataSetCount)) {
log.info("selectDataSet by dataSet count:{}", dataSetIdByDataSetCount);
return dataSetIdByDataSetCount;
}
Map<Long, DataSetMatchResult> dataSetTypeMap = getDataSetTypeMap(schemaMap);
if (dataSetTypeMap.size() == 1) {
Long dataSetSelect = new ArrayList<>(dataSetTypeMap.entrySet()).get(0).getKey();
if (dataSetQueryModes.containsKey(dataSetSelect)) {
log.info("selectDataSet with only one DataSet [{}]", dataSetSelect);
return dataSetSelect;
}
} else {
Map.Entry<Long, DataSetMatchResult> maxDataSet = dataSetTypeMap.entrySet().stream()
.filter(entry -> dataSetQueryModes.containsKey(entry.getKey()))
.sorted((o1, o2) -> {
int difference = o2.getValue().getCount() - o1.getValue().getCount();
if (difference == 0) {
return (int) ((o2.getValue().getMaxSimilarity()
- o1.getValue().getMaxSimilarity()) * 100);
}
return difference;
}).findFirst().orElse(null);
if (maxDataSet != null) {
log.info("selectDataSet with multiple DataSets [{}]", maxDataSet.getKey());
return maxDataSet.getKey();
}
}
return null;
}
private static Long getDataSetIdByMatchDataSetScore(SchemaMapInfo schemaMap) {
Map<Long, List<SchemaElementMatch>> dataSetElementMatches = schemaMap.getDataSetElementMatches();
// calculate dataSet match score, matched element gets 1.0 point, and inherit element gets 0.5 point
Map<Long, Double> dataSetIdToDataSetScore = new HashMap<>();
if (Objects.nonNull(dataSetElementMatches)) {
for (Entry<Long, List<SchemaElementMatch>> dataSetElementMatch : dataSetElementMatches.entrySet()) {
Long dataSetId = dataSetElementMatch.getKey();
List<Double> dataSetMatchesScore = dataSetElementMatch.getValue().stream()
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
.filter(elementMatch -> SchemaElementType.DATASET.equals(elementMatch.getElement().getType()))
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
if (!CollectionUtils.isEmpty(dataSetMatchesScore)) {
// get sum of dataSet match score
double score = dataSetMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
dataSetIdToDataSetScore.put(dataSetId, score);
}
}
Entry<Long, Double> maxDataSetScore = dataSetIdToDataSetScore.entrySet().stream()
.max(Comparator.comparingDouble(Entry::getValue)).orElse(null);
log.info("maxDataSetCount:{},dataSetIdToDataSetCount:{}", maxDataSetScore, dataSetIdToDataSetScore);
if (Objects.nonNull(maxDataSetScore)) {
return maxDataSetScore.getKey();
}
}
return null;
}
public static Map<Long, DataSetMatchResult> getDataSetTypeMap(SchemaMapInfo schemaMap) {
Map<Long, DataSetMatchResult> dataSetCount = new HashMap<>();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getDataSetElementMatches().entrySet()) {
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
if (!dataSetCount.containsKey(entry.getKey())) {
dataSetCount.put(entry.getKey(), new DataSetMatchResult());
}
DataSetMatchResult dataSetMatchResult = dataSetCount.get(entry.getKey());
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
schemaElementMatches.stream()
.forEach(schemaElementMatch -> schemaElementTypes.add(
schemaElementMatch.getElement().getType()));
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
.sorted((o1, o2) ->
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
).findFirst().orElse(null);
if (schemaElementMatchMax != null) {
dataSetMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
}
dataSetMatchResult.setCount(schemaElementTypes.size());
}
}
return dataSetCount;
}
public Long resolve(QueryContext queryContext, Set<Long> agentDataSetIds) {
SchemaMapInfo mapInfo = queryContext.getMapInfo();
Set<Long> matchedDataSets = mapInfo.getMatchedDataSetInfos();
Long dataSetId = queryContext.getDataSetId();
if (Objects.nonNull(dataSetId) && dataSetId > 0) {
if (CollectionUtils.isEmpty(agentDataSetIds) || agentDataSetIds.contains(dataSetId)) {
return dataSetId;
}
return null;
}
if (CollectionUtils.isNotEmpty(agentDataSetIds)) {
matchedDataSets.retainAll(agentDataSetIds);
}
Map<Long, SemanticQuery> dataSetQueryModes = new HashMap<>();
for (Long dataSetIds : matchedDataSets) {
dataSetQueryModes.put(dataSetIds, null);
}
if (dataSetQueryModes.size() == 1) {
return dataSetQueryModes.keySet().stream().findFirst().get();
}
return selectDataSetBySchemaElementMatchScore(dataSetQueryModes, mapInfo);
}
}

View File

@@ -1,42 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class InputFormat {
public static final String SEPERATOR = "\n\n";
public static String format(String template, List<String> templateKey,
List<Map<String, String>> toFormatList) {
List<String> result = new ArrayList<>();
for (Map<String, String> formatItem : toFormatList) {
Map<String, String> retrievalMeta = subDict(formatItem, templateKey);
result.add(format(template, retrievalMeta));
}
return String.join(SEPERATOR, result);
}
public static String format(String input, Map<String, String> replacements) {
for (Map.Entry<String, String> entry : replacements.entrySet()) {
input = input.replace(entry.getKey(), entry.getValue());
}
return input;
}
private static Map<String, String> subDict(Map<String, String> dict, List<String> keys) {
Map<String, String> subDict = new HashMap<>();
for (String key : keys) {
if (dict.containsKey(key)) {
subDict.put(key, dict.get(key));
}
}
return subDict;
}
}

View File

@@ -1,269 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.agent.AgentToolType;
import com.tencent.supersonic.chat.core.agent.NL2SQLTool;
import com.tencent.supersonic.chat.core.config.LLMParserConfig;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
@Service
public class LLMRequestService {
protected SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
@Autowired
private LLMParserConfig llmParserConfig;
@Autowired
private OptimizationConfig optimizationConfig;
public boolean isSkip(QueryContext queryCtx) {
if (ComponentFactory.getLLMProxy().isSkip(queryCtx)) {
return true;
}
if (SatisfactionChecker.isSkip(queryCtx)) {
log.info("skip {}, queryText:{}", LLMSqlParser.class, queryCtx.getQueryText());
return true;
}
return false;
}
public Long getDataSetId(QueryContext queryCtx) {
Agent agent = queryCtx.getAgent();
Set<Long> agentDataSetIds = new HashSet<>();
if (Objects.nonNull(agent)) {
agentDataSetIds = agent.getDataSetIds(AgentToolType.NL2SQL_LLM);
}
if (Agent.containsAllModel(agentDataSetIds)) {
agentDataSetIds = new HashSet<>();
}
DataSetResolver dataSetResolver = ComponentFactory.getModelResolver();
return dataSetResolver.resolve(queryCtx, agentDataSetIds);
}
public NL2SQLTool getParserTool(QueryContext queryCtx, Long dataSetId) {
Agent agent = queryCtx.getAgent();
if (Objects.isNull(agent)) {
return null;
}
List<NL2SQLTool> commonAgentTools = agent.getParserTools(AgentToolType.NL2SQL_LLM);
Optional<NL2SQLTool> llmParserTool = commonAgentTools.stream()
.filter(tool -> {
List<Long> dataSetIds = tool.getDataSetIds();
if (Agent.containsAllModel(new HashSet<>(dataSetIds))) {
return true;
}
return dataSetIds.contains(dataSetId);
})
.findFirst();
return llmParserTool.orElse(null);
}
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId, List<ElementValue> linkingValues) {
Map<Long, String> dataSetIdToName = queryCtx.getSemanticSchema().getDataSetIdToName();
String queryText = queryCtx.getQueryText();
LLMReq llmReq = new LLMReq();
llmReq.setQueryText(queryText);
LLMReq.FilterCondition filterCondition = new LLMReq.FilterCondition();
llmReq.setFilterCondition(filterCondition);
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
llmSchema.setDomainName(dataSetIdToName.get(dataSetId));
List<String> fieldNameList = getFieldNameList(queryCtx, dataSetId, llmParserConfig);
String priorExts = getPriorExts(dataSetId, fieldNameList);
llmReq.setPriorExts(priorExts);
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
llmSchema.setFieldNameList(fieldNameList);
llmReq.setSchema(llmSchema);
List<ElementValue> linking = new ArrayList<>();
if (optimizationConfig.isUseLinkingValueSwitch()) {
linking.addAll(linkingValues);
}
llmReq.setLinking(linking);
String currentDate = S2SqlDateHelper.getReferenceDate(queryCtx, dataSetId);
if (StringUtils.isEmpty(currentDate)) {
currentDate = DateUtils.getBeforeDate(0);
}
llmReq.setCurrentDate(currentDate);
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode().getName());
return llmReq;
}
public LLMResp requestLLM(LLMReq llmReq, Long dataSetId) {
return ComponentFactory.getLLMProxy().query2sql(llmReq, dataSetId);
}
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
LLMParserConfig llmParserConfig) {
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, dataSetId);
results.addAll(fieldNameList);
return new ArrayList<>(results);
}
private String getPriorExts(Long dataSetId, List<String> fieldNameList) {
StringBuilder extraInfoSb = new StringBuilder();
List<DataSetSchemaResp> dataSetSchemaResps = semanticInterpreter.fetchDataSetSchema(
Lists.newArrayList(dataSetId), true);
if (!CollectionUtils.isEmpty(dataSetSchemaResps)) {
DataSetSchemaResp dataSetSchemaResp = dataSetSchemaResps.get(0);
Map<String, String> fieldNameToDataFormatType = dataSetSchemaResp.getMetrics()
.stream().filter(metricSchemaResp -> Objects.nonNull(metricSchemaResp.getDataFormatType()))
.flatMap(metricSchemaResp -> {
Set<Pair<String, String>> result = new HashSet<>();
String dataFormatType = metricSchemaResp.getDataFormatType();
result.add(Pair.of(metricSchemaResp.getName(), dataFormatType));
List<String> aliasList = SchemaItem.getAliasList(metricSchemaResp.getAlias());
if (!CollectionUtils.isEmpty(aliasList)) {
for (String alias : aliasList) {
result.add(Pair.of(alias, dataFormatType));
}
}
return result.stream();
})
.collect(Collectors.toMap(a -> a.getLeft(), a -> a.getRight(), (k1, k2) -> k1));
for (String fieldName : fieldNameList) {
String dataFormatType = fieldNameToDataFormatType.get(fieldName);
if (DataFormatTypeEnum.DECIMAL.getName().equalsIgnoreCase(dataFormatType)
|| DataFormatTypeEnum.PERCENT.getName().equalsIgnoreCase(dataFormatType)) {
String format = String.format("%s的计量单位是%s", fieldName, "小数; ");
extraInfoSb.append(format);
}
}
}
return extraInfoSb.toString();
}
protected List<ElementValue> getValueList(QueryContext queryCtx, Long dataSetId) {
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>();
}
Set<ElementValue> valueMatches = matchedElements
.stream()
.filter(elementMatch -> !elementMatch.isInherited())
.filter(schemaElementMatch -> {
SchemaElementType type = schemaElementMatch.getElement().getType();
return SchemaElementType.VALUE.equals(type) || SchemaElementType.TAG_VALUE.equals(type)
|| SchemaElementType.ID.equals(type);
})
.map(elementMatch -> {
ElementValue elementValue = new ElementValue();
elementValue.setFieldName(itemIdToName.get(elementMatch.getElement().getId()));
elementValue.setFieldValue(elementMatch.getWord());
return elementValue;
}).collect(Collectors.toSet());
return new ArrayList<>(valueMatches);
}
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, Long dataSetId) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
List<SchemaElement> elements = semanticSchema.getDimensions(dataSetId);
if (QueryType.TAG.equals(queryCtx.getQueryType(dataSetId))) {
elements = semanticSchema.getTags(dataSetId);
}
return elements.stream()
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
}
private Set<String> getTopNFieldNames(QueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
Set<String> results = new HashSet<>();
if (QueryType.TAG.equals(queryCtx.getQueryType(dataSetId))) {
Set<String> tags = semanticSchema.getTags(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(tags);
} else {
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(dimensions);
Set<String> metrics = semanticSchema.getMetrics(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getMetricTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(metrics);
}
return results;
}
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long dataSetId) {
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
return new HashSet<>();
}
Set<String> fieldNameList = matchedElements.stream()
.filter(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.METRIC.equals(elementType)
|| SchemaElementType.DIMENSION.equals(elementType)
|| SchemaElementType.VALUE.equals(elementType)
|| SchemaElementType.TAG.equals(elementType)
|| SchemaElementType.TAG_VALUE.equals(elementType);
})
.map(schemaElementMatch -> {
SchemaElement element = schemaElementMatch.getElement();
SchemaElementType elementType = element.getType();
if (SchemaElementType.VALUE.equals(elementType) || SchemaElementType.TAG_VALUE.equals(
elementType)) {
return itemIdToName.get(element.getId());
}
return schemaElementMatch.getWord();
})
.collect(Collectors.toSet());
return fieldNameList;
}
}

View File

@@ -1,62 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.core.agent.NL2SQLTool;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.jsqlparser.SqlEqualHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
@Slf4j
@Service
public class LLMResponseService {
public SemanticParseInfo addParseInfo(QueryContext queryCtx, ParseResult parseResult, String s2SQL, Double weight) {
if (Objects.isNull(weight)) {
weight = 0D;
}
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(LLMSqlQuery.QUERY_MODE);
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
parseInfo.setDataSet(queryCtx.getSemanticSchema().getDataSet(parseResult.getDataSetId()));
NL2SQLTool commonAgentTool = parseResult.getCommonAgentTool();
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(parseInfo.getDataSetId()));
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, parseResult);
properties.put("type", "internal");
properties.put("name", commonAgentTool.getName());
parseInfo.setProperties(properties);
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setS2SQL(s2SQL);
queryCtx.getCandidateQueries().add(semanticQuery);
return parseInfo;
}
public Map<String, LLMSqlResp> getDeduplicationSqlResp(LLMResp llmResp) {
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
return llmResp.getSqlRespMap();
}
Map<String, LLMSqlResp> result = new HashMap<>();
for (Map.Entry<String, LLMSqlResp> entry : llmResp.getSqlRespMap().entrySet()) {
String key = entry.getKey();
if (result.keySet().stream().anyMatch(existKey -> SqlEqualHelper.equals(existKey, key))) {
continue;
}
result.put(key, entry.getValue());
}
return result;
}
}

View File

@@ -1,72 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.core.agent.NL2SQLTool;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
@Slf4j
public class LLMSqlParser implements SemanticParser {
@Override
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
//1.determine whether to skip this parser.
if (requestService.isSkip(queryCtx)) {
return;
}
try {
//2.get dataSetId from queryCtx and chatCtx.
Long dataSetId = requestService.getDataSetId(queryCtx);
if (dataSetId == null) {
return;
}
//3.get agent tool and determine whether to skip this parser.
NL2SQLTool commonAgentTool = requestService.getParserTool(queryCtx, dataSetId);
if (Objects.isNull(commonAgentTool)) {
log.info("no tool in this agent, skip {}", LLMSqlParser.class);
return;
}
//4.construct a request, call the API for the large model, and retrieve the results.
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, dataSetId);
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, linkingValues);
LLMResp llmResp = requestService.requestLLM(llmReq, dataSetId);
if (Objects.isNull(llmResp)) {
return;
}
//5. deduplicate the SQL result list and build parserInfo
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
ParseResult parseResult = ParseResult.builder()
.dataSetId(dataSetId)
.commonAgentTool(commonAgentTool)
.llmReq(llmReq)
.llmResp(llmResp)
.linkingValues(linkingValues)
.build();
if (MapUtils.isEmpty(deduplicationSqlResp)) {
responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D);
} else {
deduplicationSqlResp.forEach((sql, sqlResp) -> {
responseService.addParseInfo(queryCtx, parseResult, sql, sqlResp.getSqlWeight());
});
}
} catch (Exception e) {
log.error("parse", e);
}
}
}

View File

@@ -1,90 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;
@Service
@Slf4j
public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private SqlExamplarLoader sqlExamplarLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private SqlPromptGenerator sqlPromptGenerator;
@Override
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
//1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
//2.generator linking and sql prompt by sqlExamples,and parallel generate response.
List<String> linkingSqlPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, true);
List<String> llmResults = new CopyOnWriteArrayList<>();
linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
.apply(new HashMap<>());
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
llmResults.add(result);
keyPipelineLog.info("model response:{}", result);
}
);
//3.format response.
List<String> schemaLinkingResults = llmResults.stream()
.map(llmResult -> OutputFormat.getSchemaLinks(llmResult)).collect(Collectors.toList());
List<String> candidateSortedList = OutputFormat.formatList(schemaLinkingResults);
Pair<String, Map<String, Double>> linkingMap = OutputFormat.selfConsistencyVote(candidateSortedList);
List<String> sqlList = llmResults.stream()
.map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList());
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlList);
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
LLMResp result = new LLMResp();
result.setQuery(llmReq.getQueryText());
result.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMapPair.getRight()));
return result;
}
@Override
public void afterPropertiesSet() {
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.ONE_PASS_AUTO_COT_SELF_CONSISTENCY, this);
}
}

View File

@@ -1,75 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Service
@Slf4j
public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private SqlExamplarLoader sqlExampleLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private SqlPromptGenerator sqlPromptGenerator;
@Override
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
//1.retriever sqlExamples
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
//2.generator linking and sql prompt by sqlExamples,and generate response.
String promptStr = sqlPromptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
keyPipelineLog.info("model response:{}", result);
//3.format response.
String schemaLinkStr = OutputFormat.getSchemaLinks(response.content().text());
String sql = OutputFormat.getSql(response.content().text());
Map<String, LLMSqlResp> sqlRespMap = new HashMap<>();
sqlRespMap.put(sql, LLMSqlResp.builder().sqlWeight(1D).fewShots(sqlExamples).build());
keyPipelineLog.info("schemaLinkStr:{},sqlRespMap:{}", schemaLinkStr, sqlRespMap);
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
llmResp.setSqlRespMap(sqlRespMap);
return llmResp;
}
@Override
public void afterPropertiesSet() {
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.ONE_PASS_AUTO_COT, this);
}
}

View File

@@ -1,138 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
/***
* output format
*/
@Slf4j
public class OutputFormat {
public static String getSchemaLink(String schemaLink) {
String reult = "";
try {
reult = schemaLink.trim();
String pattern = "Schema_links:(.*)";
Pattern regexPattern = Pattern.compile(pattern, Pattern.DOTALL);
Matcher matcher = regexPattern.matcher(reult);
if (matcher.find()) {
return matcher.group(1).trim();
}
} catch (Exception e) {
log.error("", e);
}
return reult;
}
public static String getSql(String sqlOutput) {
String sql = "";
try {
sqlOutput = sqlOutput.trim();
String pattern = "SQL:(.*)";
Pattern regexPattern = Pattern.compile(pattern);
Matcher matcher = regexPattern.matcher(sqlOutput);
if (matcher.find()) {
return matcher.group(1);
}
} catch (Exception e) {
log.error("", e);
}
return sql;
}
public static String getSchemaLinks(String text) {
String schemaLinks = "";
try {
text = text.trim();
String pattern = "Schema_links:(\\[.*?\\])|Schema_links: (\\[.*?\\])";
Pattern regexPattern = Pattern.compile(pattern);
Matcher matcher = regexPattern.matcher(text);
if (matcher.find()) {
if (matcher.group(1) != null) {
schemaLinks = matcher.group(1);
} else if (matcher.group(2) != null) {
schemaLinks = matcher.group(2);
}
}
} catch (Exception e) {
log.error("", e);
}
return schemaLinks;
}
public static Pair<String, Map<String, Double>> selfConsistencyVote(List<String> inputList) {
Map<String, Integer> inputCounts = new HashMap<>();
for (String input : inputList) {
inputCounts.put(input, inputCounts.getOrDefault(input, 0) + 1);
}
String inputMax = null;
int maxCount = 0;
int inputSize = inputList.size();
Map<String, Double> votePercentage = new HashMap<>();
for (Map.Entry<String, Integer> entry : inputCounts.entrySet()) {
String input = entry.getKey();
int count = entry.getValue();
if (count > maxCount) {
inputMax = input;
maxCount = count;
}
double percentage = (double) count / inputSize;
votePercentage.put(input, percentage);
}
return Pair.of(inputMax, votePercentage);
}
public static List<String> formatList(List<String> toFormatList) {
List<String> results = new ArrayList<>();
for (String toFormat : toFormatList) {
List<String> items = new ArrayList<>();
String[] split = toFormat.replace("[", "").replace("]", "").split(",");
for (String item : split) {
items.add(item.trim());
}
Collections.sort(items);
String result = "[" + String.join(",", items) + "]";
results.add(result);
}
return results;
}
public static FunctionResp functionCallParse(String llmOutput) {
try {
ObjectMapper objectMapper = new ObjectMapper();
JsonNode jsonNode = objectMapper.readTree(llmOutput);
String selectedTool = jsonNode.get("选择工具").asText();
FunctionResp resp = new FunctionResp();
resp.setToolSelection(selectedTool);
return resp;
} catch (Exception e) {
log.error("", e);
}
return null;
}
public static Map<String, LLMSqlResp> buildSqlRespMap(List<Map<String, String>> sqlExamples,
Map<String, Double> sqlMap) {
return sqlMap.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
entry -> LLMSqlResp.builder().sqlWeight(entry.getValue()).fewShots(sqlExamples).build())
);
}
}

View File

@@ -1,32 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.core.agent.NL2SQLTool;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class ParseResult {
private Long dataSetId;
private LLMReq llmReq;
private LLMResp llmResp;
private QueryReq request;
private NL2SQLTool commonAgentTool;
private List<ElementValue> linkingValues;
}

View File

@@ -1,81 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.fasterxml.jackson.core.type.TypeReference;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ClassPathResource;
import org.springframework.stereotype.Component;
@Slf4j
@Component
public class SqlExamplarLoader {
private static final String EXAMPLE_JSON_FILE = "s2ql_examplar.json";
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
private TypeReference<List<SqlExample>> valueTypeRef = new TypeReference<List<SqlExample>>() {
};
@Autowired
private EmbeddingConfig embeddingConfig;
public List<SqlExample> getSqlExamples() throws IOException {
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
InputStream inputStream = resource.getInputStream();
return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
}
public void addEmbeddingStore(List<SqlExample> sqlExamples, String collectionName) {
List<EmbeddingQuery> queries = new ArrayList<>();
for (int i = 0; i < sqlExamples.size(); i++) {
SqlExample sqlExample = sqlExamples.get(i);
String question = sqlExample.getQuestion();
Map<String, Object> metaDataMap = JsonUtil.toMap(JsonUtil.toString(sqlExample), String.class, Object.class);
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(String.valueOf(i));
embeddingQuery.setQuery(question);
embeddingQuery.setMetadata(metaDataMap);
queries.add(embeddingQuery);
}
s2EmbeddingStore.addQuery(collectionName, queries);
}
public List<Map<String, String>> retrieverSqlExamples(String queryText, int maxResults) {
String collectionName = embeddingConfig.getText2sqlCollectionName();
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText))
.queryEmbeddings(null).build();
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(collectionName, retrieveQuery,
maxResults);
List<Map<String, String>> result = new ArrayList<>();
if (CollectionUtils.isEmpty(resultList)) {
return result;
}
for (Retrieval retrieval : resultList.get(0).getRetrieval()) {
if (Objects.nonNull(retrieval.getMetadata()) && !retrieval.getMetadata().isEmpty()) {
Map<String, String> convertedMap = retrieval.getMetadata().entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, entry -> String.valueOf(entry.getValue())));
result.add(convertedMap);
}
}
return result;
}
}

View File

@@ -1,19 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import lombok.Data;
@Data
public class SqlExample {
private String question;
private String questionAugmented;
private String dbSchema;
private String sql;
private String generatedSchemaLinkingCoT;
private String generatedSchemaLinkings;
}

View File

@@ -1,20 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
/**
* Sql Generation interface, generating SQL using a large model.
*/
public interface SqlGeneration {
/***
* generate llmResp (sql, schemaLink, prompt, etc.) through LLMReq.
* @param llmReq
* @param dataSetId
* @return
*/
LLMResp generation(LLMReq llmReq, Long dataSetId);
}

View File

@@ -1,19 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class SqlGenerationFactory {
private static Map<SqlGenerationMode, SqlGeneration> sqlGenerationMap = new ConcurrentHashMap<>();
public static SqlGeneration get(SqlGenerationMode strategyType) {
return sqlGenerationMap.get(strategyType);
}
public static void addSqlGenerationForFactory(SqlGenerationMode strategy, SqlGeneration sqlGeneration) {
sqlGenerationMap.put(strategy, sqlGeneration);
}
}

View File

@@ -1,132 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@Component
@Slf4j
public class SqlPromptGenerator {
public String generatorLinkingAndSqlPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
String instruction =
"# Find the schema_links for generating SQL queries for each question based on the database schema "
+ "and Foreign keys. Then use the the schema links to generate the "
+ "SQL queries for each of the questions.";
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT", "sql");
String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT\nSQL: sql";
String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
String dbSchema = questionPrompt.getLeft();
String questionAugmented = questionPrompt.getRight();
String newCaseTemplate = "%s\nQ: %s\nA: Lets think step by step. In the question \"%s\", we are asked:";
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented);
return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt;
}
public String generateLinkingPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
String instruction = "# Find the schema_links for generating SQL queries for each question "
+ "based on the database schema and Foreign keys.";
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT");
String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT";
String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
String dbSchema = questionPrompt.getLeft();
String questionAugmented = questionPrompt.getRight();
String newCaseTemplate = "%s\nQ: %s\nA: Lets think step by step. In the question \"%s\", we are asked:";
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented);
return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt;
}
public String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr, List<Map<String, String>> fewshotExampleList) {
String instruction = "# Use the the schema links to generate the SQL queries for each of the questions.";
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql");
String exampleTemplate = "dbSchema\nQ: questionAugmented\n" + "Schema_links: generatedSchemaLinkings\n"
+ "SQL: sql";
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList);
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
String dbSchema = questionPrompt.getLeft();
String questionAugmented = questionPrompt.getRight();
String newCaseTemplate = "%s\nQ: %s\nSchema_links: %s\nSQL: ";
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, schemaLinkStr);
return instruction + InputFormat.SEPERATOR + schemaLinkingPrompt + InputFormat.SEPERATOR + newCasePrompt;
}
public List<String> generatePromptPool(LLMReq llmReq, List<List<Map<String, String>>> exampleListPool,
boolean isSqlPrompt) {
List<String> promptPool = new ArrayList<>();
for (List<Map<String, String>> exampleList : exampleListPool) {
String prompt;
if (isSqlPrompt) {
prompt = generatorLinkingAndSqlPrompt(llmReq, exampleList);
} else {
prompt = generateLinkingPrompt(llmReq, exampleList);
}
promptPool.add(prompt);
}
return promptPool;
}
public List<List<Map<String, String>>> getExampleCombos(List<Map<String, String>> exampleList, int numFewShots,
int numSelfConsistency) {
List<List<Map<String, String>>> results = new ArrayList<>();
for (int i = 0; i < numSelfConsistency; i++) {
List<Map<String, String>> shuffledList = new ArrayList<>(exampleList);
Collections.shuffle(shuffledList);
results.add(shuffledList.subList(0, numFewShots));
}
return results;
}
public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
String modelName = llmReq.getSchema().getDataSetName();
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
List<ElementValue> linking = llmReq.getLinking();
String currentDate = llmReq.getCurrentDate();
String priorExts = llmReq.getPriorExts();
String dbSchema = "Table: " + modelName + ", Columns = " + fieldNameList + "\nForeign_keys: []";
List<String> priorLinkingList = new ArrayList<>();
for (ElementValue priorLinking : linking) {
String fieldName = priorLinking.getFieldName();
String fieldValue = priorLinking.getFieldValue();
priorLinkingList.add("" + fieldValue + "‘是一个‘" + fieldName + "");
}
String currentDataStr = "当前的日期是" + currentDate;
String linkingListStr = String.join("", priorLinkingList);
String questionAugmented = String.format("%s (补充信息:%s 。 %s) (备注: %s)", llmReq.getQueryText(), linkingListStr,
currentDataStr, priorExts);
return Pair.of(dbSchema, questionAugmented);
}
public List<String> generateSqlPromptPool(LLMReq llmReq, List<String> schemaLinkStrPool,
List<List<Map<String, String>>> fewshotExampleListPool) {
List<String> sqlPromptPool = new ArrayList<>();
for (int i = 0; i < schemaLinkStrPool.size(); i++) {
String schemaLinkStr = schemaLinkStrPool.get(i);
List<Map<String, String>> fewshotExampleList = fewshotExampleListPool.get(i);
String sqlPrompt = generateSqlPrompt(llmReq, schemaLinkStr, fewshotExampleList);
sqlPromptPool.add(sqlPrompt);
}
return sqlPromptPool;
}
}

View File

@@ -1,92 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
@Service
public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private SqlExamplarLoader sqlExamplarLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private SqlPromptGenerator sqlPromptGenerator;
@Override
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
//1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
//2.generator linking prompt,and parallel generate response.
List<String> linkingPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, false);
List<String> linkingResults = new CopyOnWriteArrayList<>();
linkingPromptPool.parallelStream().forEach(
linkingPrompt -> {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>());
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage());
Response<AiMessage> linkingResult = chatLanguageModel.generate(prompt.toSystemMessage());
String result = linkingResult.content().text();
keyPipelineLog.info("step one model response:{}", result);
linkingResults.add(OutputFormat.getSchemaLink(result));
}
);
List<String> sortedList = OutputFormat.formatList(linkingResults);
Pair<String, Map<String, Double>> linkingMap = OutputFormat.selfConsistencyVote(sortedList);
//3.generator sql prompt,and parallel generate response.
List<String> sqlPromptPool = sqlPromptGenerator.generateSqlPromptPool(llmReq, sortedList, exampleListPool);
List<String> sqlTaskPool = new CopyOnWriteArrayList<>();
sqlPromptPool.parallelStream().forEach(sqlPrompt -> {
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>());
keyPipelineLog.info("step two request prompt:{}", linkingPrompt.toSystemMessage());
Response<AiMessage> sqlResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
String result = sqlResult.content().text();
keyPipelineLog.info("step two model response:{}", result);
sqlTaskPool.add(result);
});
//4.format response.
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool);
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMapPair.getRight()));
return llmResp;
}
@Override
public void afterPropertiesSet() {
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_PASS_AUTO_COT_SELF_CONSISTENCY, this);
}
}

View File

@@ -1,76 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Service
@Slf4j
public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private SqlExamplarLoader sqlExamplarLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private SqlPromptGenerator sqlPromptGenerator;
@Override
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
String linkingPromptStr = sqlPromptGenerator.generateLinkingPrompt(llmReq, sqlExamples);
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
keyPipelineLog.info("step one model response:{}", response.content().text());
String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text());
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(llmReq, schemaLinkStr, sqlExamples);
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
keyPipelineLog.info("step two request prompt:{}", sqlPrompt.toSystemMessage());
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
String result = sqlResult.content().text();
keyPipelineLog.info("step two model response:{}", result);
Map<String, Double> sqlMap = new HashMap<>();
sqlMap.put(result, 1D);
keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap);
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMap));
return llmResp;
}
@Override
public void afterPropertiesSet() {
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_PASS_AUTO_COT, this);
}
}

View File

@@ -1,73 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.rule;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.agent.AgentToolType;
import com.tencent.supersonic.chat.core.agent.RuleParserTool;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public class AgentCheckParser implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> queries = queryContext.getCandidateQueries();
log.info("query size before agent filter:{}", queryContext.getCandidateQueries().size());
filterQueries(queryContext, queries);
log.info("query size after agent filter: {}", queryContext.getCandidateQueries().size());
}
private void filterQueries(QueryContext queryContext, List<SemanticQuery> queries) {
Agent agent = queryContext.getAgent();
if (agent == null) {
return;
}
List<RuleParserTool> queryTools = getRuleTools(agent);
if (CollectionUtils.isEmpty(queryTools)) {
queryContext.setCandidateQueries(Lists.newArrayList());
return;
}
log.info("agent name :{}, queries resolved: {}", agent.getName(),
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
queries.removeIf(query -> {
for (RuleParserTool tool : queryTools) {
if (CollectionUtils.isNotEmpty(tool.getQueryModes())
&& !tool.getQueryModes().contains(query.getQueryMode())) {
return true;
}
if (CollectionUtils.isEmpty(tool.getDataSetIds())) {
return true;
}
if (tool.isContainsAllModel()) {
return false;
}
return !tool.getDataSetIds().contains(query.getParseInfo().getDataSetId());
}
return true;
});
queryContext.setCandidateQueries(queries);
log.info("agent name :{}, rule queries witch can be supported by agent :{}", agent.getName(),
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
}
private static List<RuleParserTool> getRuleTools(Agent agent) {
if (agent == null) {
return Lists.newArrayList();
}
List<String> tools = agent.getTools(AgentToolType.NL2SQL_RULE);
if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList();
}
return tools.stream().map(tool -> JSONObject.parseObject(tool, RuleParserTool.class))
.collect(Collectors.toList());
}
}

View File

@@ -1,104 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.rule;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.AVG;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.COUNT;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.DISTINCT;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.MAX;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.MIN;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.TOPN;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import java.util.AbstractMap;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
/**
* AggregateTypeParser extracts aggregation type specified in the user query
* based on keyword matching.
* Currently, it supports 7 types of aggregation: max, min, sum, avg, topN,
* distinct count, count.
*/
@Slf4j
public class AggregateTypeParser implements SemanticParser {
private static final Map<AggregateTypeEnum, Pattern> REGX_MAP = Stream.of(
new AbstractMap.SimpleEntry<>(MAX, Pattern.compile("(?i)(最大值|最大|max|峰值|最高|最多)")),
new AbstractMap.SimpleEntry<>(MIN, Pattern.compile("(?i)(最小值|最小|min|最低|最少)")),
new AbstractMap.SimpleEntry<>(SUM, Pattern.compile("(?i)(汇总|总和|sum)")),
new AbstractMap.SimpleEntry<>(AVG, Pattern.compile("(?i)(平均值|日均|平均|avg)")),
new AbstractMap.SimpleEntry<>(TOPN, Pattern.compile("(?i)(top)")),
new AbstractMap.SimpleEntry<>(DISTINCT, Pattern.compile("(?i)(uv)")),
new AbstractMap.SimpleEntry<>(COUNT, Pattern.compile("(?i)(总数|pv)")),
new AbstractMap.SimpleEntry<>(NONE, Pattern.compile("(?i)(明细)"))
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k2));
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
String queryText = queryContext.getQueryText();
AggregateConf aggregateConf = resolveAggregateConf(queryText);
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
if (!AggregateTypeEnum.NONE.equals(semanticQuery.getParseInfo().getAggType())) {
continue;
}
semanticQuery.getParseInfo().setAggType(aggregateConf.type);
int detectWordLength = 0;
if (StringUtils.isNotEmpty(aggregateConf.detectWord)) {
detectWordLength = aggregateConf.detectWord.length();
}
semanticQuery.getParseInfo().setScore(semanticQuery.getParseInfo().getScore() + detectWordLength);
}
}
public AggregateTypeEnum resolveAggregateType(String queryText) {
AggregateConf aggregateConf = resolveAggregateConf(queryText);
return aggregateConf.type;
}
private AggregateConf resolveAggregateConf(String queryText) {
Map<AggregateTypeEnum, Integer> aggregateCount = new HashMap<>(REGX_MAP.size());
Map<AggregateTypeEnum, String> aggregateWord = new HashMap<>(REGX_MAP.size());
for (Map.Entry<AggregateTypeEnum, Pattern> entry : REGX_MAP.entrySet()) {
Matcher matcher = entry.getValue().matcher(queryText);
int count = 0;
String detectWord = null;
while (matcher.find()) {
count++;
detectWord = matcher.group();
}
if (count > 0) {
aggregateCount.put(entry.getKey(), count);
aggregateWord.put(entry.getKey(), detectWord);
}
}
AggregateTypeEnum type = aggregateCount.entrySet().stream().max(Map.Entry.comparingByValue())
.map(entry -> entry.getKey()).orElse(NONE);
String detectWord = aggregateWord.get(type);
return new AggregateConf(type, detectWord);
}
@AllArgsConstructor
class AggregateConf {
public AggregateTypeEnum type;
public String detectWord;
}
}

View File

@@ -1,124 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.rule;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.core.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.core.query.rule.metric.MetricSemanticQuery;
import com.tencent.supersonic.chat.core.query.rule.metric.MetricTagQuery;
import lombok.extern.slf4j.Slf4j;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* ContextInheritParser tries to inherit certain schema elements from context
* so that in multi-turn conversations users don't need to mention some keyword
* repeatedly.
*/
@Slf4j
public class ContextInheritParser implements SemanticParser {
private static final Map<SchemaElementType, List<SchemaElementType>> MUTUAL_EXCLUSIVE_MAP = Stream.of(
new AbstractMap.SimpleEntry<>(SchemaElementType.METRIC, Arrays.asList(SchemaElementType.METRIC)),
new AbstractMap.SimpleEntry<>(
SchemaElementType.DIMENSION, Arrays.asList(SchemaElementType.DIMENSION, SchemaElementType.VALUE)),
new AbstractMap.SimpleEntry<>(
SchemaElementType.VALUE, Arrays.asList(SchemaElementType.VALUE, SchemaElementType.DIMENSION)),
new AbstractMap.SimpleEntry<>(SchemaElementType.ENTITY, Arrays.asList(SchemaElementType.ENTITY)),
new AbstractMap.SimpleEntry<>(SchemaElementType.TAG, Arrays.asList(SchemaElementType.TAG)),
new AbstractMap.SimpleEntry<>(SchemaElementType.DATASET, Arrays.asList(SchemaElementType.DATASET)),
new AbstractMap.SimpleEntry<>(SchemaElementType.ID, Arrays.asList(SchemaElementType.ID))
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
if (!shouldInherit(queryContext)) {
return;
}
Long dataSetId = getMatchedDataSet(queryContext, chatContext);
if (dataSetId == null) {
return;
}
List<SchemaElementMatch> elementMatches = queryContext.getMapInfo().getMatchedElements(dataSetId);
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
SchemaElementType matchType = match.getElement().getType();
// mutual exclusive element types should not be inherited
RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery(chatContext.getParseInfo().getQueryMode());
if (!containsTypes(elementMatches, matchType, ruleQuery)) {
match.setInherited(true);
matchesToInherit.add(match);
}
}
elementMatches.addAll(matchesToInherit);
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(queryContext, chatContext);
if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(), queryContext)) {
continue;
}
queryContext.getCandidateQueries().add(query);
}
}
private boolean existSameQuery(Long dataSetId, String queryMode, QueryContext queryContext) {
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
if (semanticQuery.getQueryMode().equals(queryMode)
&& semanticQuery.getParseInfo().getDataSetId().equals(dataSetId)) {
return true;
}
}
return false;
}
private boolean containsTypes(List<SchemaElementMatch> matches, SchemaElementType matchType,
RuleSemanticQuery ruleQuery) {
List<SchemaElementType> types = MUTUAL_EXCLUSIVE_MAP.get(matchType);
return matches.stream().anyMatch(m -> {
SchemaElementType type = m.getElement().getType();
if (Objects.nonNull(ruleQuery) && ruleQuery instanceof MetricSemanticQuery
&& !(ruleQuery instanceof MetricTagQuery)) {
return types.contains(type);
}
return type.equals(matchType);
});
}
protected boolean shouldInherit(QueryContext queryContext) {
// if candidates only have MetricModel mode, count in context
List<SemanticQuery> metricModelQueries = queryContext.getCandidateQueries().stream()
.filter(query -> query instanceof MetricModelQuery).collect(
Collectors.toList());
return metricModelQueries.size() == queryContext.getCandidateQueries().size();
}
protected Long getMatchedDataSet(QueryContext queryContext, ChatContext chatContext) {
Long dataSetId = chatContext.getParseInfo().getDataSetId();
if (dataSetId == null) {
return null;
}
Set<Long> queryDataSets = queryContext.getMapInfo().getMatchedDataSetInfos();
if (queryDataSets.contains(dataSetId)) {
return dataSetId;
}
return dataSetId;
}
}

View File

@@ -1,42 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.rule;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
import java.util.Arrays;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
/**
* RuleSqlParser resolves a specific SemanticQuery according to co-appearance
* of certain schema element types.
*/
@Slf4j
public class RuleSqlParser implements SemanticParser {
private static List<SemanticParser> auxiliaryParsers = Arrays.asList(
new ContextInheritParser(),
new TimeRangeParser(),
new AggregateTypeParser(),
new AgentCheckParser()
);
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
SchemaMapInfo mapInfo = queryContext.getMapInfo();
// iterate all schemaElementMatches to resolve query mode
for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) {
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(queryContext, chatContext);
queryContext.getCandidateQueries().add(query);
}
}
auxiliaryParsers.stream().forEach(p -> p.parse(queryContext, chatContext));
}
}

View File

@@ -1,211 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.rule;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.time.LocalDate;
import java.util.Stack;
import java.util.Date;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import com.xkzhangsan.time.nlp.TimeNLP;
import com.xkzhangsan.time.nlp.TimeNLPUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;
/**
* TimeRangeParser extracts time range specified in the user query
* based on keyword matching.
* Currently, it supports two kinds of expression:
* 1. Recent unit: 近N天/周/月/年、过去N天/周/月/年
* 2. Concrete date: 2023年11月15日、20231115
*/
@Slf4j
public class TimeRangeParser implements SemanticParser {
private static final Pattern RECENT_PATTERN_CN = Pattern.compile(
".*(?<periodStr>(近|过去)((?<enNum>\\d+)|(?<zhNum>[一二三四五六七八九十百千万亿]+))个?(?<zhPeriod>[天周月年])).*");
private static final Pattern DATE_PATTERN_NUMBER = Pattern.compile("(\\d{8})");
private static final DateFormat DATE_FORMAT_NUMBER = new SimpleDateFormat("yyyyMMdd");
private static final DateFormat DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd");
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
String queryText = queryContext.getQueryText();
DateConf dateConf = parseRecent(queryText);
if (dateConf == null) {
dateConf = parseDateNumber(queryText);
}
if (dateConf == null) {
dateConf = parseDateCN(queryText);
}
if (dateConf != null) {
if (queryContext.getCandidateQueries().size() > 0) {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
query.getParseInfo().setDateInfo(dateConf);
query.getParseInfo().setScore(query.getParseInfo().getScore()
+ dateConf.getDetectWord().length());
}
} else if (QueryManager.containsRuleQuery(chatContext.getParseInfo().getQueryMode())) {
RuleSemanticQuery semanticQuery = QueryManager.createRuleQuery(
chatContext.getParseInfo().getQueryMode());
// inherit parse info from context
chatContext.getParseInfo().setDateInfo(dateConf);
chatContext.getParseInfo().setScore(chatContext.getParseInfo().getScore()
+ dateConf.getDetectWord().length());
semanticQuery.setParseInfo(chatContext.getParseInfo());
queryContext.getCandidateQueries().add(semanticQuery);
}
}
}
private DateConf parseDateCN(String queryText) {
Date startDate = null;
Date endDate;
String detectWord = null;
List<TimeNLP> times = TimeNLPUtil.parse(queryText);
if (times.size() > 0) {
startDate = times.get(0).getTime();
detectWord = times.get(0).getTimeExpression();
} else {
return null;
}
if (times.size() > 1) {
endDate = times.get(1).getTime();
detectWord += "~" + times.get(0).getTimeExpression();
} else {
endDate = startDate;
}
return getDateConf(startDate, endDate, detectWord);
}
private DateConf parseDateNumber(String queryText) {
String startDate;
String endDate = null;
String detectWord = null;
Matcher dateMatcher = DATE_PATTERN_NUMBER.matcher(queryText);
if (dateMatcher.find()) {
startDate = dateMatcher.group();
detectWord = startDate;
} else {
return null;
}
if (dateMatcher.find()) {
endDate = dateMatcher.group();
detectWord += "~" + endDate;
}
endDate = endDate != null ? endDate : startDate;
try {
return getDateConf(DATE_FORMAT_NUMBER.parse(startDate), DATE_FORMAT_NUMBER.parse(endDate), detectWord);
} catch (ParseException e) {
return null;
}
}
private DateConf parseRecent(String queryText) {
Matcher m = RECENT_PATTERN_CN.matcher(queryText);
if (m.matches()) {
int num = 0;
String enNum = m.group("enNum");
String zhNum = m.group("zhNum");
if (enNum != null) {
num = Integer.parseInt(enNum);
} else if (zhNum != null) {
num = zhNumParse(zhNum);
}
if (num > 0) {
DateConf info = new DateConf();
String zhPeriod = m.group("zhPeriod");
int days;
switch (zhPeriod) {
case "":
days = 7;
info.setPeriod(Constants.WEEK);
break;
case "":
days = 30;
info.setPeriod(Constants.MONTH);
break;
case "":
days = 365;
info.setPeriod(Constants.YEAR);
break;
default:
days = 1;
info.setPeriod(Constants.DAY);
}
days = days * num;
info.setDateMode(DateConf.DateMode.RECENT);
String detectWord = "" + num + zhPeriod;
if (Strings.isNotEmpty(m.group("periodStr"))) {
detectWord = m.group("periodStr");
}
info.setDetectWord(detectWord);
info.setStartDate(LocalDate.now().minusDays(days).toString());
info.setEndDate(LocalDate.now().minusDays(1).toString());
info.setUnit(num);
return info;
}
}
return null;
}
private int zhNumParse(String zhNumStr) {
Stack<Integer> stack = new Stack<>();
String numStr = "一二三四五六七八九";
String unitStr = "十百千万亿";
String[] ssArr = zhNumStr.split("");
for (String e : ssArr) {
int numIndex = numStr.indexOf(e);
int unitIndex = unitStr.indexOf(e);
if (numIndex != -1) {
stack.push(numIndex + 1);
} else if (unitIndex != -1) {
int unitNum = (int) Math.pow(10, unitIndex + 1);
if (stack.isEmpty()) {
stack.push(unitNum);
} else {
stack.push(stack.pop() * unitNum);
}
}
}
return stack.stream().mapToInt(s -> s).sum();
}
private DateConf getDateConf(Date startDate, Date endDate, String detectWord) {
if (startDate == null || endDate == null) {
return null;
}
DateConf info = new DateConf();
info.setDateMode(DateConf.DateMode.BETWEEN);
info.setStartDate(DATE_FORMAT.format(startDate));
info.setEndDate(DATE_FORMAT.format(endDate));
info.setDetectWord(detectWord);
return info;
}
}

View File

@@ -1,14 +0,0 @@
package com.tencent.supersonic.chat.core.pojo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import lombok.Data;
@Data
public class ChatContext {
private Integer chatId;
private Integer agentId;
private String queryText;
private SemanticParseInfo parseInfo = new SemanticParseInfo();
private String user;
}

View File

@@ -1,68 +0,0 @@
package com.tencent.supersonic.chat.core.pojo;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.plugin.Plugin;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class QueryContext {
private String queryText;
private Integer chatId;
private Long dataSetId;
private User user;
private boolean saveAnswer = true;
private Integer agentId;
private QueryFilters queryFilters;
private List<SemanticQuery> candidateQueries = new ArrayList<>();
private SchemaMapInfo mapInfo = new SchemaMapInfo();
@JsonIgnore
private SemanticSchema semanticSchema;
@JsonIgnore
private Agent agent;
@JsonIgnore
private Map<Long, ChatConfigRichResp> modelIdToChatRichConfig;
@JsonIgnore
private Map<String, Plugin> nameToPlugin;
@JsonIgnore
private List<Plugin> pluginList;
public List<SemanticQuery> getCandidateQueries() {
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
Integer parseShowCount = optimizationConfig.getParseShowCount();
candidateQueries = candidateQueries.stream()
.sorted(Comparator.comparing(semanticQuery -> semanticQuery.getParseInfo().getScore(),
Comparator.reverseOrder()))
.limit(parseShowCount)
.collect(Collectors.toList());
return candidateQueries;
}
public QueryType getQueryType(Long dataSetId) {
SemanticSchema semanticSchema = this.semanticSchema;
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
return dataSetSchema.getQueryType();
}
}

View File

@@ -1,129 +0,0 @@
package com.tencent.supersonic.chat.core.query;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.chat.core.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
@Slf4j
@ToString
public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
protected SemanticParseInfo parseInfo = new SemanticParseInfo();
protected SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
@Override
public String explain(User user) {
ExplainSqlReq explainSqlReq = null;
try {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (StringUtils.isNotBlank(sqlInfo.getCorrectS2SQL())) {
//sql
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryType.SQL)
.queryReq(QueryReqBuilder.buildS2SQLReq(
sqlInfo.getCorrectS2SQL(), parseInfo.getDataSetId()
))
.build();
} else {
//struct
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryType.STRUCT)
.queryReq(QueryReqBuilder.buildStructReq(parseInfo))
.build();
}
ExplainResp explain = semanticInterpreter.explain(explainSqlReq, user);
if (Objects.nonNull(explain)) {
return explain.getSql();
}
} catch (Exception e) {
log.error("explain error explainSqlReq:{}", explainSqlReq, e);
}
return null;
}
@Override
public SemanticParseInfo getParseInfo() {
return parseInfo;
}
@Override
public void setParseInfo(SemanticParseInfo parseInfo) {
this.parseInfo = parseInfo;
}
protected QueryStructReq convertQueryStruct() {
return QueryReqBuilder.buildStructReq(parseInfo);
}
protected void convertBizNameToName(SemanticSchema semanticSchema, QueryStructReq queryStructReq) {
Map<String, String> bizNameToName = semanticSchema.getBizNameToName(queryStructReq.getDataSetId());
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
List<Order> orders = queryStructReq.getOrders();
if (CollectionUtils.isNotEmpty(orders)) {
for (Order order : orders) {
order.setColumn(bizNameToName.get(order.getColumn()));
}
}
List<Aggregator> aggregators = queryStructReq.getAggregators();
if (CollectionUtils.isNotEmpty(aggregators)) {
for (Aggregator aggregator : aggregators) {
aggregator.setColumn(bizNameToName.get(aggregator.getColumn()));
}
}
List<String> groups = queryStructReq.getGroups();
if (CollectionUtils.isNotEmpty(groups)) {
groups = groups.stream().map(bizNameToName::get).collect(Collectors.toList());
queryStructReq.setGroups(groups);
}
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
dimensionFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
}
List<Filter> metricFilters = queryStructReq.getMetricFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
metricFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
}
}
protected void initS2SqlByStruct(SemanticSchema semanticSchema) {
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
if (!optimizationConfig.isUseS2SqlSwitch()) {
return;
}
QueryStructReq queryStructReq = convertQueryStruct();
convertBizNameToName(semanticSchema, queryStructReq);
QuerySqlReq querySQLReq = queryStructReq.convert();
parseInfo.getSqlInfo().setS2SQL(querySQLReq.getSql());
parseInfo.getSqlInfo().setCorrectS2SQL(querySQLReq.getSql());
}
}

View File

@@ -1,107 +0,0 @@
package com.tencent.supersonic.chat.core.query;
import com.tencent.supersonic.chat.core.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.core.query.rule.metric.MetricSemanticQuery;
import com.tencent.supersonic.chat.core.query.rule.tag.TagSemanticQuery;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
public class QueryManager {
private static Map<String, RuleSemanticQuery> ruleQueryMap = new ConcurrentHashMap<>();
private static Map<String, PluginSemanticQuery> pluginQueryMap = new ConcurrentHashMap<>();
private static Map<String, LLMSemanticQuery> llmQueryMap = new ConcurrentHashMap<>();
public static void register(SemanticQuery query) {
if (query instanceof RuleSemanticQuery) {
ruleQueryMap.put(query.getQueryMode(), (RuleSemanticQuery) query);
} else if (query instanceof PluginSemanticQuery) {
pluginQueryMap.put(query.getQueryMode(), (PluginSemanticQuery) query);
} else if (query instanceof LLMSemanticQuery) {
llmQueryMap.put(query.getQueryMode(), (LLMSemanticQuery) query);
}
}
public static SemanticQuery createQuery(String queryMode) {
if (containsRuleQuery(queryMode)) {
return createRuleQuery(queryMode);
}
if (containsPluginQuery(queryMode)) {
return createPluginQuery(queryMode);
}
return createLLMQuery(queryMode);
}
public static RuleSemanticQuery createRuleQuery(String queryMode) {
RuleSemanticQuery semanticQuery = ruleQueryMap.get(queryMode);
return (RuleSemanticQuery) getSemanticQuery(queryMode, semanticQuery);
}
public static PluginSemanticQuery createPluginQuery(String queryMode) {
PluginSemanticQuery semanticQuery = pluginQueryMap.get(queryMode);
return (PluginSemanticQuery) getSemanticQuery(queryMode, semanticQuery);
}
public static LLMSemanticQuery createLLMQuery(String queryMode) {
LLMSemanticQuery semanticQuery = llmQueryMap.get(queryMode);
return (LLMSemanticQuery) getSemanticQuery(queryMode, semanticQuery);
}
private static SemanticQuery getSemanticQuery(String queryMode, SemanticQuery semanticQuery) {
if (Objects.isNull(semanticQuery)) {
throw new RuntimeException("no supported queryMode :" + queryMode);
}
try {
return semanticQuery.getClass().getDeclaredConstructor().newInstance();
} catch (Exception e) {
throw new RuntimeException("no supported queryMode :" + queryMode);
}
}
public static boolean containsRuleQuery(String queryMode) {
if (queryMode == null) {
return false;
}
return ruleQueryMap.containsKey(queryMode);
}
public static boolean isMetricQuery(String queryMode) {
if (queryMode == null || !ruleQueryMap.containsKey(queryMode)) {
return false;
}
return ruleQueryMap.get(queryMode) instanceof MetricSemanticQuery;
}
public static boolean isTagQuery(String queryMode) {
if (queryMode == null || !ruleQueryMap.containsKey(queryMode)) {
return false;
}
return ruleQueryMap.get(queryMode) instanceof TagSemanticQuery;
}
public static boolean containsPluginQuery(String queryMode) {
return queryMode != null && pluginQueryMap.containsKey(queryMode);
}
public static RuleSemanticQuery getRuleQuery(String queryMode) {
if (queryMode == null) {
return null;
}
return ruleQueryMap.get(queryMode);
}
public static List<RuleSemanticQuery> getRuleQueries() {
return new ArrayList<>(ruleQueryMap.values());
}
public static List<String> getPluginQueryModes() {
return new ArrayList<>(pluginQueryMap.keySet());
}
}

View File

@@ -1,25 +0,0 @@
package com.tencent.supersonic.chat.core.query;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import org.apache.calcite.sql.parser.SqlParseException;
/**
* A semantic query executes specific type of query based on the results of semantic parsing.
*/
public interface SemanticQuery {
String getQueryMode();
QueryResult execute(User user) throws SqlParseException;
void initS2Sql(SemanticSchema semanticSchema, User user);
String explain(User user);
SemanticParseInfo getParseInfo();
void setParseInfo(SemanticParseInfo parseInfo);
}

View File

@@ -1,8 +0,0 @@
package com.tencent.supersonic.chat.core.query.llm;
import com.tencent.supersonic.chat.core.query.BaseSemanticQuery;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public abstract class LLMSemanticQuery extends BaseSemanticQuery {
}

View File

@@ -1,13 +0,0 @@
package com.tencent.supersonic.chat.core.query.llm.analytics;
import lombok.Data;
@Data
public class LLMAnswerReq {
private String queryText;
private String pluginOutput;
}

View File

@@ -1,10 +0,0 @@
package com.tencent.supersonic.chat.core.query.llm.analytics;
import lombok.Data;
@Data
public class LLMAnswerResp {
private String assistantMessage;
}

View File

@@ -1,144 +0,0 @@
package com.tencent.supersonic.chat.core.query.llm.analytics;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.chat.core.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@Slf4j
@Component
public class MetricAnalyzeQuery extends LLMSemanticQuery {
public static final String QUERY_MODE = "METRIC_INTERPRET";
public MetricAnalyzeQuery() {
QueryManager.register(this);
}
@Override
public String getQueryMode() {
return QUERY_MODE;
}
@Override
public QueryResult execute(User user) throws SqlParseException {
QueryStructReq queryStructReq = convertQueryStruct();
SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
SemanticQueryResp semanticQueryResp = semanticInterpreter.queryByStruct(queryStructReq, user);
String text = generateTableText(semanticQueryResp);
Map<String, Object> properties = parseInfo.getProperties();
Map<String, String> replacedMap = new HashMap<>();
String textReplaced = replaceText((String) properties.get("queryText"),
parseInfo.getElementMatches(), replacedMap);
String answer = replaceAnswer(fetchInterpret(textReplaced, text), replacedMap);
QueryResult queryResult = new QueryResult();
List<QueryColumn> queryColumns = Lists.newArrayList(new QueryColumn("结果", "string", "answer"));
Map<String, Object> result = new HashMap<>();
result.put("answer", answer);
List<Map<String, Object>> resultList = Lists.newArrayList();
resultList.add(result);
queryResult.setQueryResults(resultList);
queryResult.setQueryColumns(queryColumns);
queryResult.setQueryMode(getQueryMode());
queryResult.setQueryState(QueryState.SUCCESS);
return queryResult;
}
@Override
public void initS2Sql(SemanticSchema semanticSchema, User user) {
initS2SqlByStruct(semanticSchema);
}
protected QueryStructReq convertQueryStruct() {
QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo);
fillAggregator(queryStructReq, parseInfo.getMetrics());
queryStructReq.setQueryType(QueryType.TAG);
return queryStructReq;
}
private String replaceText(String text, List<SchemaElementMatch> schemaElementMatches,
Map<String, String> replacedMap) {
if (CollectionUtils.isEmpty(schemaElementMatches)) {
return text;
}
List<SchemaElementMatch> valueSchemaElementMatches = schemaElementMatches.stream()
.filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElementMatches) {
String detectWord = schemaElementMatch.getDetectWord();
if (StringUtils.isBlank(detectWord)) {
continue;
}
text = text.replace(detectWord, "xxx");
replacedMap.put("xxx", detectWord);
}
return text;
}
private void fillAggregator(QueryStructReq queryStructReq, Set<SchemaElement> schemaElements) {
queryStructReq.getAggregators().clear();
for (SchemaElement schemaElement : schemaElements) {
Aggregator aggregator = new Aggregator();
aggregator.setColumn(schemaElement.getBizName());
aggregator.setFunc(AggOperatorEnum.SUM);
aggregator.setNameCh(schemaElement.getName());
queryStructReq.getAggregators().add(aggregator);
}
}
private String replaceAnswer(String text, Map<String, String> replacedMap) {
for (String key : replacedMap.keySet()) {
text = text.replaceAll(key, replacedMap.get(key));
}
return text;
}
public static String generateTableText(SemanticQueryResp result) {
StringBuilder tableBuilder = new StringBuilder();
for (QueryColumn column : result.getColumns()) {
tableBuilder.append(column.getName()).append("\t");
}
tableBuilder.append("\n");
for (Map<String, Object> row : result.getResultList()) {
for (QueryColumn column : result.getColumns()) {
tableBuilder.append(row.get(column.getNameEn())).append("\t");
}
tableBuilder.append("\n");
}
return tableBuilder.toString();
}
public String fetchInterpret(String queryText, String dataText) {
return "";
}
}

View File

@@ -1,83 +0,0 @@
package com.tencent.supersonic.chat.core.query.llm.s2sql;
import com.fasterxml.jackson.annotation.JsonValue;
import lombok.Data;
import java.util.List;
@Data
public class LLMReq {
private String queryText;
private FilterCondition filterCondition;
private LLMSchema schema;
private List<ElementValue> linking;
private String currentDate;
private String priorExts;
private String sqlGenerationMode;
@Data
public static class ElementValue {
private String fieldName;
private String fieldValue;
}
@Data
public static class LLMSchema {
private String domainName;
private String dataSetName;
private List<String> fieldNameList;
}
@Data
public static class FilterCondition {
private String tableName;
}
public enum SqlGenerationMode {
ONE_PASS_AUTO_COT("1_pass_auto_cot"),
ONE_PASS_AUTO_COT_SELF_CONSISTENCY("1_pass_auto_cot_self_consistency"),
TWO_PASS_AUTO_COT("2_pass_auto_cot"),
TWO_PASS_AUTO_COT_SELF_CONSISTENCY("2_pass_auto_cot_self_consistency");
private String name;
SqlGenerationMode(String name) {
this.name = name;
}
@JsonValue
public String getName() {
return name;
}
public static SqlGenerationMode getMode(String name) {
for (SqlGenerationMode sqlGenerationMode : SqlGenerationMode.values()) {
if (sqlGenerationMode.name.equals(name)) {
return sqlGenerationMode;
}
}
return null;
}
}
}

View File

@@ -1,25 +0,0 @@
package com.tencent.supersonic.chat.core.query.llm.s2sql;
import java.util.List;
import java.util.Map;
import lombok.Data;
@Data
public class LLMResp {
private String query;
private String modelName;
private String sqlOutput;
private List<String> fields;
private Map<String, LLMSqlResp> sqlRespMap;
/**
* Only for compatibility with python code, later deleted
*/
private Map<String, Double> sqlWeight;
}

Some files were not shown because too many files have changed in this diff Show More