(improvement)(project) support for modifying filter conditions and fix group by pushdown and add windows scipt (#49)

Co-authored-by: lexluo <lexluo@tencent.com>
This commit is contained in:
lexluo09
2023-09-03 23:51:47 +08:00
committed by GitHub
parent 8440f1f30e
commit 559ef974b0
317 changed files with 7449 additions and 9413 deletions

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.auth.api.authentication.adaptor;
import com.tencent.supersonic.auth.api.authentication.pojo.Organization; import com.tencent.supersonic.auth.api.authentication.pojo.Organization;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.request.UserReq; import com.tencent.supersonic.auth.api.authentication.request.UserReq;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;

View File

@@ -1,9 +1,10 @@
package com.tencent.supersonic.auth.api.authentication.pojo; package com.tencent.supersonic.auth.api.authentication.pojo;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import java.util.List;
import lombok.Data; import lombok.Data;
import java.util.List;
@Data @Data
public class Organization { public class Organization {

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.auth.api.authentication.service;
import com.tencent.supersonic.auth.api.authentication.pojo.Organization; import com.tencent.supersonic.auth.api.authentication.pojo.Organization;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.request.UserReq; import com.tencent.supersonic.auth.api.authentication.request.UserReq;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.auth.api.authentication.service; package com.tencent.supersonic.auth.api.authentication.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.auth.api.authorization.request;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes; import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import lombok.Data; import lombok.Data;
import lombok.ToString; import lombok.ToString;

View File

@@ -4,6 +4,7 @@ import com.tencent.supersonic.auth.api.authorization.pojo.AuthResGrp;
import com.tencent.supersonic.auth.api.authorization.pojo.DimensionFilter; import com.tencent.supersonic.auth.api.authorization.pojo.DimensionFilter;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import lombok.Data; import lombok.Data;
@Data @Data

View File

@@ -6,7 +6,6 @@ import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq;
import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp; import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp;
import java.util.List; import java.util.List;
public interface AuthService { public interface AuthService {
List<AuthGroup> queryAuthGroups(String domainId, Integer groupId); List<AuthGroup> queryAuthGroups(String domainId, Integer groupId);

View File

@@ -11,10 +11,10 @@ import com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO;
import com.tencent.supersonic.auth.authentication.persistence.repository.UserRepository; import com.tencent.supersonic.auth.authentication.persistence.repository.UserRepository;
import com.tencent.supersonic.auth.authentication.utils.UserTokenUtils; import com.tencent.supersonic.auth.authentication.utils.UserTokenUtils;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import org.springframework.beans.BeanUtils;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.springframework.beans.BeanUtils;
public class DefaultUserAdaptor implements UserAdaptor { public class DefaultUserAdaptor implements UserAdaptor {

View File

@@ -0,0 +1,21 @@
package com.tencent.supersonic.auth.authentication.config;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Data
@Configuration
public class TppConfig {
@Value(value = "${auth.app.secret:}")
private String appSecret;
@Value(value = "${auth.app.key:}")
private String appKey;
@Value(value = "${auth.oa.url:}")
private String tppOaUrl;
}

View File

@@ -1,11 +1,12 @@
package com.tencent.supersonic.auth.authentication.interceptor; package com.tencent.supersonic.auth.authentication.interceptor;
import java.util.List;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.support.SpringFactoriesLoader; import org.springframework.core.io.support.SpringFactoriesLoader;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry; import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
import java.util.List;
@Configuration @Configuration
public class InterceptorFactory implements WebMvcConfigurer { public class InterceptorFactory implements WebMvcConfigurer {

View File

@@ -3,8 +3,8 @@ package com.tencent.supersonic.auth.authentication.persistence.repository.impl;
import com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO; import com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO;
import com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample; import com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample;
import com.tencent.supersonic.auth.authentication.persistence.mapper.UserDOMapper;
import com.tencent.supersonic.auth.authentication.persistence.repository.UserRepository; import com.tencent.supersonic.auth.authentication.persistence.repository.UserRepository;
import com.tencent.supersonic.auth.authentication.persistence.mapper.UserDOMapper;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.auth.api.authentication.service.UserService;
import com.tencent.supersonic.auth.authentication.utils.ComponentFactory; import com.tencent.supersonic.auth.authentication.utils.ComponentFactory;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@Service @Service

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.auth.authentication.utils; package com.tencent.supersonic.auth.authentication.utils;
import com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor; import com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor;
import java.util.Objects;
import org.springframework.core.io.support.SpringFactoriesLoader; import org.springframework.core.io.support.SpringFactoriesLoader;
import java.util.Objects;
public class ComponentFactory { public class ComponentFactory {

View File

@@ -1,159 +1,145 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" <!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.tencent.supersonic.auth.authentication.persistence.mapper.UserDOMapper"> <mapper namespace="com.tencent.supersonic.auth.authentication.persistence.mapper.UserDOMapper">
<resultMap id="BaseResultMap" <resultMap id="BaseResultMap" type="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
type="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO"> <id column="id" jdbcType="BIGINT" property="id" />
<id column="id" jdbcType="BIGINT" property="id"/> <result column="name" jdbcType="VARCHAR" property="name" />
<result column="name" jdbcType="VARCHAR" property="name"/> <result column="password" jdbcType="VARCHAR" property="password" />
<result column="password" jdbcType="VARCHAR" property="password"/> <result column="display_name" jdbcType="VARCHAR" property="displayName" />
<result column="display_name" jdbcType="VARCHAR" property="displayName"/> <result column="email" jdbcType="VARCHAR" property="email" />
<result column="email" jdbcType="VARCHAR" property="email"/> </resultMap>
</resultMap> <sql id="Example_Where_Clause">
<sql id="Example_Where_Clause"> <where>
<where> <foreach collection="oredCriteria" item="criteria" separator="or">
<foreach collection="oredCriteria" item="criteria" separator="or"> <if test="criteria.valid">
<if test="criteria.valid"> <trim prefix="(" prefixOverrides="and" suffix=")">
<trim prefix="(" prefixOverrides="and" suffix=")"> <foreach collection="criteria.criteria" item="criterion">
<foreach collection="criteria.criteria" item="criterion"> <choose>
<choose> <when test="criterion.noValue">
<when test="criterion.noValue"> and ${criterion.condition}
and ${criterion.condition} </when>
</when> <when test="criterion.singleValue">
<when test="criterion.singleValue"> and ${criterion.condition} #{criterion.value}
and ${criterion.condition} #{criterion.value} </when>
</when> <when test="criterion.betweenValue">
<when test="criterion.betweenValue"> and ${criterion.condition} #{criterion.value} and #{criterion.secondValue}
and ${criterion.condition} #{criterion.value} and </when>
#{criterion.secondValue} <when test="criterion.listValue">
</when> and ${criterion.condition}
<when test="criterion.listValue"> <foreach close=")" collection="criterion.value" item="listItem" open="(" separator=",">
and ${criterion.condition} #{listItem}
<foreach close=")" collection="criterion.value" item="listItem" </foreach>
open="(" separator=","> </when>
#{listItem} </choose>
</foreach>
</when>
</choose>
</foreach>
</trim>
</if>
</foreach> </foreach>
</where> </trim>
</sql>
<sql id="Base_Column_List">
id
, name, password, display_name, email
</sql>
<select id="selectByExample"
parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample"
resultMap="BaseResultMap">
select
<if test="distinct">
distinct
</if> </if>
<include refid="Base_Column_List"/> </foreach>
from s2_user </where>
<if test="_parameter != null"> </sql>
<include refid="Example_Where_Clause"/> <sql id="Base_Column_List">
</if> id, name, password, display_name, email
<if test="orderByClause != null"> </sql>
order by ${orderByClause} <select id="selectByExample" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample" resultMap="BaseResultMap">
</if> select
<if test="limitStart != null and limitStart>=0"> <if test="distinct">
limit #{limitStart} , #{limitEnd} distinct
</if> </if>
</select> <include refid="Base_Column_List" />
<select id="selectByPrimaryKey" parameterType="java.lang.Long" resultMap="BaseResultMap"> from s2_user
select <if test="_parameter != null">
<include refid="Base_Column_List"/> <include refid="Example_Where_Clause" />
from s2_user </if>
where id = #{id,jdbcType=BIGINT} <if test="orderByClause != null">
</select> order by ${orderByClause}
<delete id="deleteByPrimaryKey" parameterType="java.lang.Long"> </if>
delete <if test="limitStart != null and limitStart>=0">
from s2_user limit #{limitStart} , #{limitEnd}
where id = #{id,jdbcType=BIGINT} </if>
</delete> </select>
<insert id="insert" <select id="selectByPrimaryKey" parameterType="java.lang.Long" resultMap="BaseResultMap">
parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO"> select
insert into s2_user (id, name, password, <include refid="Base_Column_List" />
display_name, email) from s2_user
values (#{id,jdbcType=BIGINT}, #{name,jdbcType=VARCHAR}, #{password,jdbcType=VARCHAR}, where id = #{id,jdbcType=BIGINT}
#{displayName,jdbcType=VARCHAR}, #{email,jdbcType=VARCHAR}) </select>
</insert> <delete id="deleteByPrimaryKey" parameterType="java.lang.Long">
<insert id="insertSelective" delete from s2_user
parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO"> where id = #{id,jdbcType=BIGINT}
insert into s2_user </delete>
<trim prefix="(" suffix=")" suffixOverrides=","> <insert id="insert" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
<if test="id != null"> insert into s2_user (id, name, password,
id, display_name, email)
</if> values (#{id,jdbcType=BIGINT}, #{name,jdbcType=VARCHAR}, #{password,jdbcType=VARCHAR},
<if test="name != null"> #{displayName,jdbcType=VARCHAR}, #{email,jdbcType=VARCHAR})
name, </insert>
</if> <insert id="insertSelective" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
<if test="password != null"> insert into s2_user
password, <trim prefix="(" suffix=")" suffixOverrides=",">
</if> <if test="id != null">
<if test="displayName != null"> id,
display_name, </if>
</if> <if test="name != null">
<if test="email != null"> name,
email, </if>
</if> <if test="password != null">
</trim> password,
<trim prefix="values (" suffix=")" suffixOverrides=","> </if>
<if test="id != null"> <if test="displayName != null">
#{id,jdbcType=BIGINT}, display_name,
</if> </if>
<if test="name != null"> <if test="email != null">
#{name,jdbcType=VARCHAR}, email,
</if> </if>
<if test="password != null"> </trim>
#{password,jdbcType=VARCHAR}, <trim prefix="values (" suffix=")" suffixOverrides=",">
</if> <if test="id != null">
<if test="displayName != null"> #{id,jdbcType=BIGINT},
#{displayName,jdbcType=VARCHAR}, </if>
</if> <if test="name != null">
<if test="email != null"> #{name,jdbcType=VARCHAR},
#{email,jdbcType=VARCHAR}, </if>
</if> <if test="password != null">
</trim> #{password,jdbcType=VARCHAR},
</insert> </if>
<select id="countByExample" <if test="displayName != null">
parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample" #{displayName,jdbcType=VARCHAR},
resultType="java.lang.Long"> </if>
select count(*) from s2_user <if test="email != null">
<if test="_parameter != null"> #{email,jdbcType=VARCHAR},
<include refid="Example_Where_Clause"/> </if>
</if> </trim>
</select> </insert>
<update id="updateByPrimaryKeySelective" <select id="countByExample" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample" resultType="java.lang.Long">
parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO"> select count(*) from s2_user
update s2_user <if test="_parameter != null">
<set> <include refid="Example_Where_Clause" />
<if test="name != null"> </if>
name = #{name,jdbcType=VARCHAR}, </select>
</if> <update id="updateByPrimaryKeySelective" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
<if test="password != null"> update s2_user
password = #{password,jdbcType=VARCHAR}, <set>
</if> <if test="name != null">
<if test="displayName != null"> name = #{name,jdbcType=VARCHAR},
display_name = #{displayName,jdbcType=VARCHAR}, </if>
</if> <if test="password != null">
<if test="email != null"> password = #{password,jdbcType=VARCHAR},
email = #{email,jdbcType=VARCHAR}, </if>
</if> <if test="displayName != null">
</set> display_name = #{displayName,jdbcType=VARCHAR},
where id = #{id,jdbcType=BIGINT} </if>
</update> <if test="email != null">
<update id="updateByPrimaryKey" email = #{email,jdbcType=VARCHAR},
parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO"> </if>
update s2_user </set>
set name = #{name,jdbcType=VARCHAR}, where id = #{id,jdbcType=BIGINT}
password = #{password,jdbcType=VARCHAR}, </update>
display_name = #{displayName,jdbcType=VARCHAR}, <update id="updateByPrimaryKey" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
email = #{email,jdbcType=VARCHAR} update s2_user
where id = #{id,jdbcType=BIGINT} set name = #{name,jdbcType=VARCHAR},
</update> password = #{password,jdbcType=VARCHAR},
display_name = #{displayName,jdbcType=VARCHAR},
email = #{email,jdbcType=VARCHAR}
where id = #{id,jdbcType=BIGINT}
</update>
</mapper> </mapper>

View File

@@ -18,4 +18,5 @@ public class CorrectionInfo {
private String sql; private String sql;
private String preSql;
} }

View File

@@ -1,9 +1,10 @@
package com.tencent.supersonic.chat.api.pojo; package com.tencent.supersonic.chat.api.pojo;
import lombok.Data;
import java.util.HashSet; import java.util.HashSet;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import lombok.Data;
@Data @Data
public class ModelSchema { public class ModelSchema {

View File

@@ -2,11 +2,12 @@ package com.tencent.supersonic.chat.api.pojo;
import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import java.util.ArrayList;
import java.util.List;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.List;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
public class QueryContext { public class QueryContext {

View File

@@ -1,39 +1,30 @@
package com.tencent.supersonic.chat.api.pojo; package com.tencent.supersonic.chat.api.pojo;
import com.google.common.base.Objects; import com.google.common.base.Objects;
import java.io.Serializable; import java.io.Serializable;
import java.util.List; import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.Getter; import lombok.Getter;
import lombok.Builder;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
@Data @Data
@Getter @Getter
@Builder @Builder
@AllArgsConstructor
@NoArgsConstructor @NoArgsConstructor
public class SchemaElement implements Serializable { public class SchemaElement implements Serializable {
private Long model; private Long model;
private Long id; private Long id;
private String name; private String name;
private String bizName; private String bizName;
private Long useCnt; private Long useCnt;
private SchemaElementType type; private SchemaElementType type;
private List<String> alias; private List<String> alias;
public SchemaElement(Long model, Long id, String name, String bizName, private List<SchemaValueMap> schemaValueMaps;
Long useCnt, SchemaElementType type, List<String> alias) {
this.model = model;
this.id = id;
this.name = name;
this.bizName = bizName;
this.useCnt = useCnt;
this.type = type;
this.alias = alias;
}
@Override @Override
public boolean equals(Object o) { public boolean equals(Object o) {
@@ -54,4 +45,5 @@ public class SchemaElement implements Serializable {
public int hashCode() { public int hashCode() {
return Objects.hashCode(model, id, name, bizName, useCnt, type); return Objects.hashCode(model, id, name, bizName, useCnt, type);
} }
} }

View File

@@ -0,0 +1,24 @@
package com.tencent.supersonic.chat.api.pojo;
import java.util.ArrayList;
import java.util.List;
import lombok.Data;
@Data
public class SchemaValueMap {
/**
* dimension value in db
*/
private String techName;
/**
* dimension value for result show
*/
private String bizName;
/**
* dimension value for user query
*/
private List<String> alias = new ArrayList<>();
}

View File

@@ -7,7 +7,6 @@ import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public class SemanticSchema implements Serializable { public class SemanticSchema implements Serializable {
private List<ModelSchema> modelSchemaList; private List<ModelSchema> modelSchemaList;
public SemanticSchema(List<ModelSchema> modelSchemaList) { public SemanticSchema(List<ModelSchema> modelSchemaList) {

View File

@@ -1,8 +1,9 @@
package com.tencent.supersonic.chat.api.pojo.request; package com.tencent.supersonic.chat.api.pojo.request;
import java.util.List;
import lombok.Data; import lombok.Data;
import java.util.List;
@Data @Data
public class ChatAggConfigReq { public class ChatAggConfigReq {

View File

@@ -1,10 +1,12 @@
package com.tencent.supersonic.chat.api.pojo.request; package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.common.pojo.enums.StatusEnum; import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import java.util.List;
import lombok.Data; import lombok.Data;
import lombok.ToString; import lombok.ToString;
import java.util.List;
/** /**
* extended information command about model * extended information command about model
*/ */

View File

@@ -2,9 +2,10 @@ package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import lombok.Data;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import lombok.Data;
@Data @Data
public class ChatDefaultConfigReq { public class ChatDefaultConfigReq {

View File

@@ -1,8 +1,9 @@
package com.tencent.supersonic.chat.api.pojo.request; package com.tencent.supersonic.chat.api.pojo.request;
import java.util.List;
import lombok.Data; import lombok.Data;
import java.util.List;
@Data @Data
public class ChatDetailConfigReq { public class ChatDetailConfigReq {

View File

@@ -0,0 +1,12 @@
package com.tencent.supersonic.chat.api.pojo.request;
import lombok.Data;
@Data
public class DimensionValueReq {
private Long modelId;
private String bizName;
private Object value;
}

View File

@@ -7,13 +7,12 @@ import lombok.Data;
@Data @Data
public class ExecuteQueryReq { public class ExecuteQueryReq {
private User user; private User user;
private Integer agentId; private Integer agentId;
private Integer chatId; private Integer chatId;
private String queryText; private String queryText;
private Long queryId; private Long queryId = 7L;
private Integer parseId; private Integer parseId = 2;
private SemanticParseInfo parseInfo; private SemanticParseInfo parseInfo;
private boolean saveAnswer = true; private boolean saveAnswer = true;
} }

View File

@@ -1,8 +1,9 @@
package com.tencent.supersonic.chat.api.pojo.request; package com.tencent.supersonic.chat.api.pojo.request;
import lombok.Data;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import lombok.Data;
/** /**
* advanced knowledge config * advanced knowledge config

View File

@@ -1,7 +1,9 @@
package com.tencent.supersonic.chat.api.pojo.request; package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import lombok.Data; import lombok.Data;
/** /**

View File

@@ -4,19 +4,20 @@ package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.Order; import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
import lombok.Data; import lombok.Data;
@Data @Data
public class QueryDataReq { public class QueryDataReq {
String queryMode; String queryMode;
SchemaElement model; SchemaElement model;
Set<SchemaElement> metrics = new HashSet<>(); Set<SchemaElement> metrics = new HashSet<>();
Set<SchemaElement> dimensions = new HashSet<>(); Set<SchemaElement> dimensions = new HashSet<>();
Set<QueryFilter> dimensionFilters = new HashSet<>(); Set<QueryFilter> dimensionFilters = new HashSet<>();
Set<QueryFilter> metricFilters = new HashSet<>(); Set<QueryFilter> metricFilters = new HashSet<>();
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
private Set<Order> orders = new HashSet<>(); private Set<Order> orders = new HashSet<>();
private DateConf dateInfo; private DateConf dateInfo;
private Long limit; private Long limit;

View File

@@ -1,14 +1,13 @@
package com.tencent.supersonic.chat.api.pojo.request; package com.tencent.supersonic.chat.api.pojo.request;
import lombok.Data;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import lombok.Data;
@Data @Data
public class QueryFilters { public class QueryFilters {
private List<QueryFilter> filters = new ArrayList<>(); private List<QueryFilter> filters = new ArrayList<>();
private Map<String, Object> params = new HashMap<>(); private Map<String, Object> params = new HashMap<>();
} }

View File

@@ -5,7 +5,6 @@ import lombok.Data;
@Data @Data
public class QueryReq { public class QueryReq {
private String queryText; private String queryText;
private Integer chatId; private Integer chatId;
private Long modelId = 0L; private Long modelId = 0L;

View File

@@ -2,9 +2,10 @@ package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeAdvancedConfig; import com.tencent.supersonic.chat.api.pojo.request.KnowledgeAdvancedConfig;
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeInfoReq; import com.tencent.supersonic.chat.api.pojo.request.KnowledgeInfoReq;
import java.util.List;
import lombok.Data; import lombok.Data;
import java.util.List;
@Data @Data
public class ChatAggRichConfigResp { public class ChatAggRichConfigResp {

View File

@@ -4,8 +4,10 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq; import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq; import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
import com.tencent.supersonic.common.pojo.enums.StatusEnum; import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import java.util.Date; import java.util.Date;
import java.util.List; import java.util.List;
import lombok.Data; import lombok.Data;
@Data @Data

View File

@@ -4,6 +4,7 @@ import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
import com.tencent.supersonic.common.pojo.enums.StatusEnum; import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import java.util.Date; import java.util.Date;
import java.util.List; import java.util.List;
import lombok.Data; import lombok.Data;
@Data @Data

View File

@@ -4,9 +4,10 @@ package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq; import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import java.util.List;
import lombok.Data; import lombok.Data;
import java.util.List;
@Data @Data
public class ChatDefaultRichConfigResp { public class ChatDefaultRichConfigResp {

View File

@@ -2,9 +2,10 @@ package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeAdvancedConfig; import com.tencent.supersonic.chat.api.pojo.request.KnowledgeAdvancedConfig;
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeInfoReq; import com.tencent.supersonic.chat.api.pojo.request.KnowledgeInfoReq;
import java.util.List;
import lombok.Data; import lombok.Data;
import java.util.List;
@Data @Data
public class ChatDetailRichConfigResp { public class ChatDetailRichConfigResp {

View File

@@ -1,14 +1,14 @@
package com.tencent.supersonic.chat.api.pojo.response; package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import java.util.List; import java.util.List;
import lombok.Data; import lombok.Data;
@Data @Data
public class EntityRichInfoResp { public class EntityRichInfoResp {
/** /**
* entity alias * entity alias
*/ */
private List<String> names; private List<String> names;

View File

@@ -1,14 +1,14 @@
package com.tencent.supersonic.chat.api.pojo.response; package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq; import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
import java.util.List;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import java.util.List;
@Data @Data
@AllArgsConstructor @AllArgsConstructor
public class RecommendQuestionResp { public class RecommendQuestionResp {
private Long modelId; private Long modelId;
private List<RecommendedQuestionReq> recommendedQuestions; private List<RecommendedQuestionReq> recommendedQuestions;
} }

View File

@@ -1,12 +1,12 @@
package com.tencent.supersonic.chat.api.pojo.response; package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import java.util.List;
import lombok.Data; import lombok.Data;
import java.util.List;
@Data @Data
public class RecommendResp { public class RecommendResp {
private List<SchemaElement> dimensions; private List<SchemaElement> dimensions;
private List<SchemaElement> metrics; private List<SchemaElement> metrics;
} }

View File

@@ -136,6 +136,14 @@
<artifactId>xk-time</artifactId> <artifactId>xk-time</artifactId>
<version>${xk.time.version}</version> <version>${xk.time.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<version>${mockito-inline.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
</project> </project>

View File

@@ -3,12 +3,13 @@ package com.tencent.supersonic.chat.config;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq; import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq; import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq; import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
import com.tencent.supersonic.common.pojo.RecordInfo;
import com.tencent.supersonic.common.pojo.enums.StatusEnum; import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import java.util.List; import com.tencent.supersonic.common.pojo.RecordInfo;
import lombok.Data; import lombok.Data;
import lombok.ToString; import lombok.ToString;
import java.util.List;
@Data @Data
@ToString @ToString
public class ChatConfig { public class ChatConfig {

View File

@@ -7,7 +7,6 @@ import org.springframework.context.annotation.Configuration;
@Configuration @Configuration
@Data @Data
public class FunctionCallInfoConfig { public class FunctionCallInfoConfig {
@Value("${functionCall.url:}") @Value("${functionCall.url:}")
private String url; private String url;
} }

View File

@@ -0,0 +1,43 @@
package com.tencent.supersonic.chat.config;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.PropertySource;
@Configuration
@Data
@PropertySource("classpath:optimization.properties")
//@ComponentScan(basePackages = "com.tencent.supersonic.chat")
public class OptimizationConfig {
@Value("${one.detection.size}")
private Integer oneDetectionSize;
@Value("${one.detection.max.size}")
private Integer oneDetectionMaxSize;
@Value("${metric.dimension.min.threshold}")
private Double metricDimensionMinThresholdConfig;
@Value("${metric.dimension.threshold}")
private Double metricDimensionThresholdConfig;
@Value("${dimension.value.threshold}")
private Double dimensionValueThresholdConfig;
@Value("${function.bonus.threshold}")
private Double functionBonusThreshold;
@Value("${long.text.threshold}")
private Double longTextThreshold;
@Value("${short.text.threshold}")
private Double shortTextThreshold;
@Value("${query.text.length.threshold}")
private Integer queryTextLengthThreshold;
@Value("${candidate.threshold}")
private Double candidateThreshold;
}

View File

@@ -20,6 +20,7 @@ public class DateFieldCorrector extends BaseSemanticCorrector {
String currentDate = DSLDateHelper.getCurrentDate(correctionInfo.getParseInfo().getModelId()); String currentDate = DSLDateHelper.getCurrentDate(correctionInfo.getParseInfo().getModelId());
sql = SqlParserUpdateHelper.addWhere(sql, DATE_FIELD, currentDate); sql = SqlParserUpdateHelper.addWhere(sql, DATE_FIELD, currentDate);
} }
correctionInfo.setPreSql(correctionInfo.getSql());
correctionInfo.setSql(sql); correctionInfo.setSql(sql);
return correctionInfo; return correctionInfo;
} }

View File

@@ -9,9 +9,11 @@ public class FieldCorrector extends BaseSemanticCorrector {
@Override @Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) { public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
String replaceFields = SqlParserUpdateHelper.replaceFields(correctionInfo.getSql(), String preSql = correctionInfo.getSql();
correctionInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceFields(preSql,
getFieldToBizName(correctionInfo.getParseInfo().getModelId())); getFieldToBizName(correctionInfo.getParseInfo().getModelId()));
correctionInfo.setSql(replaceFields); correctionInfo.setSql(sql);
return correctionInfo; return correctionInfo;
} }
} }

View File

@@ -0,0 +1,50 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class FieldNameCorrector extends BaseSemanticCorrector {
@Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
Object context = correctionInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
if (Objects.isNull(context)) {
return correctionInfo;
}
DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class);
if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) {
return correctionInfo;
}
LLMReq llmReq = dslParseResult.getLlmReq();
List<ElementValue> linking = llmReq.getLinking();
if (CollectionUtils.isEmpty(linking)) {
return correctionInfo;
}
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
Collectors.groupingBy(ElementValue::getFieldValue,
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
String preSql = correctionInfo.getSql();
correctionInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceFieldNameByValue(preSql, fieldValueToFieldNames);
correctionInfo.setSql(sql);
return correctionInfo;
}
}

View File

@@ -1,18 +1,19 @@
package com.tencent.supersonic.chat.corrector; package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo; import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult; import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq; import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue; import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@Slf4j @Slf4j
@@ -20,29 +21,61 @@ public class FieldValueCorrector extends BaseSemanticCorrector {
@Override @Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) { public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Long modelId = correctionInfo.getParseInfo().getModel().getId();
List<SchemaElement> dimensions = semanticSchema.getDimensions().stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
.collect(Collectors.toList());
Object context = correctionInfo.getParseInfo().getProperties().get(Constants.CONTEXT); if (CollectionUtils.isEmpty(dimensions)) {
if (Objects.isNull(context)) {
return correctionInfo; return correctionInfo;
} }
DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class); Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) { String preSql = correctionInfo.getSql();
return correctionInfo; correctionInfo.setPreSql(preSql);
} String sql = SqlParserUpdateHelper.replaceValue(preSql, aliasAndBizNameToTechName);
LLMReq llmReq = dslParseResult.getLlmReq();
List<ElementValue> linking = llmReq.getLinking();
if (CollectionUtils.isEmpty(linking)) {
return correctionInfo;
}
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
Collectors.groupingBy(ElementValue::getFieldValue,
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
String sql = SqlParserUpdateHelper.replaceValueFields(correctionInfo.getSql(), fieldValueToFieldNames);
correctionInfo.setSql(sql); correctionInfo.setSql(sql);
return correctionInfo; return correctionInfo;
} }
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
if (CollectionUtils.isEmpty(dimensions)) {
return new HashMap<>();
}
Map<String, Map<String, String>> result = new HashMap<>();
for (SchemaElement dimension : dimensions) {
if (Objects.isNull(dimension)
|| Strings.isEmpty(dimension.getBizName())
|| CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) {
continue;
}
String bizName = dimension.getBizName();
Map<String, String> aliasAndBizNameToTechName = new HashMap<>();
for (SchemaValueMap valueMap : dimension.getSchemaValueMaps()) {
if (Objects.isNull(valueMap) || Strings.isEmpty(valueMap.getTechName())) {
continue;
}
if (Strings.isNotEmpty(valueMap.getBizName())) {
aliasAndBizNameToTechName.put(valueMap.getBizName(), valueMap.getTechName());
}
if (!CollectionUtils.isEmpty(valueMap.getAlias())) {
valueMap.getAlias().stream().forEach(alias -> {
if (Strings.isNotEmpty(alias)) {
aliasAndBizNameToTechName.put(alias, valueMap.getTechName());
}
});
}
}
if (!CollectionUtils.isEmpty(aliasAndBizNameToTechName)) {
result.put(bizName, aliasAndBizNameToTechName);
}
}
return result;
}
} }

View File

@@ -9,8 +9,10 @@ public class FunctionCorrector extends BaseSemanticCorrector {
@Override @Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) { public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
String replaceFunction = SqlParserUpdateHelper.replaceFunction(correctionInfo.getSql()); String preSql = correctionInfo.getSql();
correctionInfo.setSql(replaceFunction); correctionInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceFunction(preSql);
correctionInfo.setSql(sql);
return correctionInfo; return correctionInfo;
} }
} }

View File

@@ -20,14 +20,15 @@ public class QueryFilterAppend extends BaseSemanticCorrector {
@Override @Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) throws JSQLParserException { public CorrectionInfo corrector(CorrectionInfo correctionInfo) throws JSQLParserException {
String queryFilter = getQueryFilter(correctionInfo.getQueryFilters()); String queryFilter = getQueryFilter(correctionInfo.getQueryFilters());
String sql = correctionInfo.getSql(); String preSql = correctionInfo.getSql();
if (StringUtils.isNotEmpty(queryFilter)) { if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to sql :{}", queryFilter); log.info("add queryFilter to preSql :{}", queryFilter);
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter); Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
sql = SqlParserUpdateHelper.addWhere(sql, expression); String sql = SqlParserUpdateHelper.addWhere(preSql, expression);
correctionInfo.setPreSql(preSql);
correctionInfo.setSql(sql);
} }
correctionInfo.setSql(sql);
return correctionInfo; return correctionInfo;
} }

View File

@@ -15,24 +15,24 @@ public class SelectFieldAppendCorrector extends BaseSemanticCorrector {
@Override @Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) { public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
String sql = correctionInfo.getSql(); String preSql = correctionInfo.getSql();
if (SqlParserSelectHelper.hasAggregateFunction(sql)) { if (SqlParserSelectHelper.hasAggregateFunction(preSql)) {
return correctionInfo; return correctionInfo;
} }
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql)); Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(preSql));
Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(sql)); Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(preSql));
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) { if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
return correctionInfo; return correctionInfo;
} }
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql)); whereFields.addAll(SqlParserSelectHelper.getOrderByFields(preSql));
whereFields.removeAll(selectFields); whereFields.removeAll(selectFields);
whereFields.remove(TimeDimensionEnum.DAY.getName()); whereFields.remove(TimeDimensionEnum.DAY.getName());
whereFields.remove(TimeDimensionEnum.WEEK.getName()); whereFields.remove(TimeDimensionEnum.WEEK.getName());
whereFields.remove(TimeDimensionEnum.MONTH.getName()); whereFields.remove(TimeDimensionEnum.MONTH.getName());
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(preSql, new ArrayList<>(whereFields));
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields)); correctionInfo.setPreSql(preSql);
correctionInfo.setSql(replaceFields); correctionInfo.setSql(replaceFields);
return correctionInfo; return correctionInfo;
} }

View File

@@ -12,9 +12,10 @@ public class TableNameCorrector extends BaseSemanticCorrector {
@Override @Override
public CorrectionInfo corrector(CorrectionInfo correctionInfo) { public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
Long modelId = correctionInfo.getParseInfo().getModelId(); Long modelId = correctionInfo.getParseInfo().getModelId();
String sqlOutput = correctionInfo.getSql(); String preSql = correctionInfo.getSql();
String replaceTable = SqlParserUpdateHelper.replaceTable(sqlOutput, TABLE_PREFIX + modelId); correctionInfo.setPreSql(preSql);
correctionInfo.setSql(replaceTable); String sql = SqlParserUpdateHelper.replaceTable(preSql, TABLE_PREFIX + modelId);
correctionInfo.setSql(sql);
return correctionInfo; return correctionInfo;
} }

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.service.SchemaService; import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.knowledge.utils.HanlpHelper; import com.tencent.supersonic.knowledge.utils.HanlpHelper;
@@ -105,8 +106,9 @@ public class FuzzyNameMapper implements SchemaMapper {
private Double getThreshold(QueryContext queryContext, MapperHelper mapperHelper) { private Double getThreshold(QueryContext queryContext, MapperHelper mapperHelper) {
Double metricDimensionThresholdConfig = mapperHelper.getMetricDimensionThresholdConfig(); OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
Double metricDimensionMinThresholdConfig = mapperHelper.getMetricDimensionMinThresholdConfig(); Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo() Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo()
.getModelElementMatches(); .getModelElementMatches();

View File

@@ -9,13 +9,13 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.dictionary.DictWordType; import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.MapResult; import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder; import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
import com.tencent.supersonic.knowledge.dictionary.builder.WordBuilderFactory; import com.tencent.supersonic.knowledge.dictionary.builder.WordBuilderFactory;
import com.tencent.supersonic.knowledge.utils.HanlpHelper; import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@@ -25,6 +25,8 @@ import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
@Slf4j @Slf4j
public class HanlpDictMapper implements SchemaMapper { public class HanlpDictMapper implements SchemaMapper {
@@ -83,11 +85,14 @@ public class HanlpDictMapper implements SchemaMapper {
Long elementID = baseWordBuilder.getElementID(nature); Long elementID = baseWordBuilder.getElementID(nature);
Long frequency = wordNatureToFrequency.get(mapResult.getName() + nature); Long frequency = wordNatureToFrequency.get(mapResult.getName() + nature);
SchemaElement element = modelSchema.getElement(elementType, elementID); SchemaElement elementDb = modelSchema.getElement(elementType, elementID);
if (Objects.isNull(element)) { if (Objects.isNull(elementDb)) {
log.info("element is null, elementType:{},elementID:{}", elementType, elementID); log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
continue; continue;
} }
SchemaElement element = new SchemaElement();
BeanUtils.copyProperties(elementDb, element);
element.setAlias(getAlias(elementDb));
if (element.getType().equals(SchemaElementType.VALUE)) { if (element.getType().equals(SchemaElementType.VALUE)) {
element.setName(mapResult.getName()); element.setName(mapResult.getName());
} }
@@ -124,4 +129,16 @@ public class HanlpDictMapper implements SchemaMapper {
} }
return matches; return matches;
} }
public List<String> getAlias(SchemaElement element) {
if (!SchemaElementType.VALUE.equals(element.getType())) {
return element.getAlias();
}
if (CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(element.getName())) {
return element.getAlias().stream()
.filter(aliasItem -> aliasItem.contains(element.getName()))
.collect(Collectors.toList());
}
return element.getAlias();
}
} }

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.algorithm.EditDistance; import com.hankcs.hanlp.algorithm.EditDistance;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.service.AgentService; import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.utils.NatureHelper; import com.tencent.supersonic.knowledge.utils.NatureHelper;
@@ -13,7 +14,7 @@ import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.Data; import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
/** /**
@@ -25,17 +26,8 @@ import org.springframework.stereotype.Service;
@Slf4j @Slf4j
public class MapperHelper { public class MapperHelper {
@Value("${one.detection.size:8}") @Autowired
private Integer oneDetectionSize; private OptimizationConfig optimizationConfig;
@Value("${one.detection.max.size:20}")
private Integer oneDetectionMaxSize;
@Value("${metric.dimension.threshold:0.3}")
private Double metricDimensionThresholdConfig;
@Value("${metric.dimension.min.threshold:0.3}")
private Double metricDimensionMinThresholdConfig;
@Value("${dimension.value.threshold:0.5}")
private Double dimensionValueThresholdConfig;
public Integer getStepIndex(Map<Integer, Integer> regOffsetToLength, Integer index) { public Integer getStepIndex(Map<Integer, Integer> regOffsetToLength, Integer index) {
Integer subRegLength = regOffsetToLength.get(index); Integer subRegLength = regOffsetToLength.get(index);
@@ -57,10 +49,11 @@ public class MapperHelper {
} }
public double getThresholdMatch(List<String> natures) { public double getThresholdMatch(List<String> natures) {
log.info("optimizationConfig:{}", optimizationConfig);
if (existDimensionValues(natures)) { if (existDimensionValues(natures)) {
return dimensionValueThresholdConfig; return optimizationConfig.getDimensionValueThresholdConfig();
} }
return metricDimensionThresholdConfig; return optimizationConfig.getMetricDimensionThresholdConfig();
} }
/*** /***
@@ -110,4 +103,4 @@ public class MapperHelper {
return detectModelIds; return detectModelIds;
} }
} }

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term; import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.knowledge.dictionary.MapResult; import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.service.SearchService; import com.tencent.supersonic.knowledge.service.SearchService;
@@ -31,6 +32,9 @@ public class QueryMatchStrategy implements MatchStrategy {
@Autowired @Autowired
private MapperHelper mapperHelper; private MapperHelper mapperHelper;
@Autowired
private OptimizationConfig optimizationConfig;
@Override @Override
public Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) { public Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
String text = queryReq.getQueryText(); String text = queryReq.getQueryText();
@@ -111,7 +115,7 @@ public class QueryMatchStrategy implements MatchStrategy {
String detectSegment = text.substring(index, i); String detectSegment = text.substring(index, i);
// step1. pre search // step1. pre search
Integer oneDetectionMaxSize = mapperHelper.getOneDetectionMaxSize(); Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize, agentId, LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize, agentId,
detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new)); detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
// step2. suffix search // step2. suffix search
@@ -153,7 +157,7 @@ public class QueryMatchStrategy implements MatchStrategy {
if (CollectionUtils.isNotEmpty(dimensionMetrics)) { if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
return dimensionMetrics; return dimensionMetrics;
} else { } else {
return mapResults.stream().limit(mapperHelper.getOneDetectionSize()).collect(Collectors.toList()); return mapResults.stream().limit(optimizationConfig.getOneDetectionSize()).collect(Collectors.toList());
} }
} }
} }

View File

@@ -4,7 +4,9 @@ package com.tencent.supersonic.chat.parser;
import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery; import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
/** /**
@@ -15,10 +17,6 @@ import lombok.extern.slf4j.Slf4j;
@Slf4j @Slf4j
public class SatisfactionChecker { public class SatisfactionChecker {
private static final double LONG_TEXT_THRESHOLD = 0.8;
private static final double SHORT_TEXT_THRESHOLD = 0.5;
private static final int QUERY_TEXT_LENGTH_THRESHOLD = 10;
// check all the parse info in candidate // check all the parse info in candidate
public static boolean check(QueryContext queryContext) { public static boolean check(QueryContext queryContext) {
for (SemanticQuery query : queryContext.getCandidateQueries()) { for (SemanticQuery query : queryContext.getCandidateQueries()) {
@@ -35,11 +33,12 @@ public class SatisfactionChecker {
private static boolean checkThreshold(String queryText, SemanticParseInfo semanticParseInfo) { private static boolean checkThreshold(String queryText, SemanticParseInfo semanticParseInfo) {
int queryTextLength = queryText.replaceAll(" ", "").length(); int queryTextLength = queryText.replaceAll(" ", "").length();
double degree = semanticParseInfo.getScore() / queryTextLength; double degree = semanticParseInfo.getScore() / queryTextLength;
if (queryTextLength > QUERY_TEXT_LENGTH_THRESHOLD) { OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
if (degree < LONG_TEXT_THRESHOLD) { if (queryTextLength > optimizationConfig.getQueryTextLengthThreshold()) {
if (degree < optimizationConfig.getLongTextThreshold()) {
return false; return false;
} }
} else if (degree < SHORT_TEXT_THRESHOLD) { } else if (degree < optimizationConfig.getShortTextThreshold()) {
return false; return false;
} }
log.info("queryMode:{}, degree:{}, parse info:{}", log.info("queryMode:{}, degree:{}, parse info:{}",

View File

@@ -93,19 +93,23 @@ public class LLMDslParser implements SemanticParser {
SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, dslTool, dslParseResult); SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, dslTool, dslParseResult);
String correctorSql = getCorrectorSql(queryCtx, parseInfo, llmResp.getSqlOutput()); CorrectionInfo correctionInfo = getCorrectorSql(queryCtx, parseInfo, llmResp.getSqlOutput());
llmResp.setCorrectorSql(correctorSql); llmResp.setCorrectorSql(correctionInfo.getSql());
setFilter(correctorSql, modelId, parseInfo); setFilter(correctionInfo, modelId, parseInfo);
} catch (Exception e) { } catch (Exception e) {
log.error("LLMDSLParser error", e); log.error("LLMDSLParser error", e);
} }
} }
public void setFilter(String correctorSql, Long modelId, SemanticParseInfo parseInfo) { public void setFilter(CorrectionInfo correctionInfo, Long modelId, SemanticParseInfo parseInfo) {
String correctorSql = correctionInfo.getPreSql();
if (StringUtils.isEmpty(correctorSql)) {
correctorSql = correctionInfo.getSql();
}
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql); List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
if (CollectionUtils.isEmpty(expressions)) { if (CollectionUtils.isEmpty(expressions)) {
return; return;
@@ -200,7 +204,7 @@ public class LLMDslParser implements SemanticParser {
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue()); return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
} }
private String getCorrectorSql(QueryContext queryCtx, SemanticParseInfo parseInfo, String sql) { private CorrectionInfo getCorrectorSql(QueryContext queryCtx, SemanticParseInfo parseInfo, String sql) {
CorrectionInfo correctionInfo = CorrectionInfo.builder() CorrectionInfo correctionInfo = CorrectionInfo.builder()
.queryFilters(queryCtx.getRequest().getQueryFilters()).sql(sql) .queryFilters(queryCtx.getRequest().getQueryFilters()).sql(sql)
@@ -217,7 +221,7 @@ public class LLMDslParser implements SemanticParser {
log.error("sqlCorrection:{} execute error,correctionInfo:{}", dslCorrection, correctionInfo, e); log.error("sqlCorrection:{} execute error,correctionInfo:{}", dslCorrection, correctionInfo, e);
} }
}); });
return correctionInfo.getSql(); return correctionInfo;
} }
private SemanticParseInfo getParseInfo(QueryContext queryCtx, Long modelId, DslTool dslTool, private SemanticParseInfo getParseInfo(QueryContext queryCtx, Long modelId, DslTool dslTool,

View File

@@ -0,0 +1,118 @@
package com.tencent.supersonic.chat.parser.plugin;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import org.springframework.util.CollectionUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public abstract class PluginParser implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
if (!checkPreCondition(queryContext)) {
return;
}
PluginRecallResult pluginRecallResult = recallPlugin(queryContext);
if (pluginRecallResult == null) {
return;
}
buildQuery(queryContext, pluginRecallResult);
}
public abstract boolean checkPreCondition(QueryContext queryContext);
public abstract PluginRecallResult recallPlugin(QueryContext queryContext);
public void buildQuery(QueryContext queryContext, PluginRecallResult pluginRecallResult) {
Plugin plugin = pluginRecallResult.getPlugin();
for (Long modelId : pluginRecallResult.getModelIds()) {
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin, queryContext.getRequest(),
queryContext.getMapInfo().getMatchedElements(modelId), pluginRecallResult.getDistance());
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
semanticParseInfo.setScore(pluginRecallResult.getScore());
pluginQuery.setParseInfo(semanticParseInfo);
queryContext.getCandidateQueries().add(pluginQuery);
if (plugin.isContainsAllModel()) {
break;
}
}
}
protected List<Plugin> getPluginList(QueryContext queryContext) {
return PluginManager.getPluginAgentCanSupport(queryContext.getRequest().getAgentId());
}
protected SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin, QueryReq queryReq,
List<SchemaElementMatch> schemaElementMatches, double distance) {
if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) {
modelId = plugin.getModelList().get(0);
}
SchemaElement model = new SchemaElement();
model.setModel(modelId);
model.setId(modelId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setElementMatches(schemaElementMatches);
semanticParseInfo.setModel(model);
Map<String, Object> properties = new HashMap<>();
PluginParseResult pluginParseResult = new PluginParseResult();
pluginParseResult.setPlugin(plugin);
pluginParseResult.setRequest(queryReq);
pluginParseResult.setDistance(distance);
properties.put(Constants.CONTEXT, pluginParseResult);
properties.put("type", "plugin");
properties.put("name", plugin.getName());
semanticParseInfo.setProperties(properties);
semanticParseInfo.setScore(distance);
fillSemanticParseInfo(semanticParseInfo);
setEntity(modelId, semanticParseInfo);
return semanticParseInfo;
}
private void setEntity(Long modelId, SemanticParseInfo semanticParseInfo) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
ModelSchema modelSchema = semanticService.getModelSchema(modelId);
if (modelSchema != null && modelSchema.getEntity() != null) {
semanticParseInfo.setEntity(modelSchema.getEntity());
}
}
private void fillSemanticParseInfo(SemanticParseInfo semanticParseInfo) {
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
if (CollectionUtils.isEmpty(schemaElementMatches)) {
return;
}
schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.forEach(schemaElementMatch -> {
QueryFilter queryFilter = new QueryFilter();
queryFilter.setValue(schemaElementMatch.getWord());
queryFilter.setElementID(schemaElementMatch.getElement().getId());
queryFilter.setName(schemaElementMatch.getElement().getName());
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
semanticParseInfo.getDimensionFilters().add(queryFilter);
});
}
}

View File

@@ -1,63 +1,54 @@
package com.tencent.supersonic.chat.parser.plugin.embedding; package com.tencent.supersonic.chat.parser.plugin.embedding;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.ParseMode; import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
import com.tencent.supersonic.chat.plugin.Plugin; import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager; import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseResult; import com.tencent.supersonic.chat.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.HashMap;
import java.util.Comparator;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@Slf4j @Slf4j
public class EmbeddingBasedParser implements SemanticParser { public class EmbeddingBasedParser extends PluginParser {
@Override @Override
public void parse(QueryContext queryContext, ChatContext chatContext) { public boolean checkPreCondition(QueryContext queryContext) {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
if (StringUtils.isBlank(embeddingConfig.getUrl())) { if (StringUtils.isBlank(embeddingConfig.getUrl())) {
return; return false;
} }
log.info("EmbeddingBasedParser parser query ctx: {}, chat ctx: {}", queryContext, chatContext); List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
String text = queryContext.getRequest().getQueryText(); for (SemanticQuery semanticQuery : semanticQueries) {
List<RecallRetrieval> embeddingRetrievals = recallResult(text); if (queryContext.getRequest().getQueryText().length() <= semanticQuery.getParseInfo().getScore()) {
choosePlugin(embeddingRetrievals, queryContext); return false;
}
}
return true;
} }
private void choosePlugin(List<RecallRetrieval> embeddingRetrievals, @Override
QueryContext queryContext) { public PluginRecallResult recallPlugin(QueryContext queryContext) {
String text = queryContext.getRequest().getQueryText();
List<RecallRetrieval> embeddingRetrievals = embeddingRecall(text);
if (CollectionUtils.isEmpty(embeddingRetrievals)) { if (CollectionUtils.isEmpty(embeddingRetrievals)) {
return; return null;
} }
List<Plugin> plugins = getPluginList(queryContext); List<Plugin> plugins = getPluginList(queryContext);
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p)); Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) { for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId())); Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
if (plugin == null || DslQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) { if (plugin == null) {
continue; continue;
} }
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext); Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
@@ -65,69 +56,19 @@ public class EmbeddingBasedParser implements SemanticParser {
if (pair.getLeft()) { if (pair.getLeft()) {
Set<Long> modelList = pair.getRight(); Set<Long> modelList = pair.getRight();
if (CollectionUtils.isEmpty(modelList)) { if (CollectionUtils.isEmpty(modelList)) {
return; continue;
} }
for (Long modelId : modelList) { plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
buildQuery(plugin, Double.parseDouble(embeddingRetrieval.getDistance()), modelId, queryContext, double distance = Double.parseDouble(embeddingRetrieval.getDistance());
queryContext.getMapInfo().getMatchedElements(modelId)); double score = queryContext.getRequest().getQueryText().length() * (1 - distance);
if (plugin.isContainsAllModel()) { return PluginRecallResult.builder()
break; .plugin(plugin).modelIds(modelList).score(score).distance(distance).build();
}
}
return;
} }
} }
return null;
} }
private void buildQuery(Plugin plugin, double distance, Long modelId, public List<RecallRetrieval> embeddingRecall(String embeddingText) {
QueryContext queryContext, List<SchemaElementMatch> schemaElementMatches) {
log.info("EmbeddingBasedParser Model: {} choose plugin: [{} {}]", modelId, plugin.getId(), plugin.getName());
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin, queryContext.getRequest(),
schemaElementMatches, distance);
double score = queryContext.getRequest().getQueryText().length() * (1 - distance);
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
semanticParseInfo.setScore(score);
pluginQuery.setParseInfo(semanticParseInfo);
queryContext.getCandidateQueries().add(pluginQuery);
}
private SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin, QueryReq queryReq,
List<SchemaElementMatch> schemaElementMatches, double distance) {
if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) {
modelId = plugin.getModelList().get(0);
}
SchemaElement model = new SchemaElement();
model.setModel(modelId);
model.setId(modelId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setElementMatches(schemaElementMatches);
semanticParseInfo.setModel(model);
Map<String, Object> properties = new HashMap<>();
PluginParseResult pluginParseResult = new PluginParseResult();
pluginParseResult.setPlugin(plugin);
pluginParseResult.setRequest(queryReq);
pluginParseResult.setDistance(distance);
properties.put(Constants.CONTEXT, pluginParseResult);
properties.put("type", "plugin");
properties.put("name", plugin.getName());
semanticParseInfo.setProperties(properties);
semanticParseInfo.setScore(distance);
fillSemanticParseInfo(semanticParseInfo);
setEntity(modelId, semanticParseInfo);
return semanticParseInfo;
}
private void setEntity(Long modelId, SemanticParseInfo semanticParseInfo) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
ModelSchema modelSchema = semanticService.getModelSchema(modelId);
if (modelSchema != null && modelSchema.getEntity() != null) {
semanticParseInfo.setEntity(modelSchema.getEntity());
}
}
public List<RecallRetrieval> recallResult(String embeddingText) {
try { try {
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class); PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
EmbeddingResp embeddingResp = pluginManager.recognize(embeddingText); EmbeddingResp embeddingResp = pluginManager.recognize(embeddingText);
@@ -144,26 +85,4 @@ public class EmbeddingBasedParser implements SemanticParser {
return Lists.newArrayList(); return Lists.newArrayList();
} }
private void fillSemanticParseInfo(SemanticParseInfo semanticParseInfo) {
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
if (!CollectionUtils.isEmpty(schemaElementMatches)) {
schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.forEach(schemaElementMatch -> {
QueryFilter queryFilter = new QueryFilter();
queryFilter.setValue(schemaElementMatch.getWord());
queryFilter.setElementID(schemaElementMatch.getElement().getId());
queryFilter.setName(schemaElementMatch.getElement().getName());
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
semanticParseInfo.getDimensionFilters().add(queryFilter);
});
}
}
protected List<Plugin> getPluginList(QueryContext queryContext) {
return PluginManager.getPluginAgentCanSupport(queryContext.getRequest().getAgentId());
}
} }

View File

@@ -1,9 +1,10 @@
package com.tencent.supersonic.chat.parser.plugin.embedding; package com.tencent.supersonic.chat.parser.plugin.embedding;
import java.util.List;
import lombok.Data; import lombok.Data;
import java.util.List;
@Data @Data
public class EmbeddingResp { public class EmbeddingResp {

View File

@@ -1,32 +1,23 @@
package com.tencent.supersonic.chat.parser.plugin.function; package com.tencent.supersonic.chat.parser.plugin.function;
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSON;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.config.FunctionCallInfoConfig; import com.tencent.supersonic.chat.config.FunctionCallInfoConfig;
import com.tencent.supersonic.chat.parser.ParseMode; import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.chat.parser.SatisfactionChecker; import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
import com.tencent.supersonic.chat.plugin.Plugin; import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager; import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseConfig; import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.plugin.PluginParseResult; import com.tencent.supersonic.chat.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery; import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
import com.tencent.supersonic.chat.service.PluginService; import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import java.net.URI; import java.net.URI;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.Objects; import java.util.Objects;
import java.util.Map;
import java.util.HashMap;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -42,85 +33,73 @@ import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriComponentsBuilder;
@Slf4j @Slf4j
public class FunctionBasedParser implements SemanticParser { public class FunctionBasedParser extends PluginParser {
@Override @Override
public void parse(QueryContext queryCtx, ChatContext chatCtx) { public boolean checkPreCondition(QueryContext queryContext) {
FunctionCallInfoConfig functionCallConfig = ContextUtils.getBean(FunctionCallInfoConfig.class); FunctionCallInfoConfig functionCallConfig = ContextUtils.getBean(FunctionCallInfoConfig.class);
PluginService pluginService = ContextUtils.getBean(PluginService.class);
String functionUrl = functionCallConfig.getUrl(); String functionUrl = functionCallConfig.getUrl();
if (StringUtils.isBlank(functionUrl) || SatisfactionChecker.check(queryCtx)) { if (StringUtils.isBlank(functionUrl) || SatisfactionChecker.check(queryContext)) {
log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl, log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl,
queryCtx.getRequest().getQueryText()); queryContext.getRequest().getQueryText());
return; return false;
} }
List<PluginParseConfig> functionDOList = getFunctionDO(queryCtx.getRequest().getModelId(), queryCtx); return true;
if (CollectionUtils.isEmpty(functionDOList)) { }
log.info("function call parser, plugin is empty, skip");
return; @Override
} public PluginRecallResult recallPlugin(QueryContext queryContext) {
FunctionResp functionResp = new FunctionResp(); PluginService pluginService = ContextUtils.getBean(PluginService.class);
if (functionDOList.size() == 1) { FunctionResp functionResp = functionCall(queryContext);
functionResp.setToolSelection(functionDOList.iterator().next().getName()); if (skipFunction(functionResp)) {
} else { return null;
FunctionReq functionReq = FunctionReq.builder()
.queryText(queryCtx.getRequest().getQueryText())
.pluginConfigs(functionDOList).build();
functionResp = requestFunction(functionUrl, functionReq);
} }
log.info("requestFunction result:{}", functionResp.getToolSelection()); log.info("requestFunction result:{}", functionResp.getToolSelection());
if (skipFunction(functionResp)) {
return;
}
PluginParseResult functionCallParseResult = new PluginParseResult();
String toolSelection = functionResp.getToolSelection(); String toolSelection = functionResp.getToolSelection();
Optional<Plugin> pluginOptional = pluginService.getPluginByName(toolSelection); Optional<Plugin> pluginOptional = pluginService.getPluginByName(toolSelection);
if (!pluginOptional.isPresent()) { if (!pluginOptional.isPresent()) {
log.info("pluginOptional is not exist:{}, skip the parse", toolSelection); log.info("pluginOptional is not exist:{}, skip the parse", toolSelection);
return; return null;
} }
Plugin plugin = pluginOptional.get(); Plugin plugin = pluginOptional.get();
plugin.setParseMode(ParseMode.FUNCTION_CALL); plugin.setParseMode(ParseMode.FUNCTION_CALL);
toolSelection = plugin.getType(); Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryContext);
functionCallParseResult.setPlugin(plugin); if (pluginResolveResult.getLeft()) {
log.info("QueryManager PluginQueryModes:{}", QueryManager.getPluginQueryModes()); Set<Long> modelList = pluginResolveResult.getRight();
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(toolSelection); if (CollectionUtils.isEmpty(modelList)) {
ModelResolver modelResolver = ComponentFactory.getModelResolver(); return null;
log.info("plugin ModelList:{}", plugin.getModelList()); }
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryCtx); double score = queryContext.getRequest().getQueryText().length();
Long modelId = modelResolver.resolve(queryCtx, chatCtx, pluginResolveResult.getRight()); return PluginRecallResult.builder().plugin(plugin).modelIds(modelList).score(score).build();
log.info("FunctionBasedParser modelId:{}", modelId);
if ((Objects.isNull(modelId) || modelId <= 0) && !plugin.isContainsAllModel()) {
log.info("Model is null, skip the parse, select tool: {}", toolSelection);
return;
} }
if (!plugin.getModelList().contains(modelId) && !plugin.isContainsAllModel()) { return null;
return; }
public FunctionResp functionCall(QueryContext queryContext) {
List<PluginParseConfig> pluginToFunctionCall =
getPluginToFunctionCall(queryContext.getRequest().getModelId(), queryContext);
if (CollectionUtils.isEmpty(pluginToFunctionCall)) {
log.info("function call parser, plugin is empty, skip");
return null;
} }
SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); FunctionCallInfoConfig functionCallConfig = ContextUtils.getBean(FunctionCallInfoConfig.class);
if (Objects.nonNull(modelId) && modelId > 0) { FunctionResp functionResp = new FunctionResp();
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId)); if (pluginToFunctionCall.size() == 1) {
functionResp.setToolSelection(pluginToFunctionCall.iterator().next().getName());
} else {
FunctionReq functionReq = FunctionReq.builder()
.queryText(queryContext.getRequest().getQueryText())
.pluginConfigs(pluginToFunctionCall).build();
functionResp = requestFunction(functionCallConfig.getUrl(), functionReq);
} }
functionCallParseResult.setRequest(queryCtx.getRequest()); return functionResp;
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, functionCallParseResult);
properties.put("type", "plugin");
properties.put("name", plugin.getName());
parseInfo.setProperties(properties);
parseInfo.setScore(queryCtx.getRequest().getQueryText().length());
parseInfo.setQueryMode(semanticQuery.getQueryMode());
SchemaElement model = new SchemaElement();
model.setModel(modelId);
model.setId(modelId);
parseInfo.setModel(model);
queryCtx.getCandidateQueries().add(semanticQuery);
} }
private boolean skipFunction(FunctionResp functionResp) { private boolean skipFunction(FunctionResp functionResp) {
return Objects.isNull(functionResp) || StringUtils.isBlank(functionResp.getToolSelection()); return Objects.isNull(functionResp) || StringUtils.isBlank(functionResp.getToolSelection());
} }
private List<PluginParseConfig> getFunctionDO(Long modelId, QueryContext queryContext) { private List<PluginParseConfig> getPluginToFunctionCall(Long modelId, QueryContext queryContext) {
log.info("user decide Model:{}", modelId); log.info("user decide Model:{}", modelId);
List<Plugin> plugins = getPluginList(queryContext); List<Plugin> plugins = getPluginList(queryContext);
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> { List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
@@ -150,7 +129,7 @@ public class FunctionBasedParser implements SemanticParser {
return true; return true;
} }
}).map(o -> JsonUtil.toObject(o.getParseModeConfig(), PluginParseConfig.class)).collect(Collectors.toList()); }).map(o -> JsonUtil.toObject(o.getParseModeConfig(), PluginParseConfig.class)).collect(Collectors.toList());
log.info("getFunctionDO:{}", JsonUtil.toString(functionDOList)); log.info("PluginToFunctionCall: {}", JsonUtil.toString(functionDOList));
return functionDOList; return functionDOList;
} }
@@ -173,8 +152,4 @@ public class FunctionBasedParser implements SemanticParser {
} }
return null; return null;
} }
protected List<Plugin> getPluginList(QueryContext queryContext) {
return PluginManager.getPluginAgentCanSupport(queryContext.getRequest().getAgentId());
}
} }

View File

@@ -1,7 +1,8 @@
package com.tencent.supersonic.chat.parser.plugin.function; package com.tencent.supersonic.chat.parser.plugin.function;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import java.util.List; import java.util.List;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;

View File

@@ -4,7 +4,6 @@ import lombok.Data;
@Data @Data
public class ModelMatchResult { public class ModelMatchResult {
private Integer count = 0; private Integer count = 0;
private double maxSimilarity; private double maxSimilarity;
} }

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.parser.plugin.function; package com.tencent.supersonic.chat.parser.plugin.function;
import lombok.Data;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import lombok.Data;
@Data @Data
public class Parameters { public class Parameters {

View File

@@ -68,6 +68,7 @@ public class AggregateTypeParser implements SemanticParser {
Map<AggregateTypeEnum, Integer> aggregateCount = new HashMap<>(REGX_MAP.size()); Map<AggregateTypeEnum, Integer> aggregateCount = new HashMap<>(REGX_MAP.size());
Map<AggregateTypeEnum, String> aggregateWord = new HashMap<>(REGX_MAP.size()); Map<AggregateTypeEnum, String> aggregateWord = new HashMap<>(REGX_MAP.size());
for (Map.Entry<AggregateTypeEnum, Pattern> entry : REGX_MAP.entrySet()) { for (Map.Entry<AggregateTypeEnum, Pattern> entry : REGX_MAP.entrySet()) {
Matcher matcher = entry.getValue().matcher(queryText); Matcher matcher = entry.getValue().matcher(queryText);
int count = 0; int count = 0;
@@ -90,7 +91,6 @@ public class AggregateTypeParser implements SemanticParser {
@AllArgsConstructor @AllArgsConstructor
class AggregateConf { class AggregateConf {
public AggregateTypeEnum type; public AggregateTypeEnum type;
public String detectWord; public String detectWord;
} }

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.persistence.dataobject; package com.tencent.supersonic.chat.persistence.dataobject;
import java.util.Date; import java.util.Date;
import lombok.Data; import lombok.Data;
import lombok.ToString; import lombok.ToString;

View File

@@ -6,6 +6,7 @@ import lombok.Data;
public class ChatDO { public class ChatDO {
private long chatId; private long chatId;
private Integer agentId;
private String chatName; private String chatName;
private String createTime; private String createTime;
private String lastTime; private String lastTime;

View File

@@ -1,10 +1,11 @@
package com.tencent.supersonic.chat.persistence.dataobject; package com.tencent.supersonic.chat.persistence.dataobject;
import com.tencent.supersonic.chat.config.DefaultMetric;
import com.tencent.supersonic.chat.config.Dim4Dict;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import com.tencent.supersonic.chat.config.DefaultMetric;
import com.tencent.supersonic.chat.config.Dim4Dict;
import lombok.Data; import lombok.Data;
import lombok.ToString; import lombok.ToString;

View File

@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.persistence.dataobject;
import java.util.Date; import java.util.Date;
public class PluginDO { public class PluginDO {
/** /**
* *
*/ */

View File

@@ -5,7 +5,6 @@ import java.util.Date;
import java.util.List; import java.util.List;
public class PluginDOExample { public class PluginDOExample {
/** /**
* s2_plugin * s2_plugin
*/ */
@@ -149,7 +148,6 @@ public class PluginDOExample {
* s2_plugin null * s2_plugin null
*/ */
protected abstract static class GeneratedCriteria { protected abstract static class GeneratedCriteria {
protected List<Criterion> criteria; protected List<Criterion> criteria;
protected GeneratedCriteria() { protected GeneratedCriteria() {
@@ -875,7 +873,6 @@ public class PluginDOExample {
* s2_plugin null * s2_plugin null
*/ */
public static class Criterion { public static class Criterion {
private String condition; private String condition;
private Object value; private Object value;

View File

@@ -10,7 +10,7 @@ public interface ChatMapper {
boolean createChat(ChatDO chatDO); boolean createChat(ChatDO chatDO);
List<ChatDO> getAll(String creator); List<ChatDO> getAll(String creator, Integer agentId);
Boolean updateChatName(Long chatId, String chatName, String lastTime, String creator); Boolean updateChatName(Long chatId, String chatName, String lastTime, String creator);

View File

@@ -2,9 +2,10 @@ package com.tencent.supersonic.chat.persistence.mapper;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample; import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample;
import java.util.List;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
import java.util.List;
@Mapper @Mapper
public interface ChatQueryDOMapper { public interface ChatQueryDOMapper {

View File

@@ -2,58 +2,67 @@ package com.tencent.supersonic.chat.persistence.mapper;
import com.tencent.supersonic.chat.persistence.dataobject.PluginDO; import com.tencent.supersonic.chat.persistence.dataobject.PluginDO;
import com.tencent.supersonic.chat.persistence.dataobject.PluginDOExample; import com.tencent.supersonic.chat.persistence.dataobject.PluginDOExample;
import java.util.List;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
import java.util.List;
@Mapper @Mapper
public interface PluginDOMapper { public interface PluginDOMapper {
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
long countByExample(PluginDOExample example); long countByExample(PluginDOExample example);
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
int deleteByPrimaryKey(Long id); int deleteByPrimaryKey(Long id);
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
int insert(PluginDO record); int insert(PluginDO record);
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
int insertSelective(PluginDO record); int insertSelective(PluginDO record);
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
List<PluginDO> selectByExampleWithBLOBs(PluginDOExample example); List<PluginDO> selectByExampleWithBLOBs(PluginDOExample example);
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
List<PluginDO> selectByExample(PluginDOExample example); List<PluginDO> selectByExample(PluginDOExample example);
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
PluginDO selectByPrimaryKey(Long id); PluginDO selectByPrimaryKey(Long id);
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
int updateByPrimaryKeySelective(PluginDO record); int updateByPrimaryKeySelective(PluginDO record);
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
int updateByPrimaryKeyWithBLOBs(PluginDO record); int updateByPrimaryKeyWithBLOBs(PluginDO record);
/** /**
*
* @mbg.generated * @mbg.generated
*/ */
int updateByPrimaryKey(PluginDO record); int updateByPrimaryKey(PluginDO record);

View File

@@ -1,9 +1,10 @@
package com.tencent.supersonic.chat.persistence.repository; package com.tencent.supersonic.chat.persistence.repository;
import com.tencent.supersonic.chat.config.ChatConfig;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter; import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp; import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.config.ChatConfig;
import java.util.List; import java.util.List;
public interface ChatConfigRepository { public interface ChatConfigRepository {

View File

@@ -8,7 +8,7 @@ public interface ChatRepository {
boolean createChat(ChatDO chatDO); boolean createChat(ChatDO chatDO);
List<ChatDO> getAll(String creator); List<ChatDO> getAll(String creator, Integer agentId);
Boolean updateChatName(Long chatId, String chatName, String lastTime, String creator); Boolean updateChatName(Long chatId, String chatName, String lastTime, String creator);

View File

@@ -2,10 +2,10 @@ package com.tencent.supersonic.chat.persistence.repository;
import com.tencent.supersonic.chat.persistence.dataobject.PluginDO; import com.tencent.supersonic.chat.persistence.dataobject.PluginDO;
import com.tencent.supersonic.chat.persistence.dataobject.PluginDOExample; import com.tencent.supersonic.chat.persistence.dataobject.PluginDOExample;
import java.util.List; import java.util.List;
public interface PluginRepository { public interface PluginRepository {
List<PluginDO> getPlugins(); List<PluginDO> getPlugins();
List<PluginDO> fetchPluginDOs(String queryText, String type); List<PluginDO> fetchPluginDOs(String queryText, String type);

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.persistence.repository.impl;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.persistence.dataobject.QueryDO; import com.tencent.supersonic.chat.persistence.dataobject.QueryDO;
import com.tencent.supersonic.chat.persistence.mapper.ChatMapper;
import com.tencent.supersonic.chat.persistence.repository.ChatRepository; import com.tencent.supersonic.chat.persistence.repository.ChatRepository;
import com.tencent.supersonic.chat.persistence.mapper.ChatMapper;
import java.util.List; import java.util.List;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Primary; import org.springframework.context.annotation.Primary;
@@ -26,8 +26,8 @@ public class ChatRepositoryImpl implements ChatRepository {
} }
@Override @Override
public List<ChatDO> getAll(String creator) { public List<ChatDO> getAll(String creator, Integer agentId) {
return chatMapper.getAll(creator); return chatMapper.getAll(creator, agentId);
} }

View File

@@ -5,10 +5,10 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.parser.ParseMode; import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.common.pojo.RecordInfo; import com.tencent.supersonic.common.pojo.RecordInfo;
import java.util.List;
import lombok.Data; import lombok.Data;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import java.util.List;
@Data @Data
public class Plugin extends RecordInfo { public class Plugin extends RecordInfo {

View File

@@ -30,9 +30,7 @@ import java.util.HashSet;
import java.util.HashMap; import java.util.HashMap;
import java.util.Objects; import java.util.Objects;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
@@ -52,8 +50,6 @@ import org.springframework.web.util.UriComponentsBuilder;
@Component @Component
public class PluginManager { public class PluginManager {
private static Map<String, Plugin> internalPluginMap = new ConcurrentHashMap<>();
private EmbeddingConfig embeddingConfig; private EmbeddingConfig embeddingConfig;
private RestTemplate restTemplate; private RestTemplate restTemplate;

View File

@@ -0,0 +1,23 @@
package com.tencent.supersonic.chat.plugin;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.Set;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class PluginRecallResult {
private Plugin plugin;
private Set<Long> modelIds;
private double score;
private double distance;
}

View File

@@ -4,6 +4,7 @@ import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import java.util.List; import java.util.List;
@@ -12,18 +13,18 @@ import java.util.OptionalDouble;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery; import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery; import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
@Slf4j @Slf4j
public class HeuristicQuerySelector implements QuerySelector { public class HeuristicQuerySelector implements QuerySelector {
private static final double CANDIDATE_THRESHOLD = 0.2;
@Override @Override
public List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq) { public List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq) {
List<SemanticQuery> selectedQueries = new ArrayList<>(); List<SemanticQuery> selectedQueries = new ArrayList<>();
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
Double candidateThreshold = optimizationConfig.getCandidateThreshold();
if (CollectionUtils.isNotEmpty(candidateQueries) && candidateQueries.size() == 1) { if (CollectionUtils.isNotEmpty(candidateQueries) && candidateQueries.size() == 1) {
selectedQueries.addAll(candidateQueries); selectedQueries.addAll(candidateQueries);
} else { } else {
@@ -35,7 +36,7 @@ public class HeuristicQuerySelector implements QuerySelector {
candidateQueries.stream().forEach(query -> { candidateQueries.stream().forEach(query -> {
SemanticParseInfo parseInfo = query.getParseInfo(); SemanticParseInfo parseInfo = query.getParseInfo();
if (!checkFullyInherited(query) if (!checkFullyInherited(query)
&& (maxScore - parseInfo.getScore()) / maxScore <= CANDIDATE_THRESHOLD && (maxScore - parseInfo.getScore()) / maxScore <= candidateThreshold
&& checkSatisfyOtherRules(query, candidateQueries)) { && checkSatisfyOtherRules(query, candidateQueries)) {
selectedQueries.add(query); selectedQueries.add(query);
} }

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.rule.entity.EntitySemanticQuery; import com.tencent.supersonic.chat.query.rule.entity.EntitySemanticQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery; import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@@ -55,7 +56,6 @@ public class QueryManager {
throw new RuntimeException("no supported queryMode :" + queryMode); throw new RuntimeException("no supported queryMode :" + queryMode);
} }
} }
public static boolean containsRuleQuery(String queryMode) { public static boolean containsRuleQuery(String queryMode) {
if (queryMode == null) { if (queryMode == null) {
return false; return false;

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.chat.query.metricinterpret; package com.tencent.supersonic.chat.query.metricinterpret;
import lombok.Data; import lombok.Data;
@Data @Data

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.query.plugin; package com.tencent.supersonic.chat.query.plugin;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import java.util.List;
import lombok.Data; import lombok.Data;
import java.util.List;
@Data @Data
public class WebBase { public class WebBase {

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.query.plugin; package com.tencent.supersonic.chat.query.plugin;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import java.util.List;
import lombok.Data; import lombok.Data;
import java.util.List;
@Data @Data
public class WebBaseResult { public class WebBaseResult {

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.query.plugin.webpage; package com.tencent.supersonic.chat.query.plugin.webpage;
import com.tencent.supersonic.chat.query.plugin.WebBaseResult; import com.tencent.supersonic.chat.query.plugin.WebBaseResult;
import java.util.List;
import lombok.Data; import lombok.Data;
import java.util.List;
@Data @Data
public class WebPageResponse { public class WebPageResponse {

View File

@@ -1,3 +1,4 @@
package com.tencent.supersonic.chat.query.rule; package com.tencent.supersonic.chat.query.rule;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
@@ -291,6 +292,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
} }
protected QueryStructReq convertQueryStruct() { protected QueryStructReq convertQueryStruct() {
return QueryReqBuilder.buildStructReq(parseInfo); return QueryReqBuilder.buildStructReq(parseInfo);
} }

View File

@@ -1,9 +1,5 @@
package com.tencent.supersonic.chat.query.rule.entity; package com.tencent.supersonic.chat.query.rule.entity;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
@@ -15,12 +11,17 @@ import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.service.ConfigService; import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.time.LocalDate; import java.time.LocalDate;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
@Slf4j @Slf4j
public abstract class EntitySemanticQuery extends RuleSemanticQuery { public abstract class EntitySemanticQuery extends RuleSemanticQuery {
@@ -34,7 +35,7 @@ public abstract class EntitySemanticQuery extends RuleSemanticQuery {
@Override @Override
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches, public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
QueryContext queryCtx) { QueryContext queryCtx) {
candidateElementMatches = filterElementMatches(candidateElementMatches); candidateElementMatches = filterElementMatches(candidateElementMatches);
return super.match(candidateElementMatches, queryCtx); return super.match(candidateElementMatches, queryCtx);
} }

View File

@@ -40,7 +40,7 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
@Override @Override
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches, public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
QueryContext queryCtx) { QueryContext queryCtx) {
candidateElementMatches = filterElementMatches(candidateElementMatches); candidateElementMatches = filterElementMatches(candidateElementMatches);
return super.match(candidateElementMatches, queryCtx); return super.match(candidateElementMatches, queryCtx);
} }

View File

@@ -3,8 +3,8 @@ package com.tencent.supersonic.chat.query.rule.metric;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION; import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE; import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL; import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST; import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.common.pojo.Constants.DESC_UPPER; import static com.tencent.supersonic.common.pojo.Constants.DESC_UPPER;
import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.ChatContext;
@@ -13,11 +13,12 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.common.pojo.Order; import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import org.springframework.stereotype.Component;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import org.springframework.stereotype.Component;
@Component @Component
public class MetricTopNQuery extends MetricSemanticQuery { public class MetricTopNQuery extends MetricSemanticQuery {
@@ -35,7 +36,7 @@ public class MetricTopNQuery extends MetricSemanticQuery {
@Override @Override
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches, public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
QueryContext queryCtx) { QueryContext queryCtx) {
Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getRequest().getQueryText()); Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getRequest().getQueryText());
if (matcher.matches()) { if (matcher.matches()) {
return super.match(candidateElementMatches, queryCtx); return super.match(candidateElementMatches, queryCtx);

View File

@@ -30,14 +30,16 @@ public class ChatController {
@PostMapping("/save") @PostMapping("/save")
public Boolean save(@RequestParam(value = "chatName") String chatName, public Boolean save(@RequestParam(value = "chatName") String chatName,
@RequestParam(value = "agentId", required = false) Integer agentId,
HttpServletRequest request, HttpServletResponse response) { HttpServletRequest request, HttpServletResponse response) {
return chatService.addChat(UserHolder.findUser(request, response), chatName); return chatService.addChat(UserHolder.findUser(request, response), chatName, agentId);
} }
@GetMapping("/getAll") @GetMapping("/getAll")
public List<ChatDO> getAllConversions(HttpServletRequest request, HttpServletResponse response) { public List<ChatDO> getAllConversions(@RequestParam(value = "agentId", required = false) Integer agentId,
HttpServletRequest request, HttpServletResponse response) {
String userName = UserHolder.findUser(request, response).getName(); String userName = UserHolder.findUser(request, response).getName();
return chatService.getAll(userName); return chatService.getAll(userName, agentId);
} }
@PostMapping("/delete") @PostMapping("/delete")

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.rest;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq; import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
@@ -33,7 +34,7 @@ public class ChatQueryController {
@PostMapping("search") @PostMapping("search")
public Object search(@RequestBody QueryReq queryCtx, HttpServletRequest request, public Object search(@RequestBody QueryReq queryCtx, HttpServletRequest request,
HttpServletResponse response) { HttpServletResponse response) {
queryCtx.setUser(UserHolder.findUser(request, response)); queryCtx.setUser(UserHolder.findUser(request, response));
return searchService.search(queryCtx); return searchService.search(queryCtx);
} }
@@ -54,7 +55,7 @@ public class ChatQueryController {
@PostMapping("execute") @PostMapping("execute")
public Object execute(@RequestBody ExecuteQueryReq queryCtx, public Object execute(@RequestBody ExecuteQueryReq queryCtx,
HttpServletRequest request, HttpServletResponse response) HttpServletRequest request, HttpServletResponse response)
throws Exception { throws Exception {
queryCtx.setUser(UserHolder.findUser(request, response)); queryCtx.setUser(UserHolder.findUser(request, response));
return queryService.performExecution(queryCtx); return queryService.performExecution(queryCtx);
@@ -62,16 +63,23 @@ public class ChatQueryController {
@PostMapping("queryContext") @PostMapping("queryContext")
public Object queryContext(@RequestBody QueryReq queryCtx, HttpServletRequest request, public Object queryContext(@RequestBody QueryReq queryCtx, HttpServletRequest request,
HttpServletResponse response) throws Exception { HttpServletResponse response) throws Exception {
queryCtx.setUser(UserHolder.findUser(request, response)); queryCtx.setUser(UserHolder.findUser(request, response));
return queryService.queryContext(queryCtx); return queryService.queryContext(queryCtx);
} }
@PostMapping("queryData") @PostMapping("queryData")
public Object queryData(@RequestBody QueryDataReq queryData, public Object queryData(@RequestBody QueryDataReq queryData,
HttpServletRequest request, HttpServletResponse response) HttpServletRequest request, HttpServletResponse response)
throws Exception { throws Exception {
return queryService.executeDirectQuery(queryData, UserHolder.findUser(request, response)); return queryService.executeDirectQuery(queryData, UserHolder.findUser(request, response));
} }
@PostMapping("queryDimensionValue")
public Object queryDimensionValue(@RequestBody DimensionValueReq dimensionValueReq,
HttpServletRequest request, HttpServletResponse response)
throws Exception {
return queryService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
}
} }

View File

@@ -10,6 +10,7 @@ import com.tencent.supersonic.knowledge.dictionary.DimValueDictInfo;
import java.util.List; import java.util.List;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.DeleteMapping; import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;

View File

@@ -30,9 +30,9 @@ public interface ChatService {
public void switchContext(ChatContext chatCtx); public void switchContext(ChatContext chatCtx);
public Boolean addChat(User user, String chatName); public Boolean addChat(User user, String chatName, Integer agentId);
public List<ChatDO> getAll(String userName); public List<ChatDO> getAll(String userName, Integer agentId);
public boolean updateChatName(Long chatId, String chatName, String userName); public boolean updateChatName(Long chatId, String chatName, String userName);

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter; import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp; import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp; import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import java.util.List; import java.util.List;
public interface ConfigService { public interface ConfigService {

View File

@@ -5,10 +5,10 @@ import com.tencent.supersonic.knowledge.dictionary.DictConfig;
import com.tencent.supersonic.knowledge.dictionary.DictTaskFilter; import com.tencent.supersonic.knowledge.dictionary.DictTaskFilter;
import com.tencent.supersonic.knowledge.dictionary.DimValue2DictCommand; import com.tencent.supersonic.knowledge.dictionary.DimValue2DictCommand;
import com.tencent.supersonic.knowledge.dictionary.DimValueDictInfo; import com.tencent.supersonic.knowledge.dictionary.DimValueDictInfo;
import java.util.List; import java.util.List;
public interface DictionaryService { public interface DictionaryService {
Long addDictTask(DimValue2DictCommand dimValue2DictCommend, User user); Long addDictTask(DimValue2DictCommand dimValue2DictCommend, User user);
Long deleteDictTask(DimValue2DictCommand dimValue2DictCommend, User user); Long deleteDictTask(DimValue2DictCommand dimValue2DictCommend, User user);

View File

@@ -2,8 +2,9 @@ package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq;
import com.tencent.supersonic.chat.plugin.Plugin; import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;

View File

@@ -2,11 +2,12 @@ package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.parser.SqlParseException;
/*** /***
@@ -24,4 +25,5 @@ public interface QueryService {
QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws SqlParseException; QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws SqlParseException;
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
} }

View File

@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.RecommendQuestionResp; import com.tencent.supersonic.chat.api.pojo.response.RecommendQuestionResp;
import com.tencent.supersonic.chat.api.pojo.response.RecommendResp; import com.tencent.supersonic.chat.api.pojo.response.RecommendResp;
import java.util.List; import java.util.List;
/*** /***

View File

@@ -26,9 +26,9 @@ import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp; import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp; import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.DataInfo; import com.tencent.supersonic.chat.api.pojo.response.DataInfo;
import com.tencent.supersonic.chat.api.pojo.response.ModelInfo;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.MetricInfo; import com.tencent.supersonic.chat.api.pojo.response.MetricInfo;
import com.tencent.supersonic.chat.api.pojo.response.ModelInfo;
import com.tencent.supersonic.chat.config.AggregatorConfig; import com.tencent.supersonic.chat.config.AggregatorConfig;
import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder; import com.tencent.supersonic.chat.utils.QueryReqBuilder;
@@ -332,7 +332,7 @@ public class SemanticService {
} }
public AggregateInfo getAggregateInfo(User user, SemanticParseInfo semanticParseInfo, public AggregateInfo getAggregateInfo(User user, SemanticParseInfo semanticParseInfo,
QueryResultWithSchemaResp result) { QueryResultWithSchemaResp result) {
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics()) || !aggregatorConfig.getEnableRatio()) { if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics()) || !aggregatorConfig.getEnableRatio()) {
return new AggregateInfo(); return new AggregateInfo();
} }
@@ -384,7 +384,7 @@ public class SemanticService {
} }
private MetricInfo queryRatio(User user, SemanticParseInfo semanticParseInfo, SchemaElement metric, private MetricInfo queryRatio(User user, SemanticParseInfo semanticParseInfo, SchemaElement metric,
AggOperatorEnum aggOperatorEnum, QueryResultWithSchemaResp results) { AggOperatorEnum aggOperatorEnum, QueryResultWithSchemaResp results) {
MetricInfo metricInfo = new MetricInfo(); MetricInfo metricInfo = new MetricInfo();
metricInfo.setStatistics(new HashMap<>()); metricInfo.setStatistics(new HashMap<>());
QueryStructReq queryStructReq = QueryReqBuilder.buildStructRatioReq(semanticParseInfo, metric, aggOperatorEnum); QueryStructReq queryStructReq = QueryReqBuilder.buildStructRatioReq(semanticParseInfo, metric, aggOperatorEnum);
@@ -432,7 +432,7 @@ public class SemanticService {
} }
private DateConf getRatioDateConf(AggOperatorEnum aggOperatorEnum, SemanticParseInfo semanticParseInfo, private DateConf getRatioDateConf(AggOperatorEnum aggOperatorEnum, SemanticParseInfo semanticParseInfo,
QueryResultWithSchemaResp results) { QueryResultWithSchemaResp results) {
String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo()); String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo());
Optional<String> lastDayOp = results.getResultList().stream() Optional<String> lastDayOp = results.getResultList().stream()
.map(r -> r.get(dateField).toString()) .map(r -> r.get(dateField).toString())

Some files were not shown because too many files have changed in this diff Show More