(improvement)(Chat) add extend config for agent (#1010)

This commit is contained in:
LXW
2024-05-20 12:58:21 +08:00
committed by GitHub
parent 2d8c5c379c
commit 53b6c03288
37 changed files with 264 additions and 711 deletions

View File

@@ -0,0 +1,9 @@
package com.tencent.supersonic.chat.api.pojo.enums;
public enum DefaultShowType {
TEXT,
TABLE,
WIDGET
}

View File

@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.server.agent;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import com.tencent.supersonic.common.pojo.RecordInfo;
import lombok.Data;
import org.springframework.util.CollectionUtils;
@@ -30,6 +31,9 @@ public class Agent extends RecordInfo {
private Integer status;
private List<String> examples;
private String agentConfig;
private LLMConfig llmConfig;
private MultiTurnConfig multiTurnConfig;
private VisualConfig visualConfig;
public List<String> getTools(AgentToolType type) {
Map map = JSONObject.parseObject(agentConfig, Map.class);

View File

@@ -0,0 +1,15 @@
package com.tencent.supersonic.chat.server.agent;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class MultiTurnConfig {
private boolean enableMultiTurn;
}

View File

@@ -0,0 +1,15 @@
package com.tencent.supersonic.chat.server.agent;
import com.tencent.supersonic.chat.api.pojo.enums.DefaultShowType;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class VisualConfig {
private DefaultShowType defaultShowType;
}

View File

@@ -1,10 +1,18 @@
package com.tencent.supersonic.chat.server.persistence.dataobject;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import java.util.Date;
@Data
@TableName("s2_agent")
public class AgentDO {
/**
*/
@TableId(type = IdType.AUTO)
private Integer id;
/**
@@ -48,159 +56,10 @@ public class AgentDO {
*/
private Integer enableSearch;
/**
* @return id
*/
public Integer getId() {
return id;
}
private String llmConfig;
/**
* @param id
*/
public void setId(Integer id) {
this.id = id;
}
private String multiTurnConfig;
/**
* @return name
*/
public String getName() {
return name;
}
private String visualConfig;
/**
* @param name
*/
public void setName(String name) {
this.name = name == null ? null : name.trim();
}
/**
* @return description
*/
public String getDescription() {
return description;
}
/**
* @param description
*/
public void setDescription(String description) {
this.description = description == null ? null : description.trim();
}
/**
* 0 offline, 1 online
* @return status 0 offline, 1 online
*/
public Integer getStatus() {
return status;
}
/**
* 0 offline, 1 online
* @param status 0 offline, 1 online
*/
public void setStatus(Integer status) {
this.status = status;
}
/**
* @return examples
*/
public String getExamples() {
return examples;
}
/**
* @param examples
*/
public void setExamples(String examples) {
this.examples = examples == null ? null : examples.trim();
}
/**
* @return config
*/
public String getConfig() {
return config;
}
/**
* @param config
*/
public void setConfig(String config) {
this.config = config == null ? null : config.trim();
}
/**
* @return created_by
*/
public String getCreatedBy() {
return createdBy;
}
/**
* @param createdBy
*/
public void setCreatedBy(String createdBy) {
this.createdBy = createdBy == null ? null : createdBy.trim();
}
/**
* @return created_at
*/
public Date getCreatedAt() {
return createdAt;
}
/**
* @param createdAt
*/
public void setCreatedAt(Date createdAt) {
this.createdAt = createdAt;
}
/**
* @return updated_by
*/
public String getUpdatedBy() {
return updatedBy;
}
/**
* @param updatedBy
*/
public void setUpdatedBy(String updatedBy) {
this.updatedBy = updatedBy == null ? null : updatedBy.trim();
}
/**
* @return updated_at
*/
public Date getUpdatedAt() {
return updatedAt;
}
/**
* @param updatedAt
*/
public void setUpdatedAt(Date updatedAt) {
this.updatedAt = updatedAt;
}
/**
* @return enable_search
*/
public Integer getEnableSearch() {
return enableSearch;
}
/**
* @param enableSearch
*/
public void setEnableSearch(Integer enableSearch) {
this.enableSearch = enableSearch;
}
}

View File

@@ -1,71 +1,10 @@
package com.tencent.supersonic.chat.server.persistence.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDOExample;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
@Mapper
public interface AgentDOMapper {
/**
*
* @mbg.generated
*/
long countByExample(AgentDOExample example);
public interface AgentDOMapper extends BaseMapper<AgentDO> {
/**
*
* @mbg.generated
*/
int deleteByPrimaryKey(Integer id);
/**
*
* @mbg.generated
*/
int insert(AgentDO record);
/**
*
* @mbg.generated
*/
int insertSelective(AgentDO record);
/**
*
* @mbg.generated
*/
List<AgentDO> selectByExample(AgentDOExample example);
/**
*
* @mbg.generated
*/
AgentDO selectByPrimaryKey(Integer id);
/**
*
* @mbg.generated
*/
int updateByExampleSelective(@Param("record") AgentDO record, @Param("example") AgentDOExample example);
/**
*
* @mbg.generated
*/
int updateByExample(@Param("record") AgentDO record, @Param("example") AgentDOExample example);
/**
*
* @mbg.generated
*/
int updateByPrimaryKeySelective(AgentDO record);
/**
*
* @mbg.generated
*/
int updateByPrimaryKey(AgentDO record);
}

View File

@@ -1,18 +0,0 @@
package com.tencent.supersonic.chat.server.persistence.repository;
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
import java.util.List;
public interface AgentRepository {
List<AgentDO> getAgents();
void createAgent(AgentDO agentDO);
void updateAgent(AgentDO agentDO);
AgentDO getAgent(Integer id);
void deleteAgent(Integer id);
}

View File

@@ -1,43 +0,0 @@
package com.tencent.supersonic.chat.server.persistence.repository.impl;
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDOExample;
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
import com.tencent.supersonic.chat.server.persistence.repository.AgentRepository;
import org.springframework.stereotype.Repository;
import java.util.List;
@Repository
public class AgentRepositoryImpl implements AgentRepository {
private AgentDOMapper agentDOMapper;
public AgentRepositoryImpl(AgentDOMapper agentDOMapper) {
this.agentDOMapper = agentDOMapper;
}
@Override
public List<AgentDO> getAgents() {
return agentDOMapper.selectByExample(new AgentDOExample());
}
@Override
public void createAgent(AgentDO agentDO) {
agentDOMapper.insert(agentDO);
}
@Override
public void updateAgent(AgentDO agentDO) {
agentDOMapper.updateByPrimaryKey(agentDO);
}
@Override
public AgentDO getAgent(Integer id) {
return agentDOMapper.selectByPrimaryKey(id);
}
@Override
public void deleteAgent(Integer id) {
agentDOMapper.deleteByPrimaryKey(id);
}
}

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.service.AgentService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
@@ -21,12 +22,9 @@ import java.util.Map;
@RequestMapping({"/api/chat/agent", "/openapi/chat/agent"})
public class AgentController {
@Autowired
private AgentService agentService;
public AgentController(AgentService agentService) {
this.agentService = agentService;
}
@PostMapping
public boolean createAgent(@RequestBody Agent agent,
HttpServletRequest httpServletRequest,

View File

@@ -1,25 +1,23 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.VisualConfig;
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.server.persistence.repository.AgentRepository;
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
import com.tencent.supersonic.chat.server.service.AgentService;
import java.util.Date;
import java.util.List;
import java.util.stream.Collectors;
import com.tencent.supersonic.common.util.JsonUtil;
import org.springframework.beans.BeanUtils;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.stream.Collectors;
@Service
public class AgentServiceImpl implements AgentService {
private AgentRepository agentRepository;
public AgentServiceImpl(AgentRepository agentRepository) {
this.agentRepository = agentRepository;
}
public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
implements AgentService {
@Override
public List<Agent> getAgents() {
@@ -29,12 +27,14 @@ public class AgentServiceImpl implements AgentService {
@Override
public void createAgent(Agent agent, User user) {
agentRepository.createAgent(convert(agent, user));
agent.createdBy(user.getName());
save(convert(agent));
}
@Override
public void updateAgent(Agent agent, User user) {
agentRepository.updateAgent(convert(agent, user));
agent.updatedBy(user.getName());
updateById(convert(agent));
}
@Override
@@ -42,16 +42,16 @@ public class AgentServiceImpl implements AgentService {
if (id == null) {
return null;
}
return convert(agentRepository.getAgent(id));
return convert(getById(id));
}
@Override
public void deleteAgent(Integer id) {
agentRepository.deleteAgent(id);
removeById(id);
}
private List<AgentDO> getAgentDOList() {
return agentRepository.getAgents();
return list();
}
private Agent convert(AgentDO agentDO) {
@@ -61,19 +61,21 @@ public class AgentServiceImpl implements AgentService {
Agent agent = new Agent();
BeanUtils.copyProperties(agentDO, agent);
agent.setAgentConfig(agentDO.getConfig());
agent.setExamples(JSONObject.parseArray(agentDO.getExamples(), String.class));
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
agent.setLlmConfig(JsonUtil.toObject(agentDO.getLlmConfig(), LLMConfig.class));
agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
return agent;
}
private AgentDO convert(Agent agent, User user) {
private AgentDO convert(Agent agent) {
AgentDO agentDO = new AgentDO();
BeanUtils.copyProperties(agent, agentDO);
agentDO.setConfig(agent.getAgentConfig());
agentDO.setExamples(JSONObject.toJSONString(agent.getExamples()));
agentDO.setCreatedAt(new Date());
agentDO.setCreatedBy(user.getName());
agentDO.setUpdatedAt(new Date());
agentDO.setUpdatedBy(user.getName());
agentDO.setExamples(JsonUtil.toString(agent.getExamples()));
agentDO.setLlmConfig(JsonUtil.toString(agent.getLlmConfig()));
agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig()));
agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig()));
if (agentDO.getStatus() == null) {
agentDO.setStatus(1);
}

View File

@@ -30,6 +30,7 @@ public class QueryReqConverter {
&& MapUtils.isNotEmpty(queryReq.getMapInfo().getDataSetElementMatches())) {
queryReq.setMapInfo(queryReq.getMapInfo());
}
queryReq.setLlmConfig(agent.getLlmConfig());
return queryReq;
}

View File

@@ -1,303 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper">
<resultMap id="BaseResultMap" type="com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO">
<id column="id" jdbcType="INTEGER" property="id" />
<result column="name" jdbcType="VARCHAR" property="name" />
<result column="description" jdbcType="VARCHAR" property="description" />
<result column="status" jdbcType="INTEGER" property="status" />
<result column="examples" jdbcType="VARCHAR" property="examples" />
<result column="config" jdbcType="VARCHAR" property="config" />
<result column="created_by" jdbcType="VARCHAR" property="createdBy" />
<result column="created_at" jdbcType="TIMESTAMP" property="createdAt" />
<result column="updated_by" jdbcType="VARCHAR" property="updatedBy" />
<result column="updated_at" jdbcType="TIMESTAMP" property="updatedAt" />
<result column="enable_search" jdbcType="INTEGER" property="enableSearch" />
</resultMap>
<sql id="Example_Where_Clause">
<where>
<foreach collection="oredCriteria" item="criteria" separator="or">
<if test="criteria.valid">
<trim prefix="(" prefixOverrides="and" suffix=")">
<foreach collection="criteria.criteria" item="criterion">
<choose>
<when test="criterion.noValue">
and ${criterion.condition}
</when>
<when test="criterion.singleValue">
and ${criterion.condition} #{criterion.value}
</when>
<when test="criterion.betweenValue">
and ${criterion.condition} #{criterion.value} and #{criterion.secondValue}
</when>
<when test="criterion.listValue">
and ${criterion.condition}
<foreach close=")" collection="criterion.value" item="listItem" open="(" separator=",">
#{listItem}
</foreach>
</when>
</choose>
</foreach>
</trim>
</if>
</foreach>
</where>
</sql>
<sql id="Update_By_Example_Where_Clause">
<where>
<foreach collection="example.oredCriteria" item="criteria" separator="or">
<if test="criteria.valid">
<trim prefix="(" prefixOverrides="and" suffix=")">
<foreach collection="criteria.criteria" item="criterion">
<choose>
<when test="criterion.noValue">
and ${criterion.condition}
</when>
<when test="criterion.singleValue">
and ${criterion.condition} #{criterion.value}
</when>
<when test="criterion.betweenValue">
and ${criterion.condition} #{criterion.value} and #{criterion.secondValue}
</when>
<when test="criterion.listValue">
and ${criterion.condition}
<foreach close=")" collection="criterion.value" item="listItem" open="(" separator=",">
#{listItem}
</foreach>
</when>
</choose>
</foreach>
</trim>
</if>
</foreach>
</where>
</sql>
<sql id="Base_Column_List">
id, name, description, status, examples, config, created_by, created_at, updated_by,
updated_at, enable_search
</sql>
<select id="selectByExample" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.AgentDOExample" resultMap="BaseResultMap">
select
<if test="distinct">
distinct
</if>
<include refid="Base_Column_List" />
from s2_agent
<if test="_parameter != null">
<include refid="Example_Where_Clause" />
</if>
<if test="orderByClause != null">
order by ${orderByClause}
</if>
<if test="limitStart != null and limitStart>=0">
limit #{limitStart} , #{limitEnd}
</if>
</select>
<select id="selectByPrimaryKey" parameterType="java.lang.Integer" resultMap="BaseResultMap">
select
<include refid="Base_Column_List" />
from s2_agent
where id = #{id,jdbcType=INTEGER}
</select>
<delete id="deleteByPrimaryKey" parameterType="java.lang.Integer">
delete from s2_agent
where id = #{id,jdbcType=INTEGER}
</delete>
<insert id="insert" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO">
insert into s2_agent (id, name, description,
status, examples, config,
created_by, created_at, updated_by,
updated_at, enable_search)
values (#{id,jdbcType=INTEGER}, #{name,jdbcType=VARCHAR}, #{description,jdbcType=VARCHAR},
#{status,jdbcType=INTEGER}, #{examples,jdbcType=VARCHAR}, #{config,jdbcType=VARCHAR},
#{createdBy,jdbcType=VARCHAR}, #{createdAt,jdbcType=TIMESTAMP}, #{updatedBy,jdbcType=VARCHAR},
#{updatedAt,jdbcType=TIMESTAMP}, #{enableSearch,jdbcType=INTEGER})
</insert>
<insert id="insertSelective" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO">
insert into s2_agent
<trim prefix="(" suffix=")" suffixOverrides=",">
<if test="id != null">
id,
</if>
<if test="name != null">
name,
</if>
<if test="description != null">
description,
</if>
<if test="status != null">
status,
</if>
<if test="examples != null">
examples,
</if>
<if test="config != null">
config,
</if>
<if test="createdBy != null">
created_by,
</if>
<if test="createdAt != null">
created_at,
</if>
<if test="updatedBy != null">
updated_by,
</if>
<if test="updatedAt != null">
updated_at,
</if>
<if test="enableSearch != null">
enable_search,
</if>
</trim>
<trim prefix="values (" suffix=")" suffixOverrides=",">
<if test="id != null">
#{id,jdbcType=INTEGER},
</if>
<if test="name != null">
#{name,jdbcType=VARCHAR},
</if>
<if test="description != null">
#{description,jdbcType=VARCHAR},
</if>
<if test="status != null">
#{status,jdbcType=INTEGER},
</if>
<if test="examples != null">
#{examples,jdbcType=VARCHAR},
</if>
<if test="config != null">
#{config,jdbcType=VARCHAR},
</if>
<if test="createdBy != null">
#{createdBy,jdbcType=VARCHAR},
</if>
<if test="createdAt != null">
#{createdAt,jdbcType=TIMESTAMP},
</if>
<if test="updatedBy != null">
#{updatedBy,jdbcType=VARCHAR},
</if>
<if test="updatedAt != null">
#{updatedAt,jdbcType=TIMESTAMP},
</if>
<if test="enableSearch != null">
#{enableSearch,jdbcType=INTEGER},
</if>
</trim>
</insert>
<select id="countByExample" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.AgentDOExample" resultType="java.lang.Long">
select count(*) from s2_agent
<if test="_parameter != null">
<include refid="Example_Where_Clause" />
</if>
</select>
<update id="updateByExampleSelective" parameterType="map">
update s2_agent
<set>
<if test="record.id != null">
id = #{record.id,jdbcType=INTEGER},
</if>
<if test="record.name != null">
name = #{record.name,jdbcType=VARCHAR},
</if>
<if test="record.description != null">
description = #{record.description,jdbcType=VARCHAR},
</if>
<if test="record.status != null">
status = #{record.status,jdbcType=INTEGER},
</if>
<if test="record.examples != null">
examples = #{record.examples,jdbcType=VARCHAR},
</if>
<if test="record.config != null">
config = #{record.config,jdbcType=VARCHAR},
</if>
<if test="record.createdBy != null">
created_by = #{record.createdBy,jdbcType=VARCHAR},
</if>
<if test="record.createdAt != null">
created_at = #{record.createdAt,jdbcType=TIMESTAMP},
</if>
<if test="record.updatedBy != null">
updated_by = #{record.updatedBy,jdbcType=VARCHAR},
</if>
<if test="record.updatedAt != null">
updated_at = #{record.updatedAt,jdbcType=TIMESTAMP},
</if>
<if test="record.enableSearch != null">
enable_search = #{record.enableSearch,jdbcType=INTEGER},
</if>
</set>
<if test="_parameter != null">
<include refid="Update_By_Example_Where_Clause" />
</if>
</update>
<update id="updateByExample" parameterType="map">
update s2_agent
set id = #{record.id,jdbcType=INTEGER},
name = #{record.name,jdbcType=VARCHAR},
description = #{record.description,jdbcType=VARCHAR},
status = #{record.status,jdbcType=INTEGER},
examples = #{record.examples,jdbcType=VARCHAR},
config = #{record.config,jdbcType=VARCHAR},
created_by = #{record.createdBy,jdbcType=VARCHAR},
created_at = #{record.createdAt,jdbcType=TIMESTAMP},
updated_by = #{record.updatedBy,jdbcType=VARCHAR},
updated_at = #{record.updatedAt,jdbcType=TIMESTAMP},
enable_search = #{record.enableSearch,jdbcType=INTEGER}
<if test="_parameter != null">
<include refid="Update_By_Example_Where_Clause" />
</if>
</update>
<update id="updateByPrimaryKeySelective" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO">
update s2_agent
<set>
<if test="name != null">
name = #{name,jdbcType=VARCHAR},
</if>
<if test="description != null">
description = #{description,jdbcType=VARCHAR},
</if>
<if test="status != null">
status = #{status,jdbcType=INTEGER},
</if>
<if test="examples != null">
examples = #{examples,jdbcType=VARCHAR},
</if>
<if test="config != null">
config = #{config,jdbcType=VARCHAR},
</if>
<if test="createdBy != null">
created_by = #{createdBy,jdbcType=VARCHAR},
</if>
<if test="createdAt != null">
created_at = #{createdAt,jdbcType=TIMESTAMP},
</if>
<if test="updatedBy != null">
updated_by = #{updatedBy,jdbcType=VARCHAR},
</if>
<if test="updatedAt != null">
updated_at = #{updatedAt,jdbcType=TIMESTAMP},
</if>
<if test="enableSearch != null">
enable_search = #{enableSearch,jdbcType=INTEGER},
</if>
</set>
where id = #{id,jdbcType=INTEGER}
</update>
<update id="updateByPrimaryKey" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO">
update s2_agent
set name = #{name,jdbcType=VARCHAR},
description = #{description,jdbcType=VARCHAR},
status = #{status,jdbcType=INTEGER},
examples = #{examples,jdbcType=VARCHAR},
config = #{config,jdbcType=VARCHAR},
created_by = #{createdBy,jdbcType=VARCHAR},
created_at = #{createdAt,jdbcType=TIMESTAMP},
updated_by = #{updatedBy,jdbcType=VARCHAR},
updated_at = #{updatedAt,jdbcType=TIMESTAMP},
enable_search = #{enableSearch,jdbcType=INTEGER}
where id = #{id,jdbcType=INTEGER}
</update>
</mapper>

View File

@@ -172,6 +172,10 @@
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-local-ai</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>

View File

@@ -0,0 +1,9 @@
package com.tencent.supersonic.common.pojo.enums;
public enum S2ModelProvider {
OPEN_AI,
HUGGING_FACE,
LOCAL_AI,
IN_PROCESS
}

View File

@@ -0,0 +1,32 @@
package com.tencent.supersonic.headless.api.pojo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class LLMConfig {
private String provider;
private String baseUrl;
private String apiKey;
private String modelName;
private Double temperature;
private Long timeOut;
public LLMConfig(String provider, String baseUrl, String apiKey, String modelName) {
this.provider = provider;
this.baseUrl = baseUrl;
this.apiKey = apiKey;
this.modelName = modelName;
this.temperature = 0.0d;
this.timeOut = 60L;
}
}

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.api.pojo.request;
import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
@@ -22,4 +23,5 @@ public class QueryReq {
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
private SchemaMapInfo mapInfo = new SchemaMapInfo();
private QueryDataType queryDataType = QueryDataType.ALL;
private LLMConfig llmConfig;
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.api.pojo.response;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import lombok.Data;
import java.util.List;
@@ -11,8 +12,8 @@ public class DataSetMapInfo {
private String description;
private List<SchemaElementMatch> mapFields;
private List<SchemaElementMatch> mapFields = Lists.newArrayList();
private List<SchemaElementMatch> topFields;
private List<SchemaElementMatch> topFields = Lists.newArrayList();
}

View File

@@ -0,0 +1,31 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.utils.S2ChatModelProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
public abstract class BaseSqlGeneration implements SqlGeneration, InitializingBean {
protected static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired
protected SqlExamplarLoader sqlExamplarLoader;
@Autowired
protected OptimizationConfig optimizationConfig;
@Autowired
protected SqlPromptGenerator sqlPromptGenerator;
protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) {
return S2ChatModelProvider.provide(llmConfig);
}
}

View File

@@ -96,6 +96,7 @@ public class LLMRequestService {
}
llmReq.setCurrentDate(currentDate);
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode().getName());
llmReq.setLlmConfig(queryCtx.getLlmConfig());
return llmReq;
}

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
@@ -12,10 +11,6 @@ import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.HashMap;
@@ -26,20 +21,7 @@ import java.util.stream.Collectors;
@Service
@Slf4j
public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private SqlExamplarLoader sqlExamplarLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private SqlPromptGenerator sqlPromptGenerator;
public class OnePassSCSqlGeneration extends BaseSqlGeneration {
@Override
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
@@ -59,6 +41,7 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
.apply(new HashMap<>());
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
llmResults.add(result);

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
@@ -12,10 +11,6 @@ import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.HashMap;
@@ -24,26 +19,13 @@ import java.util.Map;
@Service
@Slf4j
public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private SqlExamplarLoader sqlExampleLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private SqlPromptGenerator sqlPromptGenerator;
public class OnePassSqlGeneration extends BaseSqlGeneration {
@Override
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
//1.retriever sqlExamples
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
//2.generator linking and sql prompt by sqlExamples,and generate response.
@@ -51,6 +33,7 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
keyPipelineLog.info("model response:{}", result);

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
@@ -11,10 +10,6 @@ import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.HashMap;
@@ -23,20 +18,7 @@ import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
@Service
public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private SqlExamplarLoader sqlExamplarLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private SqlPromptGenerator sqlPromptGenerator;
public class TwoPassSCSqlGeneration extends BaseSqlGeneration {
@Override
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
@@ -51,6 +33,7 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
//2.generator linking prompt,and parallel generate response.
List<String> linkingPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, false);
List<String> linkingResults = new CopyOnWriteArrayList<>();
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
linkingPromptPool.parallelStream().forEach(
linkingPrompt -> {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>());

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
@@ -11,10 +10,6 @@ import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.HashMap;
@@ -23,20 +18,7 @@ import java.util.Map;
@Service
@Slf4j
public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private SqlExamplarLoader sqlExamplarLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private SqlPromptGenerator sqlPromptGenerator;
public class TwoPassSqlGeneration extends BaseSqlGeneration {
@Override
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
@@ -48,6 +30,7 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
keyPipelineLog.info("step one model response:{}", response.content().text());
String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text());

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.core.chat.query.llm.s2sql;
import com.fasterxml.jackson.annotation.JsonValue;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import lombok.Data;
import java.util.List;
@@ -22,6 +23,8 @@ public class LLMReq {
private String sqlGenerationMode;
private LLMConfig llmConfig;
@Data
public static class ElementValue {

View File

@@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
@@ -47,6 +48,7 @@ public class QueryContext {
@JsonIgnore
private WorkflowState workflowState;
private QueryDataType queryDataType = QueryDataType.ALL;
private LLMConfig llmConfig;
public List<SemanticQuery> getCandidateQueries() {
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);

View File

@@ -0,0 +1,41 @@
package com.tencent.supersonic.headless.core.utils;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import com.tencent.supersonic.common.pojo.enums.S2ModelProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import org.apache.commons.lang3.StringUtils;
import java.time.Duration;
public class S2ChatModelProvider {
public static ChatLanguageModel provide(LLMConfig llmConfig) {
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
if (StringUtils.isBlank(llmConfig.getProvider())
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
return chatLanguageModel;
}
if (S2ModelProvider.OPEN_AI.name().equalsIgnoreCase(llmConfig.getProvider())) {
return OpenAiChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())
.apiKey(llmConfig.getApiKey())
.temperature(llmConfig.getTemperature())
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
.build();
} else if (S2ModelProvider.LOCAL_AI.name().equalsIgnoreCase(llmConfig.getProvider())) {
return LocalAiChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())
.temperature(llmConfig.getTemperature())
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
.build();
}
throw new RuntimeException("unsupported provider: " + llmConfig.getProvider());
}
}

View File

@@ -22,7 +22,6 @@ import com.tencent.supersonic.headless.server.service.DatabaseService;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.utils.DatabaseConverter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@@ -36,8 +35,6 @@ import java.util.stream.Collectors;
@Slf4j
@Service
public class DatabaseServiceImpl implements DatabaseService {
@Value("${inMemoryEmbeddingStore.persistent.path:/tmp}")
private String embeddingStorePersistentPath;
private final SqlUtils sqlUtils;
private DatabaseRepository databaseRepository;

View File

@@ -66,6 +66,13 @@ import com.tencent.supersonic.headless.server.service.TagMetaService;
import com.tencent.supersonic.headless.server.utils.MetricCheckUtils;
import com.tencent.supersonic.headless.server.utils.MetricConverter;
import com.tencent.supersonic.headless.server.utils.ModelClusterBuilder;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@@ -79,13 +86,6 @@ import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Service
@Slf4j
@@ -293,12 +293,13 @@ public class MetricServiceImpl implements MetricService {
queryMapReq.setUser(user);
queryMapReq.setMapModeEnum(MapModeEnum.LOOSE);
MapInfoResp mapMeta = metaDiscoveryService.getMapMeta(queryMapReq);
Map<String, DataSetMapInfo> dataSetMapInfo = mapMeta.getDataSetMapInfo();
if (CollectionUtils.isEmpty(dataSetMapInfo)) {
Map<String, DataSetMapInfo> dataSetMapInfoMap = mapMeta.getDataSetMapInfo();
if (CollectionUtils.isEmpty(dataSetMapInfoMap)) {
return metricRespPageInfo;
}
Map<Long, Double> result = dataSetMapInfo.values().stream()
Map<Long, Double> result = dataSetMapInfoMap.values().stream()
.map(DataSetMapInfo::getMapFields)
.filter(Objects::nonNull)
.flatMap(Collection::stream).filter(schemaElementMatch ->
SchemaElementType.METRIC.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toMap(schemaElementMatch ->

View File

@@ -1,5 +1,6 @@
package dev.langchain4j;
import com.tencent.supersonic.common.pojo.enums.S2ModelProvider;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
class S2EmbeddingModel {

View File

@@ -1,9 +0,0 @@
package dev.langchain4j;
enum S2ModelProvider {
OPEN_AI,
HUGGING_FACE,
LOCAL_AI,
IN_PROCESS
}

View File

@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentConfig;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.chat.server.plugin.Plugin;
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
@@ -21,6 +22,8 @@ import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.common.pojo.SysParameter;
import com.tencent.supersonic.common.service.SysParameterService;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import com.tencent.supersonic.common.pojo.enums.S2ModelProvider;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
@@ -174,6 +177,11 @@ public class ChatDemoLoader implements CommandLineRunner {
agentConfig.getTools().add(llmParserTool);
}
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
LLMConfig llmConfig = new LLMConfig(S2ModelProvider.OPEN_AI.name(),
"", "your_key", "gpt-3.5-turbo");
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(false);
agent.setLlmConfig(llmConfig);
agent.setMultiTurnConfig(multiTurnConfig);
agentService.createAgent(agent, User.getFakeUser());
}

View File

@@ -308,3 +308,8 @@ CREATE TABLE IF NOT EXISTS `s2_term` (
`updated_by` varchar(100) DEFAULT NULL ,
PRIMARY KEY (`id`)
);
--20240520
alter table s2_agent add column `llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL;
alter table s2_agent add column `multi_turn_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL;
alter table s2_agent add column `visual_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL;

View File

@@ -351,6 +351,9 @@ CREATE TABLE IF NOT EXISTS s2_agent
status int null,
examples varchar(500) null,
config varchar(2000) null,
llm_config varchar(2000) null,
multi_turn_config varchar(2000) null,
visual_config varchar(2000) null,
created_by varchar(100) null,
created_at TIMESTAMP null,
updated_by varchar(100) null,

View File

@@ -72,6 +72,9 @@ CREATE TABLE `s2_agent` (
`status` int(11) DEFAULT NULL,
`model` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL,
`config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL,
`llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
`multi_turn_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
`visual_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
`created_by` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL,
`created_at` datetime DEFAULT NULL,
`updated_by` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL,

View File

@@ -351,6 +351,9 @@ CREATE TABLE IF NOT EXISTS s2_agent
status int null,
examples varchar(500) null,
config varchar(2000) null,
llm_config varchar(2000) null,
multi_turn_config varchar(2000) null,
visual_config varchar(2000) null,
created_by varchar(100) null,
created_at TIMESTAMP null,
updated_by varchar(100) null,

View File

@@ -125,6 +125,11 @@
<artifactId>langchain4j-open-ai</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-local-ai</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-hugging-face</artifactId>

2
webapp/.gitignore vendored
View File

@@ -13,7 +13,7 @@
/dist
/supersonic-webapp
/webapp
../assembly/build/supersonic-webapp.tar.gz
supersonic-webapp.tar.gz