mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(project) support for modifying filter conditions and fix group by pushdown and add windows scipt (#49)
Co-authored-by: lexluo <lexluo@tencent.com>
This commit is contained in:
@@ -3,6 +3,7 @@ package com.tencent.supersonic.auth.api.authentication.adaptor;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.Organization;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.request.UserReq;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package com.tencent.supersonic.auth.api.authentication.pojo;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class Organization {
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.auth.api.authentication.service;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.Organization;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.request.UserReq;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.auth.api.authentication.service;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.auth.api.authorization.request;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.tencent.supersonic.auth.api.authorization.pojo.AuthResGrp;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.DimensionFilter;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
|
||||
@@ -6,7 +6,6 @@ import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq;
|
||||
import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
public interface AuthService {
|
||||
|
||||
List<AuthGroup> queryAuthGroups(String domainId, Integer groupId);
|
||||
|
||||
@@ -11,10 +11,10 @@ import com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO;
|
||||
import com.tencent.supersonic.auth.authentication.persistence.repository.UserRepository;
|
||||
import com.tencent.supersonic.auth.authentication.utils.UserTokenUtils;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
public class DefaultUserAdaptor implements UserAdaptor {
|
||||
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.tencent.supersonic.auth.authentication.config;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Data
|
||||
@Configuration
|
||||
public class TppConfig {
|
||||
|
||||
@Value(value = "${auth.app.secret:}")
|
||||
private String appSecret;
|
||||
|
||||
@Value(value = "${auth.app.key:}")
|
||||
private String appKey;
|
||||
|
||||
@Value(value = "${auth.oa.url:}")
|
||||
private String tppOaUrl;
|
||||
|
||||
}
|
||||
@@ -1,11 +1,12 @@
|
||||
package com.tencent.supersonic.auth.authentication.interceptor;
|
||||
|
||||
import java.util.List;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.core.io.support.SpringFactoriesLoader;
|
||||
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
|
||||
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Configuration
|
||||
public class InterceptorFactory implements WebMvcConfigurer {
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ package com.tencent.supersonic.auth.authentication.persistence.repository.impl;
|
||||
|
||||
import com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO;
|
||||
import com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample;
|
||||
import com.tencent.supersonic.auth.authentication.persistence.mapper.UserDOMapper;
|
||||
import com.tencent.supersonic.auth.authentication.persistence.repository.UserRepository;
|
||||
import com.tencent.supersonic.auth.authentication.persistence.mapper.UserDOMapper;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@@ -7,6 +7,7 @@ import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||
import com.tencent.supersonic.auth.authentication.utils.ComponentFactory;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.auth.authentication.utils;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor;
|
||||
import java.util.Objects;
|
||||
import org.springframework.core.io.support.SpringFactoriesLoader;
|
||||
import java.util.Objects;
|
||||
|
||||
public class ComponentFactory {
|
||||
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
|
||||
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||
<mapper namespace="com.tencent.supersonic.auth.authentication.persistence.mapper.UserDOMapper">
|
||||
<resultMap id="BaseResultMap"
|
||||
type="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||
<id column="id" jdbcType="BIGINT" property="id"/>
|
||||
<result column="name" jdbcType="VARCHAR" property="name"/>
|
||||
<result column="password" jdbcType="VARCHAR" property="password"/>
|
||||
<result column="display_name" jdbcType="VARCHAR" property="displayName"/>
|
||||
<result column="email" jdbcType="VARCHAR" property="email"/>
|
||||
<resultMap id="BaseResultMap" type="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||
<id column="id" jdbcType="BIGINT" property="id" />
|
||||
<result column="name" jdbcType="VARCHAR" property="name" />
|
||||
<result column="password" jdbcType="VARCHAR" property="password" />
|
||||
<result column="display_name" jdbcType="VARCHAR" property="displayName" />
|
||||
<result column="email" jdbcType="VARCHAR" property="email" />
|
||||
</resultMap>
|
||||
<sql id="Example_Where_Clause">
|
||||
<where>
|
||||
@@ -24,13 +22,11 @@
|
||||
and ${criterion.condition} #{criterion.value}
|
||||
</when>
|
||||
<when test="criterion.betweenValue">
|
||||
and ${criterion.condition} #{criterion.value} and
|
||||
#{criterion.secondValue}
|
||||
and ${criterion.condition} #{criterion.value} and #{criterion.secondValue}
|
||||
</when>
|
||||
<when test="criterion.listValue">
|
||||
and ${criterion.condition}
|
||||
<foreach close=")" collection="criterion.value" item="listItem"
|
||||
open="(" separator=",">
|
||||
<foreach close=")" collection="criterion.value" item="listItem" open="(" separator=",">
|
||||
#{listItem}
|
||||
</foreach>
|
||||
</when>
|
||||
@@ -42,20 +38,17 @@
|
||||
</where>
|
||||
</sql>
|
||||
<sql id="Base_Column_List">
|
||||
id
|
||||
, name, password, display_name, email
|
||||
id, name, password, display_name, email
|
||||
</sql>
|
||||
<select id="selectByExample"
|
||||
parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample"
|
||||
resultMap="BaseResultMap">
|
||||
<select id="selectByExample" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample" resultMap="BaseResultMap">
|
||||
select
|
||||
<if test="distinct">
|
||||
distinct
|
||||
</if>
|
||||
<include refid="Base_Column_List"/>
|
||||
<include refid="Base_Column_List" />
|
||||
from s2_user
|
||||
<if test="_parameter != null">
|
||||
<include refid="Example_Where_Clause"/>
|
||||
<include refid="Example_Where_Clause" />
|
||||
</if>
|
||||
<if test="orderByClause != null">
|
||||
order by ${orderByClause}
|
||||
@@ -66,24 +59,21 @@
|
||||
</select>
|
||||
<select id="selectByPrimaryKey" parameterType="java.lang.Long" resultMap="BaseResultMap">
|
||||
select
|
||||
<include refid="Base_Column_List"/>
|
||||
<include refid="Base_Column_List" />
|
||||
from s2_user
|
||||
where id = #{id,jdbcType=BIGINT}
|
||||
</select>
|
||||
<delete id="deleteByPrimaryKey" parameterType="java.lang.Long">
|
||||
delete
|
||||
from s2_user
|
||||
delete from s2_user
|
||||
where id = #{id,jdbcType=BIGINT}
|
||||
</delete>
|
||||
<insert id="insert"
|
||||
parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||
<insert id="insert" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||
insert into s2_user (id, name, password,
|
||||
display_name, email)
|
||||
values (#{id,jdbcType=BIGINT}, #{name,jdbcType=VARCHAR}, #{password,jdbcType=VARCHAR},
|
||||
#{displayName,jdbcType=VARCHAR}, #{email,jdbcType=VARCHAR})
|
||||
</insert>
|
||||
<insert id="insertSelective"
|
||||
parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||
<insert id="insertSelective" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||
insert into s2_user
|
||||
<trim prefix="(" suffix=")" suffixOverrides=",">
|
||||
<if test="id != null">
|
||||
@@ -120,16 +110,13 @@
|
||||
</if>
|
||||
</trim>
|
||||
</insert>
|
||||
<select id="countByExample"
|
||||
parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample"
|
||||
resultType="java.lang.Long">
|
||||
<select id="countByExample" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample" resultType="java.lang.Long">
|
||||
select count(*) from s2_user
|
||||
<if test="_parameter != null">
|
||||
<include refid="Example_Where_Clause"/>
|
||||
<include refid="Example_Where_Clause" />
|
||||
</if>
|
||||
</select>
|
||||
<update id="updateByPrimaryKeySelective"
|
||||
parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||
<update id="updateByPrimaryKeySelective" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||
update s2_user
|
||||
<set>
|
||||
<if test="name != null">
|
||||
@@ -147,8 +134,7 @@
|
||||
</set>
|
||||
where id = #{id,jdbcType=BIGINT}
|
||||
</update>
|
||||
<update id="updateByPrimaryKey"
|
||||
parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||
<update id="updateByPrimaryKey" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||
update s2_user
|
||||
set name = #{name,jdbcType=VARCHAR},
|
||||
password = #{password,jdbcType=VARCHAR},
|
||||
|
||||
@@ -18,4 +18,5 @@ public class CorrectionInfo {
|
||||
|
||||
private String sql;
|
||||
|
||||
private String preSql;
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ModelSchema {
|
||||
|
||||
@@ -2,11 +2,12 @@ package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
public class QueryContext {
|
||||
|
||||
@@ -1,39 +1,30 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import lombok.Builder;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@Getter
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class SchemaElement implements Serializable {
|
||||
|
||||
private Long model;
|
||||
private Long id;
|
||||
private String name;
|
||||
private String bizName;
|
||||
private Long useCnt;
|
||||
private SchemaElementType type;
|
||||
|
||||
private List<String> alias;
|
||||
|
||||
public SchemaElement(Long model, Long id, String name, String bizName,
|
||||
Long useCnt, SchemaElementType type, List<String> alias) {
|
||||
this.model = model;
|
||||
this.id = id;
|
||||
this.name = name;
|
||||
this.bizName = bizName;
|
||||
this.useCnt = useCnt;
|
||||
this.type = type;
|
||||
this.alias = alias;
|
||||
}
|
||||
private List<SchemaValueMap> schemaValueMaps;
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
@@ -54,4 +45,5 @@ public class SchemaElement implements Serializable {
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(model, id, name, bizName, useCnt, type);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class SchemaValueMap {
|
||||
|
||||
/**
|
||||
* dimension value in db
|
||||
*/
|
||||
private String techName;
|
||||
|
||||
/**
|
||||
* dimension value for result show
|
||||
*/
|
||||
private String bizName;
|
||||
|
||||
/**
|
||||
* dimension value for user query
|
||||
*/
|
||||
private List<String> alias = new ArrayList<>();
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class SemanticSchema implements Serializable {
|
||||
|
||||
private List<ModelSchema> modelSchemaList;
|
||||
|
||||
public SemanticSchema(List<ModelSchema> modelSchemaList) {
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ChatAggConfigReq {
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* extended information command about model
|
||||
*/
|
||||
|
||||
@@ -2,9 +2,10 @@ package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ChatDefaultConfigReq {
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ChatDetailConfigReq {
|
||||
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class DimensionValueReq {
|
||||
private Long modelId;
|
||||
|
||||
private String bizName;
|
||||
|
||||
private Object value;
|
||||
}
|
||||
@@ -7,13 +7,12 @@ import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ExecuteQueryReq {
|
||||
|
||||
private User user;
|
||||
private Integer agentId;
|
||||
private Integer chatId;
|
||||
private String queryText;
|
||||
private Long queryId;
|
||||
private Integer parseId;
|
||||
private Long queryId = 7L;
|
||||
private Integer parseId = 2;
|
||||
private SemanticParseInfo parseInfo;
|
||||
private boolean saveAnswer = true;
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* advanced knowledge config
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
|
||||
|
||||
import javax.validation.constraints.NotNull;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
|
||||
@@ -4,19 +4,20 @@ package com.tencent.supersonic.chat.api.pojo.request;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class QueryDataReq {
|
||||
|
||||
String queryMode;
|
||||
SchemaElement model;
|
||||
Set<SchemaElement> metrics = new HashSet<>();
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
Set<QueryFilter> dimensionFilters = new HashSet<>();
|
||||
Set<QueryFilter> metricFilters = new HashSet<>();
|
||||
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
|
||||
private Set<Order> orders = new HashSet<>();
|
||||
private DateConf dateInfo;
|
||||
private Long limit;
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import lombok.Data;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class QueryFilters {
|
||||
|
||||
private List<QueryFilter> filters = new ArrayList<>();
|
||||
private Map<String, Object> params = new HashMap<>();
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class QueryReq {
|
||||
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Long modelId = 0L;
|
||||
|
||||
@@ -2,9 +2,10 @@ package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeAdvancedConfig;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeInfoReq;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ChatAggRichConfigResp {
|
||||
|
||||
|
||||
@@ -4,8 +4,10 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
|
||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
|
||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
|
||||
@@ -4,9 +4,10 @@ package com.tencent.supersonic.chat.api.pojo.response;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ChatDefaultRichConfigResp {
|
||||
|
||||
|
||||
@@ -2,9 +2,10 @@ package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeAdvancedConfig;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeInfoReq;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ChatDetailRichConfigResp {
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class EntityRichInfoResp {
|
||||
|
||||
/**
|
||||
* entity alias
|
||||
*/
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
|
||||
import java.util.List;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
public class RecommendQuestionResp {
|
||||
|
||||
private Long modelId;
|
||||
private List<RecommendedQuestionReq> recommendedQuestions;
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class RecommendResp {
|
||||
|
||||
private List<SchemaElement> dimensions;
|
||||
private List<SchemaElement> metrics;
|
||||
}
|
||||
|
||||
@@ -136,6 +136,14 @@
|
||||
<artifactId>xk-time</artifactId>
|
||||
<version>${xk.time.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-inline</artifactId>
|
||||
<version>${mockito-inline.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
</project>
|
||||
|
||||
@@ -3,12 +3,13 @@ package com.tencent.supersonic.chat.config;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||
import java.util.List;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class ChatConfig {
|
||||
|
||||
@@ -7,7 +7,6 @@ import org.springframework.context.annotation.Configuration;
|
||||
@Configuration
|
||||
@Data
|
||||
public class FunctionCallInfoConfig {
|
||||
|
||||
@Value("${functionCall.url:}")
|
||||
private String url;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package com.tencent.supersonic.chat.config;
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.context.annotation.PropertySource;
|
||||
|
||||
@Configuration
|
||||
@Data
|
||||
@PropertySource("classpath:optimization.properties")
|
||||
//@ComponentScan(basePackages = "com.tencent.supersonic.chat")
|
||||
public class OptimizationConfig {
|
||||
|
||||
@Value("${one.detection.size}")
|
||||
private Integer oneDetectionSize;
|
||||
@Value("${one.detection.max.size}")
|
||||
private Integer oneDetectionMaxSize;
|
||||
|
||||
@Value("${metric.dimension.min.threshold}")
|
||||
private Double metricDimensionMinThresholdConfig;
|
||||
|
||||
@Value("${metric.dimension.threshold}")
|
||||
private Double metricDimensionThresholdConfig;
|
||||
|
||||
@Value("${dimension.value.threshold}")
|
||||
private Double dimensionValueThresholdConfig;
|
||||
|
||||
@Value("${function.bonus.threshold}")
|
||||
private Double functionBonusThreshold;
|
||||
|
||||
@Value("${long.text.threshold}")
|
||||
private Double longTextThreshold;
|
||||
|
||||
@Value("${short.text.threshold}")
|
||||
private Double shortTextThreshold;
|
||||
|
||||
@Value("${query.text.length.threshold}")
|
||||
private Integer queryTextLengthThreshold;
|
||||
|
||||
@Value("${candidate.threshold}")
|
||||
private Double candidateThreshold;
|
||||
|
||||
}
|
||||
@@ -20,6 +20,7 @@ public class DateFieldCorrector extends BaseSemanticCorrector {
|
||||
String currentDate = DSLDateHelper.getCurrentDate(correctionInfo.getParseInfo().getModelId());
|
||||
sql = SqlParserUpdateHelper.addWhere(sql, DATE_FIELD, currentDate);
|
||||
}
|
||||
correctionInfo.setPreSql(correctionInfo.getSql());
|
||||
correctionInfo.setSql(sql);
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
@@ -9,9 +9,11 @@ public class FieldCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
|
||||
String replaceFields = SqlParserUpdateHelper.replaceFields(correctionInfo.getSql(),
|
||||
String preSql = correctionInfo.getSql();
|
||||
correctionInfo.setPreSql(preSql);
|
||||
String sql = SqlParserUpdateHelper.replaceFields(preSql,
|
||||
getFieldToBizName(correctionInfo.getParseInfo().getModelId()));
|
||||
correctionInfo.setSql(replaceFields);
|
||||
correctionInfo.setSql(sql);
|
||||
return correctionInfo;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class FieldNameCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
|
||||
|
||||
Object context = correctionInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
|
||||
if (Objects.isNull(context)) {
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class);
|
||||
if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) {
|
||||
return correctionInfo;
|
||||
}
|
||||
LLMReq llmReq = dslParseResult.getLlmReq();
|
||||
List<ElementValue> linking = llmReq.getLinking();
|
||||
if (CollectionUtils.isEmpty(linking)) {
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
|
||||
Collectors.groupingBy(ElementValue::getFieldValue,
|
||||
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
|
||||
|
||||
String preSql = correctionInfo.getSql();
|
||||
correctionInfo.setPreSql(preSql);
|
||||
String sql = SqlParserUpdateHelper.replaceFieldNameByValue(preSql, fieldValueToFieldNames);
|
||||
correctionInfo.setSql(sql);
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,18 +1,19 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
@@ -20,29 +21,61 @@ public class FieldValueCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
Long modelId = correctionInfo.getParseInfo().getModel().getId();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions().stream()
|
||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
Object context = correctionInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
|
||||
if (Objects.isNull(context)) {
|
||||
if (CollectionUtils.isEmpty(dimensions)) {
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class);
|
||||
if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) {
|
||||
return correctionInfo;
|
||||
}
|
||||
LLMReq llmReq = dslParseResult.getLlmReq();
|
||||
List<ElementValue> linking = llmReq.getLinking();
|
||||
if (CollectionUtils.isEmpty(linking)) {
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
|
||||
Collectors.groupingBy(ElementValue::getFieldValue,
|
||||
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
|
||||
|
||||
String sql = SqlParserUpdateHelper.replaceValueFields(correctionInfo.getSql(), fieldValueToFieldNames);
|
||||
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
||||
String preSql = correctionInfo.getSql();
|
||||
correctionInfo.setPreSql(preSql);
|
||||
String sql = SqlParserUpdateHelper.replaceValue(preSql, aliasAndBizNameToTechName);
|
||||
correctionInfo.setSql(sql);
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
|
||||
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
|
||||
if (CollectionUtils.isEmpty(dimensions)) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> result = new HashMap<>();
|
||||
|
||||
for (SchemaElement dimension : dimensions) {
|
||||
if (Objects.isNull(dimension)
|
||||
|| Strings.isEmpty(dimension.getBizName())
|
||||
|| CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) {
|
||||
continue;
|
||||
}
|
||||
String bizName = dimension.getBizName();
|
||||
|
||||
Map<String, String> aliasAndBizNameToTechName = new HashMap<>();
|
||||
|
||||
for (SchemaValueMap valueMap : dimension.getSchemaValueMaps()) {
|
||||
if (Objects.isNull(valueMap) || Strings.isEmpty(valueMap.getTechName())) {
|
||||
continue;
|
||||
}
|
||||
if (Strings.isNotEmpty(valueMap.getBizName())) {
|
||||
aliasAndBizNameToTechName.put(valueMap.getBizName(), valueMap.getTechName());
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(valueMap.getAlias())) {
|
||||
valueMap.getAlias().stream().forEach(alias -> {
|
||||
if (Strings.isNotEmpty(alias)) {
|
||||
aliasAndBizNameToTechName.put(alias, valueMap.getTechName());
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(aliasAndBizNameToTechName)) {
|
||||
result.put(bizName, aliasAndBizNameToTechName);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,8 +9,10 @@ public class FunctionCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
|
||||
String replaceFunction = SqlParserUpdateHelper.replaceFunction(correctionInfo.getSql());
|
||||
correctionInfo.setSql(replaceFunction);
|
||||
String preSql = correctionInfo.getSql();
|
||||
correctionInfo.setPreSql(preSql);
|
||||
String sql = SqlParserUpdateHelper.replaceFunction(preSql);
|
||||
correctionInfo.setSql(sql);
|
||||
return correctionInfo;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,14 +20,15 @@ public class QueryFilterAppend extends BaseSemanticCorrector {
|
||||
@Override
|
||||
public CorrectionInfo corrector(CorrectionInfo correctionInfo) throws JSQLParserException {
|
||||
String queryFilter = getQueryFilter(correctionInfo.getQueryFilters());
|
||||
String sql = correctionInfo.getSql();
|
||||
String preSql = correctionInfo.getSql();
|
||||
|
||||
if (StringUtils.isNotEmpty(queryFilter)) {
|
||||
log.info("add queryFilter to sql :{}", queryFilter);
|
||||
log.info("add queryFilter to preSql :{}", queryFilter);
|
||||
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
|
||||
sql = SqlParserUpdateHelper.addWhere(sql, expression);
|
||||
}
|
||||
String sql = SqlParserUpdateHelper.addWhere(preSql, expression);
|
||||
correctionInfo.setPreSql(preSql);
|
||||
correctionInfo.setSql(sql);
|
||||
}
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
|
||||
@@ -15,24 +15,24 @@ public class SelectFieldAppendCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
|
||||
String sql = correctionInfo.getSql();
|
||||
if (SqlParserSelectHelper.hasAggregateFunction(sql)) {
|
||||
String preSql = correctionInfo.getSql();
|
||||
if (SqlParserSelectHelper.hasAggregateFunction(preSql)) {
|
||||
return correctionInfo;
|
||||
}
|
||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql));
|
||||
Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(sql));
|
||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(preSql));
|
||||
Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(preSql));
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
|
||||
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(preSql));
|
||||
whereFields.removeAll(selectFields);
|
||||
whereFields.remove(TimeDimensionEnum.DAY.getName());
|
||||
whereFields.remove(TimeDimensionEnum.WEEK.getName());
|
||||
whereFields.remove(TimeDimensionEnum.MONTH.getName());
|
||||
|
||||
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
|
||||
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(preSql, new ArrayList<>(whereFields));
|
||||
correctionInfo.setPreSql(preSql);
|
||||
correctionInfo.setSql(replaceFields);
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
@@ -12,9 +12,10 @@ public class TableNameCorrector extends BaseSemanticCorrector {
|
||||
@Override
|
||||
public CorrectionInfo corrector(CorrectionInfo correctionInfo) {
|
||||
Long modelId = correctionInfo.getParseInfo().getModelId();
|
||||
String sqlOutput = correctionInfo.getSql();
|
||||
String replaceTable = SqlParserUpdateHelper.replaceTable(sqlOutput, TABLE_PREFIX + modelId);
|
||||
correctionInfo.setSql(replaceTable);
|
||||
String preSql = correctionInfo.getSql();
|
||||
correctionInfo.setPreSql(preSql);
|
||||
String sql = SqlParserUpdateHelper.replaceTable(preSql, TABLE_PREFIX + modelId);
|
||||
correctionInfo.setSql(sql);
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
@@ -105,8 +106,9 @@ public class FuzzyNameMapper implements SchemaMapper {
|
||||
|
||||
private Double getThreshold(QueryContext queryContext, MapperHelper mapperHelper) {
|
||||
|
||||
Double metricDimensionThresholdConfig = mapperHelper.getMetricDimensionThresholdConfig();
|
||||
Double metricDimensionMinThresholdConfig = mapperHelper.getMetricDimensionMinThresholdConfig();
|
||||
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
||||
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo()
|
||||
.getModelElementMatches();
|
||||
|
||||
@@ -9,13 +9,13 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
|
||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
||||
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.knowledge.dictionary.builder.WordBuilderFactory;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -25,6 +25,8 @@ import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
@Slf4j
|
||||
public class HanlpDictMapper implements SchemaMapper {
|
||||
@@ -83,11 +85,14 @@ public class HanlpDictMapper implements SchemaMapper {
|
||||
Long elementID = baseWordBuilder.getElementID(nature);
|
||||
Long frequency = wordNatureToFrequency.get(mapResult.getName() + nature);
|
||||
|
||||
SchemaElement element = modelSchema.getElement(elementType, elementID);
|
||||
if (Objects.isNull(element)) {
|
||||
SchemaElement elementDb = modelSchema.getElement(elementType, elementID);
|
||||
if (Objects.isNull(elementDb)) {
|
||||
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
||||
continue;
|
||||
}
|
||||
SchemaElement element = new SchemaElement();
|
||||
BeanUtils.copyProperties(elementDb, element);
|
||||
element.setAlias(getAlias(elementDb));
|
||||
if (element.getType().equals(SchemaElementType.VALUE)) {
|
||||
element.setName(mapResult.getName());
|
||||
}
|
||||
@@ -124,4 +129,16 @@ public class HanlpDictMapper implements SchemaMapper {
|
||||
}
|
||||
return matches;
|
||||
}
|
||||
|
||||
public List<String> getAlias(SchemaElement element) {
|
||||
if (!SchemaElementType.VALUE.equals(element.getType())) {
|
||||
return element.getAlias();
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(element.getName())) {
|
||||
return element.getAlias().stream()
|
||||
.filter(aliasItem -> aliasItem.contains(element.getName()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
return element.getAlias();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.algorithm.EditDistance;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
@@ -13,7 +14,7 @@ import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
@@ -25,17 +26,8 @@ import org.springframework.stereotype.Service;
|
||||
@Slf4j
|
||||
public class MapperHelper {
|
||||
|
||||
@Value("${one.detection.size:8}")
|
||||
private Integer oneDetectionSize;
|
||||
@Value("${one.detection.max.size:20}")
|
||||
private Integer oneDetectionMaxSize;
|
||||
@Value("${metric.dimension.threshold:0.3}")
|
||||
private Double metricDimensionThresholdConfig;
|
||||
|
||||
@Value("${metric.dimension.min.threshold:0.3}")
|
||||
private Double metricDimensionMinThresholdConfig;
|
||||
@Value("${dimension.value.threshold:0.5}")
|
||||
private Double dimensionValueThresholdConfig;
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
public Integer getStepIndex(Map<Integer, Integer> regOffsetToLength, Integer index) {
|
||||
Integer subRegLength = regOffsetToLength.get(index);
|
||||
@@ -57,10 +49,11 @@ public class MapperHelper {
|
||||
}
|
||||
|
||||
public double getThresholdMatch(List<String> natures) {
|
||||
log.info("optimizationConfig:{}", optimizationConfig);
|
||||
if (existDimensionValues(natures)) {
|
||||
return dimensionValueThresholdConfig;
|
||||
return optimizationConfig.getDimensionValueThresholdConfig();
|
||||
}
|
||||
return metricDimensionThresholdConfig;
|
||||
return optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
}
|
||||
|
||||
/***
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
||||
@@ -31,6 +32,9 @@ public class QueryMatchStrategy implements MatchStrategy {
|
||||
@Autowired
|
||||
private MapperHelper mapperHelper;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
|
||||
String text = queryReq.getQueryText();
|
||||
@@ -111,7 +115,7 @@ public class QueryMatchStrategy implements MatchStrategy {
|
||||
String detectSegment = text.substring(index, i);
|
||||
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = mapperHelper.getOneDetectionMaxSize();
|
||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||
LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize, agentId,
|
||||
detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step2. suffix search
|
||||
@@ -153,7 +157,7 @@ public class QueryMatchStrategy implements MatchStrategy {
|
||||
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
|
||||
return dimensionMetrics;
|
||||
} else {
|
||||
return mapResults.stream().limit(mapperHelper.getOneDetectionSize()).collect(Collectors.toList());
|
||||
return mapResults.stream().limit(optimizationConfig.getOneDetectionSize()).collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,9 @@ package com.tencent.supersonic.chat.parser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
/**
|
||||
@@ -15,10 +17,6 @@ import lombok.extern.slf4j.Slf4j;
|
||||
@Slf4j
|
||||
public class SatisfactionChecker {
|
||||
|
||||
private static final double LONG_TEXT_THRESHOLD = 0.8;
|
||||
private static final double SHORT_TEXT_THRESHOLD = 0.5;
|
||||
private static final int QUERY_TEXT_LENGTH_THRESHOLD = 10;
|
||||
|
||||
// check all the parse info in candidate
|
||||
public static boolean check(QueryContext queryContext) {
|
||||
for (SemanticQuery query : queryContext.getCandidateQueries()) {
|
||||
@@ -35,11 +33,12 @@ public class SatisfactionChecker {
|
||||
private static boolean checkThreshold(String queryText, SemanticParseInfo semanticParseInfo) {
|
||||
int queryTextLength = queryText.replaceAll(" ", "").length();
|
||||
double degree = semanticParseInfo.getScore() / queryTextLength;
|
||||
if (queryTextLength > QUERY_TEXT_LENGTH_THRESHOLD) {
|
||||
if (degree < LONG_TEXT_THRESHOLD) {
|
||||
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
||||
if (queryTextLength > optimizationConfig.getQueryTextLengthThreshold()) {
|
||||
if (degree < optimizationConfig.getLongTextThreshold()) {
|
||||
return false;
|
||||
}
|
||||
} else if (degree < SHORT_TEXT_THRESHOLD) {
|
||||
} else if (degree < optimizationConfig.getShortTextThreshold()) {
|
||||
return false;
|
||||
}
|
||||
log.info("queryMode:{}, degree:{}, parse info:{}",
|
||||
|
||||
@@ -93,19 +93,23 @@ public class LLMDslParser implements SemanticParser {
|
||||
|
||||
SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, dslTool, dslParseResult);
|
||||
|
||||
String correctorSql = getCorrectorSql(queryCtx, parseInfo, llmResp.getSqlOutput());
|
||||
CorrectionInfo correctionInfo = getCorrectorSql(queryCtx, parseInfo, llmResp.getSqlOutput());
|
||||
|
||||
llmResp.setCorrectorSql(correctorSql);
|
||||
llmResp.setCorrectorSql(correctionInfo.getSql());
|
||||
|
||||
setFilter(correctorSql, modelId, parseInfo);
|
||||
setFilter(correctionInfo, modelId, parseInfo);
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("LLMDSLParser error", e);
|
||||
}
|
||||
}
|
||||
|
||||
public void setFilter(String correctorSql, Long modelId, SemanticParseInfo parseInfo) {
|
||||
public void setFilter(CorrectionInfo correctionInfo, Long modelId, SemanticParseInfo parseInfo) {
|
||||
|
||||
String correctorSql = correctionInfo.getPreSql();
|
||||
if (StringUtils.isEmpty(correctorSql)) {
|
||||
correctorSql = correctionInfo.getSql();
|
||||
}
|
||||
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
|
||||
if (CollectionUtils.isEmpty(expressions)) {
|
||||
return;
|
||||
@@ -200,7 +204,7 @@ public class LLMDslParser implements SemanticParser {
|
||||
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
|
||||
}
|
||||
|
||||
private String getCorrectorSql(QueryContext queryCtx, SemanticParseInfo parseInfo, String sql) {
|
||||
private CorrectionInfo getCorrectorSql(QueryContext queryCtx, SemanticParseInfo parseInfo, String sql) {
|
||||
|
||||
CorrectionInfo correctionInfo = CorrectionInfo.builder()
|
||||
.queryFilters(queryCtx.getRequest().getQueryFilters()).sql(sql)
|
||||
@@ -217,7 +221,7 @@ public class LLMDslParser implements SemanticParser {
|
||||
log.error("sqlCorrection:{} execute error,correctionInfo:{}", dslCorrection, correctionInfo, e);
|
||||
}
|
||||
});
|
||||
return correctionInfo.getSql();
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
private SemanticParseInfo getParseInfo(QueryContext queryCtx, Long modelId, DslTool dslTool,
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public abstract class PluginParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
if (!checkPreCondition(queryContext)) {
|
||||
return;
|
||||
}
|
||||
PluginRecallResult pluginRecallResult = recallPlugin(queryContext);
|
||||
if (pluginRecallResult == null) {
|
||||
return;
|
||||
}
|
||||
buildQuery(queryContext, pluginRecallResult);
|
||||
}
|
||||
|
||||
public abstract boolean checkPreCondition(QueryContext queryContext);
|
||||
|
||||
public abstract PluginRecallResult recallPlugin(QueryContext queryContext);
|
||||
|
||||
public void buildQuery(QueryContext queryContext, PluginRecallResult pluginRecallResult) {
|
||||
Plugin plugin = pluginRecallResult.getPlugin();
|
||||
for (Long modelId : pluginRecallResult.getModelIds()) {
|
||||
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin, queryContext.getRequest(),
|
||||
queryContext.getMapInfo().getMatchedElements(modelId), pluginRecallResult.getDistance());
|
||||
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
|
||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||
pluginQuery.setParseInfo(semanticParseInfo);
|
||||
queryContext.getCandidateQueries().add(pluginQuery);
|
||||
if (plugin.isContainsAllModel()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected List<Plugin> getPluginList(QueryContext queryContext) {
|
||||
return PluginManager.getPluginAgentCanSupport(queryContext.getRequest().getAgentId());
|
||||
}
|
||||
|
||||
protected SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin, QueryReq queryReq,
|
||||
List<SchemaElementMatch> schemaElementMatches, double distance) {
|
||||
if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) {
|
||||
modelId = plugin.getModelList().get(0);
|
||||
}
|
||||
SchemaElement model = new SchemaElement();
|
||||
model.setModel(modelId);
|
||||
model.setId(modelId);
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||
semanticParseInfo.setModel(model);
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
PluginParseResult pluginParseResult = new PluginParseResult();
|
||||
pluginParseResult.setPlugin(plugin);
|
||||
pluginParseResult.setRequest(queryReq);
|
||||
pluginParseResult.setDistance(distance);
|
||||
properties.put(Constants.CONTEXT, pluginParseResult);
|
||||
properties.put("type", "plugin");
|
||||
properties.put("name", plugin.getName());
|
||||
semanticParseInfo.setProperties(properties);
|
||||
semanticParseInfo.setScore(distance);
|
||||
fillSemanticParseInfo(semanticParseInfo);
|
||||
setEntity(modelId, semanticParseInfo);
|
||||
return semanticParseInfo;
|
||||
}
|
||||
|
||||
private void setEntity(Long modelId, SemanticParseInfo semanticParseInfo) {
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
ModelSchema modelSchema = semanticService.getModelSchema(modelId);
|
||||
if (modelSchema != null && modelSchema.getEntity() != null) {
|
||||
semanticParseInfo.setEntity(modelSchema.getEntity());
|
||||
}
|
||||
}
|
||||
|
||||
private void fillSemanticParseInfo(SemanticParseInfo semanticParseInfo) {
|
||||
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
|
||||
if (CollectionUtils.isEmpty(schemaElementMatches)) {
|
||||
return;
|
||||
}
|
||||
schemaElementMatches.stream().filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
|
||||
.forEach(schemaElementMatch -> {
|
||||
QueryFilter queryFilter = new QueryFilter();
|
||||
queryFilter.setValue(schemaElementMatch.getWord());
|
||||
queryFilter.setElementID(schemaElementMatch.getElement().getId());
|
||||
queryFilter.setName(schemaElementMatch.getElement().getName());
|
||||
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
|
||||
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
|
||||
semanticParseInfo.getDimensionFilters().add(queryFilter);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,63 +1,54 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.embedding;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.parser.ParseMode;
|
||||
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.HashMap;
|
||||
import java.util.Comparator;
|
||||
import java.util.stream.Collectors;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class EmbeddingBasedParser implements SemanticParser {
|
||||
public class EmbeddingBasedParser extends PluginParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
public boolean checkPreCondition(QueryContext queryContext) {
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
if (StringUtils.isBlank(embeddingConfig.getUrl())) {
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
log.info("EmbeddingBasedParser parser query ctx: {}, chat ctx: {}", queryContext, chatContext);
|
||||
String text = queryContext.getRequest().getQueryText();
|
||||
List<RecallRetrieval> embeddingRetrievals = recallResult(text);
|
||||
choosePlugin(embeddingRetrievals, queryContext);
|
||||
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
|
||||
for (SemanticQuery semanticQuery : semanticQueries) {
|
||||
if (queryContext.getRequest().getQueryText().length() <= semanticQuery.getParseInfo().getScore()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private void choosePlugin(List<RecallRetrieval> embeddingRetrievals,
|
||||
QueryContext queryContext) {
|
||||
@Override
|
||||
public PluginRecallResult recallPlugin(QueryContext queryContext) {
|
||||
String text = queryContext.getRequest().getQueryText();
|
||||
List<RecallRetrieval> embeddingRetrievals = embeddingRecall(text);
|
||||
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||
return;
|
||||
return null;
|
||||
}
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
|
||||
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
|
||||
if (plugin == null || DslQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
|
||||
if (plugin == null) {
|
||||
continue;
|
||||
}
|
||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
|
||||
@@ -65,69 +56,19 @@ public class EmbeddingBasedParser implements SemanticParser {
|
||||
if (pair.getLeft()) {
|
||||
Set<Long> modelList = pair.getRight();
|
||||
if (CollectionUtils.isEmpty(modelList)) {
|
||||
return;
|
||||
continue;
|
||||
}
|
||||
for (Long modelId : modelList) {
|
||||
buildQuery(plugin, Double.parseDouble(embeddingRetrieval.getDistance()), modelId, queryContext,
|
||||
queryContext.getMapInfo().getMatchedElements(modelId));
|
||||
if (plugin.isContainsAllModel()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void buildQuery(Plugin plugin, double distance, Long modelId,
|
||||
QueryContext queryContext, List<SchemaElementMatch> schemaElementMatches) {
|
||||
log.info("EmbeddingBasedParser Model: {} choose plugin: [{} {}]", modelId, plugin.getId(), plugin.getName());
|
||||
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
|
||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin, queryContext.getRequest(),
|
||||
schemaElementMatches, distance);
|
||||
double distance = Double.parseDouble(embeddingRetrieval.getDistance());
|
||||
double score = queryContext.getRequest().getQueryText().length() * (1 - distance);
|
||||
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
|
||||
semanticParseInfo.setScore(score);
|
||||
pluginQuery.setParseInfo(semanticParseInfo);
|
||||
queryContext.getCandidateQueries().add(pluginQuery);
|
||||
return PluginRecallResult.builder()
|
||||
.plugin(plugin).modelIds(modelList).score(score).distance(distance).build();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin, QueryReq queryReq,
|
||||
List<SchemaElementMatch> schemaElementMatches, double distance) {
|
||||
if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) {
|
||||
modelId = plugin.getModelList().get(0);
|
||||
}
|
||||
SchemaElement model = new SchemaElement();
|
||||
model.setModel(modelId);
|
||||
model.setId(modelId);
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||
semanticParseInfo.setModel(model);
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
PluginParseResult pluginParseResult = new PluginParseResult();
|
||||
pluginParseResult.setPlugin(plugin);
|
||||
pluginParseResult.setRequest(queryReq);
|
||||
pluginParseResult.setDistance(distance);
|
||||
properties.put(Constants.CONTEXT, pluginParseResult);
|
||||
properties.put("type", "plugin");
|
||||
properties.put("name", plugin.getName());
|
||||
semanticParseInfo.setProperties(properties);
|
||||
semanticParseInfo.setScore(distance);
|
||||
fillSemanticParseInfo(semanticParseInfo);
|
||||
setEntity(modelId, semanticParseInfo);
|
||||
return semanticParseInfo;
|
||||
}
|
||||
|
||||
private void setEntity(Long modelId, SemanticParseInfo semanticParseInfo) {
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
ModelSchema modelSchema = semanticService.getModelSchema(modelId);
|
||||
if (modelSchema != null && modelSchema.getEntity() != null) {
|
||||
semanticParseInfo.setEntity(modelSchema.getEntity());
|
||||
}
|
||||
}
|
||||
|
||||
public List<RecallRetrieval> recallResult(String embeddingText) {
|
||||
public List<RecallRetrieval> embeddingRecall(String embeddingText) {
|
||||
try {
|
||||
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
|
||||
EmbeddingResp embeddingResp = pluginManager.recognize(embeddingText);
|
||||
@@ -144,26 +85,4 @@ public class EmbeddingBasedParser implements SemanticParser {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
|
||||
private void fillSemanticParseInfo(SemanticParseInfo semanticParseInfo) {
|
||||
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
|
||||
if (!CollectionUtils.isEmpty(schemaElementMatches)) {
|
||||
schemaElementMatches.stream().filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
|
||||
.forEach(schemaElementMatch -> {
|
||||
QueryFilter queryFilter = new QueryFilter();
|
||||
queryFilter.setValue(schemaElementMatch.getWord());
|
||||
queryFilter.setElementID(schemaElementMatch.getElement().getId());
|
||||
queryFilter.setName(schemaElementMatch.getElement().getName());
|
||||
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
|
||||
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
|
||||
semanticParseInfo.getDimensionFilters().add(queryFilter);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected List<Plugin> getPluginList(QueryContext queryContext) {
|
||||
return PluginManager.getPluginAgentCanSupport(queryContext.getRequest().getAgentId());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.embedding;
|
||||
|
||||
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class EmbeddingResp {
|
||||
|
||||
|
||||
@@ -1,32 +1,23 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.config.FunctionCallInfoConfig;
|
||||
import com.tencent.supersonic.chat.parser.ParseMode;
|
||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
|
||||
import com.tencent.supersonic.chat.service.PluginService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.net.URI;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.Objects;
|
||||
import java.util.Map;
|
||||
import java.util.HashMap;
|
||||
import java.util.stream.Collectors;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -42,85 +33,73 @@ import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
@Slf4j
|
||||
public class FunctionBasedParser implements SemanticParser {
|
||||
public class FunctionBasedParser extends PluginParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
|
||||
public boolean checkPreCondition(QueryContext queryContext) {
|
||||
FunctionCallInfoConfig functionCallConfig = ContextUtils.getBean(FunctionCallInfoConfig.class);
|
||||
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
||||
String functionUrl = functionCallConfig.getUrl();
|
||||
if (StringUtils.isBlank(functionUrl) || SatisfactionChecker.check(queryCtx)) {
|
||||
if (StringUtils.isBlank(functionUrl) || SatisfactionChecker.check(queryContext)) {
|
||||
log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl,
|
||||
queryCtx.getRequest().getQueryText());
|
||||
return;
|
||||
queryContext.getRequest().getQueryText());
|
||||
return false;
|
||||
}
|
||||
List<PluginParseConfig> functionDOList = getFunctionDO(queryCtx.getRequest().getModelId(), queryCtx);
|
||||
if (CollectionUtils.isEmpty(functionDOList)) {
|
||||
log.info("function call parser, plugin is empty, skip");
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
FunctionResp functionResp = new FunctionResp();
|
||||
if (functionDOList.size() == 1) {
|
||||
functionResp.setToolSelection(functionDOList.iterator().next().getName());
|
||||
} else {
|
||||
FunctionReq functionReq = FunctionReq.builder()
|
||||
.queryText(queryCtx.getRequest().getQueryText())
|
||||
.pluginConfigs(functionDOList).build();
|
||||
functionResp = requestFunction(functionUrl, functionReq);
|
||||
|
||||
@Override
|
||||
public PluginRecallResult recallPlugin(QueryContext queryContext) {
|
||||
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
||||
FunctionResp functionResp = functionCall(queryContext);
|
||||
if (skipFunction(functionResp)) {
|
||||
return null;
|
||||
}
|
||||
log.info("requestFunction result:{}", functionResp.getToolSelection());
|
||||
if (skipFunction(functionResp)) {
|
||||
return;
|
||||
}
|
||||
PluginParseResult functionCallParseResult = new PluginParseResult();
|
||||
String toolSelection = functionResp.getToolSelection();
|
||||
Optional<Plugin> pluginOptional = pluginService.getPluginByName(toolSelection);
|
||||
if (!pluginOptional.isPresent()) {
|
||||
log.info("pluginOptional is not exist:{}, skip the parse", toolSelection);
|
||||
return;
|
||||
return null;
|
||||
}
|
||||
Plugin plugin = pluginOptional.get();
|
||||
plugin.setParseMode(ParseMode.FUNCTION_CALL);
|
||||
toolSelection = plugin.getType();
|
||||
functionCallParseResult.setPlugin(plugin);
|
||||
log.info("QueryManager PluginQueryModes:{}", QueryManager.getPluginQueryModes());
|
||||
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(toolSelection);
|
||||
ModelResolver modelResolver = ComponentFactory.getModelResolver();
|
||||
log.info("plugin ModelList:{}", plugin.getModelList());
|
||||
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryCtx);
|
||||
Long modelId = modelResolver.resolve(queryCtx, chatCtx, pluginResolveResult.getRight());
|
||||
log.info("FunctionBasedParser modelId:{}", modelId);
|
||||
if ((Objects.isNull(modelId) || modelId <= 0) && !plugin.isContainsAllModel()) {
|
||||
log.info("Model is null, skip the parse, select tool: {}", toolSelection);
|
||||
return;
|
||||
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryContext);
|
||||
if (pluginResolveResult.getLeft()) {
|
||||
Set<Long> modelList = pluginResolveResult.getRight();
|
||||
if (CollectionUtils.isEmpty(modelList)) {
|
||||
return null;
|
||||
}
|
||||
if (!plugin.getModelList().contains(modelId) && !plugin.isContainsAllModel()) {
|
||||
return;
|
||||
double score = queryContext.getRequest().getQueryText().length();
|
||||
return PluginRecallResult.builder().plugin(plugin).modelIds(modelList).score(score).build();
|
||||
}
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
||||
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId));
|
||||
return null;
|
||||
}
|
||||
functionCallParseResult.setRequest(queryCtx.getRequest());
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, functionCallParseResult);
|
||||
properties.put("type", "plugin");
|
||||
properties.put("name", plugin.getName());
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(queryCtx.getRequest().getQueryText().length());
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
SchemaElement model = new SchemaElement();
|
||||
model.setModel(modelId);
|
||||
model.setId(modelId);
|
||||
parseInfo.setModel(model);
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
|
||||
public FunctionResp functionCall(QueryContext queryContext) {
|
||||
List<PluginParseConfig> pluginToFunctionCall =
|
||||
getPluginToFunctionCall(queryContext.getRequest().getModelId(), queryContext);
|
||||
if (CollectionUtils.isEmpty(pluginToFunctionCall)) {
|
||||
log.info("function call parser, plugin is empty, skip");
|
||||
return null;
|
||||
}
|
||||
FunctionCallInfoConfig functionCallConfig = ContextUtils.getBean(FunctionCallInfoConfig.class);
|
||||
FunctionResp functionResp = new FunctionResp();
|
||||
if (pluginToFunctionCall.size() == 1) {
|
||||
functionResp.setToolSelection(pluginToFunctionCall.iterator().next().getName());
|
||||
} else {
|
||||
FunctionReq functionReq = FunctionReq.builder()
|
||||
.queryText(queryContext.getRequest().getQueryText())
|
||||
.pluginConfigs(pluginToFunctionCall).build();
|
||||
functionResp = requestFunction(functionCallConfig.getUrl(), functionReq);
|
||||
}
|
||||
return functionResp;
|
||||
}
|
||||
|
||||
private boolean skipFunction(FunctionResp functionResp) {
|
||||
return Objects.isNull(functionResp) || StringUtils.isBlank(functionResp.getToolSelection());
|
||||
}
|
||||
|
||||
private List<PluginParseConfig> getFunctionDO(Long modelId, QueryContext queryContext) {
|
||||
private List<PluginParseConfig> getPluginToFunctionCall(Long modelId, QueryContext queryContext) {
|
||||
log.info("user decide Model:{}", modelId);
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
|
||||
@@ -150,7 +129,7 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
return true;
|
||||
}
|
||||
}).map(o -> JsonUtil.toObject(o.getParseModeConfig(), PluginParseConfig.class)).collect(Collectors.toList());
|
||||
log.info("getFunctionDO:{}", JsonUtil.toString(functionDOList));
|
||||
log.info("PluginToFunctionCall: {}", JsonUtil.toString(functionDOList));
|
||||
return functionDOList;
|
||||
}
|
||||
|
||||
@@ -173,8 +152,4 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
protected List<Plugin> getPluginList(QueryContext queryContext) {
|
||||
return PluginManager.getPluginAgentCanSupport(queryContext.getRequest().getAgentId());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
||||
import java.util.List;
|
||||
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ModelMatchResult {
|
||||
|
||||
private Integer count = 0;
|
||||
private double maxSimilarity;
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||
|
||||
import lombok.Data;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class Parameters {
|
||||
|
||||
@@ -68,6 +68,7 @@ public class AggregateTypeParser implements SemanticParser {
|
||||
Map<AggregateTypeEnum, Integer> aggregateCount = new HashMap<>(REGX_MAP.size());
|
||||
Map<AggregateTypeEnum, String> aggregateWord = new HashMap<>(REGX_MAP.size());
|
||||
|
||||
|
||||
for (Map.Entry<AggregateTypeEnum, Pattern> entry : REGX_MAP.entrySet()) {
|
||||
Matcher matcher = entry.getValue().matcher(queryText);
|
||||
int count = 0;
|
||||
@@ -90,7 +91,6 @@ public class AggregateTypeParser implements SemanticParser {
|
||||
|
||||
@AllArgsConstructor
|
||||
class AggregateConf {
|
||||
|
||||
public AggregateTypeEnum type;
|
||||
public String detectWord;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.persistence.dataobject;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import lombok.Data;
|
||||
public class ChatDO {
|
||||
|
||||
private long chatId;
|
||||
private Integer agentId;
|
||||
private String chatName;
|
||||
private String createTime;
|
||||
private String lastTime;
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package com.tencent.supersonic.chat.persistence.dataobject;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.config.DefaultMetric;
|
||||
import com.tencent.supersonic.chat.config.Dim4Dict;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import com.tencent.supersonic.chat.config.DefaultMetric;
|
||||
import com.tencent.supersonic.chat.config.Dim4Dict;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.persistence.dataobject;
|
||||
import java.util.Date;
|
||||
|
||||
public class PluginDO {
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
|
||||
@@ -5,7 +5,6 @@ import java.util.Date;
|
||||
import java.util.List;
|
||||
|
||||
public class PluginDOExample {
|
||||
|
||||
/**
|
||||
* s2_plugin
|
||||
*/
|
||||
@@ -149,7 +148,6 @@ public class PluginDOExample {
|
||||
* s2_plugin null
|
||||
*/
|
||||
protected abstract static class GeneratedCriteria {
|
||||
|
||||
protected List<Criterion> criteria;
|
||||
|
||||
protected GeneratedCriteria() {
|
||||
@@ -875,7 +873,6 @@ public class PluginDOExample {
|
||||
* s2_plugin null
|
||||
*/
|
||||
public static class Criterion {
|
||||
|
||||
private String condition;
|
||||
|
||||
private Object value;
|
||||
|
||||
@@ -10,7 +10,7 @@ public interface ChatMapper {
|
||||
|
||||
boolean createChat(ChatDO chatDO);
|
||||
|
||||
List<ChatDO> getAll(String creator);
|
||||
List<ChatDO> getAll(String creator, Integer agentId);
|
||||
|
||||
Boolean updateChatName(Long chatId, String chatName, String lastTime, String creator);
|
||||
|
||||
|
||||
@@ -2,9 +2,10 @@ package com.tencent.supersonic.chat.persistence.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample;
|
||||
import java.util.List;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Mapper
|
||||
public interface ChatQueryDOMapper {
|
||||
|
||||
|
||||
@@ -2,58 +2,67 @@ package com.tencent.supersonic.chat.persistence.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.PluginDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.PluginDOExample;
|
||||
import java.util.List;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
import java.util.List;
|
||||
|
||||
@Mapper
|
||||
public interface PluginDOMapper {
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
long countByExample(PluginDOExample example);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int deleteByPrimaryKey(Long id);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int insert(PluginDO record);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int insertSelective(PluginDO record);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
List<PluginDO> selectByExampleWithBLOBs(PluginDOExample example);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
List<PluginDO> selectByExample(PluginDOExample example);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
PluginDO selectByPrimaryKey(Long id);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByPrimaryKeySelective(PluginDO record);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByPrimaryKeyWithBLOBs(PluginDO record);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByPrimaryKey(PluginDO record);
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package com.tencent.supersonic.chat.persistence.repository;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.config.ChatConfig;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
|
||||
import com.tencent.supersonic.chat.config.ChatConfig;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface ChatConfigRepository {
|
||||
|
||||
@@ -8,7 +8,7 @@ public interface ChatRepository {
|
||||
|
||||
boolean createChat(ChatDO chatDO);
|
||||
|
||||
List<ChatDO> getAll(String creator);
|
||||
List<ChatDO> getAll(String creator, Integer agentId);
|
||||
|
||||
Boolean updateChatName(Long chatId, String chatName, String lastTime, String creator);
|
||||
|
||||
|
||||
@@ -2,10 +2,10 @@ package com.tencent.supersonic.chat.persistence.repository;
|
||||
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.PluginDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.PluginDOExample;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface PluginRepository {
|
||||
|
||||
List<PluginDO> getPlugins();
|
||||
|
||||
List<PluginDO> fetchPluginDOs(String queryText, String type);
|
||||
|
||||
@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.persistence.repository.impl;
|
||||
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.QueryDO;
|
||||
import com.tencent.supersonic.chat.persistence.mapper.ChatMapper;
|
||||
import com.tencent.supersonic.chat.persistence.repository.ChatRepository;
|
||||
import com.tencent.supersonic.chat.persistence.mapper.ChatMapper;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
@@ -26,8 +26,8 @@ public class ChatRepositoryImpl implements ChatRepository {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatDO> getAll(String creator) {
|
||||
return chatMapper.getAll(creator);
|
||||
public List<ChatDO> getAll(String creator, Integer agentId) {
|
||||
return chatMapper.getAll(creator, agentId);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@ import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.parser.ParseMode;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class Plugin extends RecordInfo {
|
||||
|
||||
@@ -30,9 +30,7 @@ import java.util.HashSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.Objects;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
@@ -52,8 +50,6 @@ import org.springframework.web.util.UriComponentsBuilder;
|
||||
@Component
|
||||
public class PluginManager {
|
||||
|
||||
private static Map<String, Plugin> internalPluginMap = new ConcurrentHashMap<>();
|
||||
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
private RestTemplate restTemplate;
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.tencent.supersonic.chat.plugin;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import java.util.Set;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class PluginRecallResult {
|
||||
|
||||
private Plugin plugin;
|
||||
|
||||
private Set<Long> modelIds;
|
||||
|
||||
private double score;
|
||||
|
||||
private double distance;
|
||||
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
|
||||
import java.util.List;
|
||||
@@ -12,18 +13,18 @@ import java.util.OptionalDouble;
|
||||
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class HeuristicQuerySelector implements QuerySelector {
|
||||
|
||||
private static final double CANDIDATE_THRESHOLD = 0.2;
|
||||
|
||||
@Override
|
||||
public List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq) {
|
||||
List<SemanticQuery> selectedQueries = new ArrayList<>();
|
||||
|
||||
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
||||
Double candidateThreshold = optimizationConfig.getCandidateThreshold();
|
||||
if (CollectionUtils.isNotEmpty(candidateQueries) && candidateQueries.size() == 1) {
|
||||
selectedQueries.addAll(candidateQueries);
|
||||
} else {
|
||||
@@ -35,7 +36,7 @@ public class HeuristicQuerySelector implements QuerySelector {
|
||||
candidateQueries.stream().forEach(query -> {
|
||||
SemanticParseInfo parseInfo = query.getParseInfo();
|
||||
if (!checkFullyInherited(query)
|
||||
&& (maxScore - parseInfo.getScore()) / maxScore <= CANDIDATE_THRESHOLD
|
||||
&& (maxScore - parseInfo.getScore()) / maxScore <= candidateThreshold
|
||||
&& checkSatisfyOtherRules(query, candidateQueries)) {
|
||||
selectedQueries.add(query);
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.entity.EntitySemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -55,7 +56,6 @@ public class QueryManager {
|
||||
throw new RuntimeException("no supported queryMode :" + queryMode);
|
||||
}
|
||||
}
|
||||
|
||||
public static boolean containsRuleQuery(String queryMode) {
|
||||
if (queryMode == null) {
|
||||
return false;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat.query.metricinterpret;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.query.plugin;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class WebBase {
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.query.plugin;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class WebBaseResult {
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.query.plugin.webpage;
|
||||
|
||||
import com.tencent.supersonic.chat.query.plugin.WebBaseResult;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class WebPageResponse {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
package com.tencent.supersonic.chat.query.rule;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
@@ -291,6 +292,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
|
||||
}
|
||||
|
||||
|
||||
|
||||
protected QueryStructReq convertQueryStruct() {
|
||||
return QueryReqBuilder.buildStructReq(parseInfo);
|
||||
}
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
package com.tencent.supersonic.chat.query.rule.entity;
|
||||
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
@@ -15,12 +11,17 @@ import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.ConfigService;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
|
||||
@Slf4j
|
||||
public abstract class EntitySemanticQuery extends RuleSemanticQuery {
|
||||
|
||||
@@ -3,8 +3,8 @@ package com.tencent.supersonic.chat.query.rule.metric;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DESC_UPPER;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
@@ -13,11 +13,12 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
public class MetricTopNQuery extends MetricSemanticQuery {
|
||||
|
||||
@@ -30,14 +30,16 @@ public class ChatController {
|
||||
|
||||
@PostMapping("/save")
|
||||
public Boolean save(@RequestParam(value = "chatName") String chatName,
|
||||
@RequestParam(value = "agentId", required = false) Integer agentId,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
return chatService.addChat(UserHolder.findUser(request, response), chatName);
|
||||
return chatService.addChat(UserHolder.findUser(request, response), chatName, agentId);
|
||||
}
|
||||
|
||||
@GetMapping("/getAll")
|
||||
public List<ChatDO> getAllConversions(HttpServletRequest request, HttpServletResponse response) {
|
||||
public List<ChatDO> getAllConversions(@RequestParam(value = "agentId", required = false) Integer agentId,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
String userName = UserHolder.findUser(request, response).getName();
|
||||
return chatService.getAll(userName);
|
||||
return chatService.getAll(userName, agentId);
|
||||
}
|
||||
|
||||
@PostMapping("/delete")
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.rest;
|
||||
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
|
||||
@@ -74,4 +75,11 @@ public class ChatQueryController {
|
||||
return queryService.executeDirectQuery(queryData, UserHolder.findUser(request, response));
|
||||
}
|
||||
|
||||
@PostMapping("queryDimensionValue")
|
||||
public Object queryDimensionValue(@RequestBody DimensionValueReq dimensionValueReq,
|
||||
HttpServletRequest request, HttpServletResponse response)
|
||||
throws Exception {
|
||||
return queryService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import com.tencent.supersonic.knowledge.dictionary.DimValueDictInfo;
|
||||
import java.util.List;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.DeleteMapping;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
|
||||
@@ -30,9 +30,9 @@ public interface ChatService {
|
||||
|
||||
public void switchContext(ChatContext chatCtx);
|
||||
|
||||
public Boolean addChat(User user, String chatName);
|
||||
public Boolean addChat(User user, String chatName, Integer agentId);
|
||||
|
||||
public List<ChatDO> getAll(String userName);
|
||||
public List<ChatDO> getAll(String userName, Integer agentId);
|
||||
|
||||
public boolean updateChatName(Long chatId, String chatName, String userName);
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface ConfigService {
|
||||
|
||||
@@ -5,10 +5,10 @@ import com.tencent.supersonic.knowledge.dictionary.DictConfig;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DictTaskFilter;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DimValue2DictCommand;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DimValueDictInfo;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface DictionaryService {
|
||||
|
||||
Long addDictTask(DimValue2DictCommand dimValue2DictCommend, User user);
|
||||
|
||||
Long deleteDictTask(DimValue2DictCommand dimValue2DictCommend, User user);
|
||||
|
||||
@@ -2,8 +2,9 @@ package com.tencent.supersonic.chat.service;
|
||||
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
|
||||
@@ -2,11 +2,12 @@ package com.tencent.supersonic.chat.service;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
|
||||
/***
|
||||
@@ -24,4 +25,5 @@ public interface QueryService {
|
||||
|
||||
QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws SqlParseException;
|
||||
|
||||
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.service;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.RecommendQuestionResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.RecommendResp;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/***
|
||||
|
||||
@@ -26,9 +26,9 @@ import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.DataInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ModelInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.MetricInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ModelInfo;
|
||||
import com.tencent.supersonic.chat.config.AggregatorConfig;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
|
||||
|
||||
@@ -82,21 +82,22 @@ public class ChatServiceImpl implements ChatService {
|
||||
|
||||
|
||||
@Override
|
||||
public Boolean addChat(User user, String chatName) {
|
||||
ChatDO intelligentConversionDO = new ChatDO();
|
||||
intelligentConversionDO.setChatName(chatName);
|
||||
intelligentConversionDO.setCreator(user.getName());
|
||||
intelligentConversionDO.setCreateTime(getCurrentTime());
|
||||
intelligentConversionDO.setIsDelete(0);
|
||||
intelligentConversionDO.setLastTime(getCurrentTime());
|
||||
intelligentConversionDO.setLastQuestion("Hello, welcome to using supersonic");
|
||||
intelligentConversionDO.setIsTop(0);
|
||||
return chatRepository.createChat(intelligentConversionDO);
|
||||
public Boolean addChat(User user, String chatName, Integer agentId) {
|
||||
ChatDO chatDO = new ChatDO();
|
||||
chatDO.setChatName(chatName);
|
||||
chatDO.setCreator(user.getName());
|
||||
chatDO.setCreateTime(getCurrentTime());
|
||||
chatDO.setIsDelete(0);
|
||||
chatDO.setLastTime(getCurrentTime());
|
||||
chatDO.setLastQuestion("Hello, welcome to using supersonic");
|
||||
chatDO.setIsTop(0);
|
||||
chatDO.setAgentId(agentId);
|
||||
return chatRepository.createChat(chatDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatDO> getAll(String userName) {
|
||||
return chatRepository.getAll(userName);
|
||||
public List<ChatDO> getAll(String userName, Integer agentId) {
|
||||
return chatRepository.getAll(userName, agentId);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user