mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(Chat) add extend config for agent (#1010)
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.enums;
|
||||
|
||||
public enum DefaultShowType {
|
||||
|
||||
TEXT,
|
||||
TABLE,
|
||||
WIDGET
|
||||
|
||||
}
|
||||
@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.server.agent;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -30,6 +31,9 @@ public class Agent extends RecordInfo {
|
||||
private Integer status;
|
||||
private List<String> examples;
|
||||
private String agentConfig;
|
||||
private LLMConfig llmConfig;
|
||||
private MultiTurnConfig multiTurnConfig;
|
||||
private VisualConfig visualConfig;
|
||||
|
||||
public List<String> getTools(AgentToolType type) {
|
||||
Map map = JSONObject.parseObject(agentConfig, Map.class);
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class MultiTurnConfig {
|
||||
|
||||
private boolean enableMultiTurn;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.DefaultShowType;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class VisualConfig {
|
||||
|
||||
private DefaultShowType defaultShowType;
|
||||
|
||||
}
|
||||
@@ -1,10 +1,18 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.IdType;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
@TableName("s2_agent")
|
||||
public class AgentDO {
|
||||
/**
|
||||
*/
|
||||
@TableId(type = IdType.AUTO)
|
||||
private Integer id;
|
||||
|
||||
/**
|
||||
@@ -48,159 +56,10 @@ public class AgentDO {
|
||||
*/
|
||||
private Integer enableSearch;
|
||||
|
||||
/**
|
||||
* @return id
|
||||
*/
|
||||
public Integer getId() {
|
||||
return id;
|
||||
}
|
||||
private String llmConfig;
|
||||
|
||||
/**
|
||||
* @param id
|
||||
*/
|
||||
public void setId(Integer id) {
|
||||
this.id = id;
|
||||
}
|
||||
private String multiTurnConfig;
|
||||
|
||||
/**
|
||||
* @return name
|
||||
*/
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
private String visualConfig;
|
||||
|
||||
/**
|
||||
* @param name
|
||||
*/
|
||||
public void setName(String name) {
|
||||
this.name = name == null ? null : name.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return description
|
||||
*/
|
||||
public String getDescription() {
|
||||
return description;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param description
|
||||
*/
|
||||
public void setDescription(String description) {
|
||||
this.description = description == null ? null : description.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* 0 offline, 1 online
|
||||
* @return status 0 offline, 1 online
|
||||
*/
|
||||
public Integer getStatus() {
|
||||
return status;
|
||||
}
|
||||
|
||||
/**
|
||||
* 0 offline, 1 online
|
||||
* @param status 0 offline, 1 online
|
||||
*/
|
||||
public void setStatus(Integer status) {
|
||||
this.status = status;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return examples
|
||||
*/
|
||||
public String getExamples() {
|
||||
return examples;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param examples
|
||||
*/
|
||||
public void setExamples(String examples) {
|
||||
this.examples = examples == null ? null : examples.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return config
|
||||
*/
|
||||
public String getConfig() {
|
||||
return config;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param config
|
||||
*/
|
||||
public void setConfig(String config) {
|
||||
this.config = config == null ? null : config.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return created_by
|
||||
*/
|
||||
public String getCreatedBy() {
|
||||
return createdBy;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param createdBy
|
||||
*/
|
||||
public void setCreatedBy(String createdBy) {
|
||||
this.createdBy = createdBy == null ? null : createdBy.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return created_at
|
||||
*/
|
||||
public Date getCreatedAt() {
|
||||
return createdAt;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param createdAt
|
||||
*/
|
||||
public void setCreatedAt(Date createdAt) {
|
||||
this.createdAt = createdAt;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return updated_by
|
||||
*/
|
||||
public String getUpdatedBy() {
|
||||
return updatedBy;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param updatedBy
|
||||
*/
|
||||
public void setUpdatedBy(String updatedBy) {
|
||||
this.updatedBy = updatedBy == null ? null : updatedBy.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return updated_at
|
||||
*/
|
||||
public Date getUpdatedAt() {
|
||||
return updatedAt;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param updatedAt
|
||||
*/
|
||||
public void setUpdatedAt(Date updatedAt) {
|
||||
this.updatedAt = updatedAt;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return enable_search
|
||||
*/
|
||||
public Integer getEnableSearch() {
|
||||
return enableSearch;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param enableSearch
|
||||
*/
|
||||
public void setEnableSearch(Integer enableSearch) {
|
||||
this.enableSearch = enableSearch;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,71 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.mapper;
|
||||
|
||||
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDOExample;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
import org.apache.ibatis.annotations.Param;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Mapper
|
||||
public interface AgentDOMapper {
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
long countByExample(AgentDOExample example);
|
||||
public interface AgentDOMapper extends BaseMapper<AgentDO> {
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int deleteByPrimaryKey(Integer id);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int insert(AgentDO record);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int insertSelective(AgentDO record);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
List<AgentDO> selectByExample(AgentDOExample example);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
AgentDO selectByPrimaryKey(Integer id);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByExampleSelective(@Param("record") AgentDO record, @Param("example") AgentDOExample example);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByExample(@Param("record") AgentDO record, @Param("example") AgentDOExample example);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByPrimaryKeySelective(AgentDO record);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByPrimaryKey(AgentDO record);
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface AgentRepository {
|
||||
|
||||
List<AgentDO> getAgents();
|
||||
|
||||
void createAgent(AgentDO agentDO);
|
||||
|
||||
void updateAgent(AgentDO agentDO);
|
||||
|
||||
AgentDO getAgent(Integer id);
|
||||
|
||||
void deleteAgent(Integer id);
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository.impl;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDOExample;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.AgentRepository;
|
||||
import org.springframework.stereotype.Repository;
|
||||
import java.util.List;
|
||||
|
||||
@Repository
|
||||
public class AgentRepositoryImpl implements AgentRepository {
|
||||
|
||||
private AgentDOMapper agentDOMapper;
|
||||
|
||||
public AgentRepositoryImpl(AgentDOMapper agentDOMapper) {
|
||||
this.agentDOMapper = agentDOMapper;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AgentDO> getAgents() {
|
||||
return agentDOMapper.selectByExample(new AgentDOExample());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createAgent(AgentDO agentDO) {
|
||||
agentDOMapper.insert(agentDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateAgent(AgentDO agentDO) {
|
||||
agentDOMapper.updateByPrimaryKey(agentDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AgentDO getAgent(Integer id) {
|
||||
return agentDOMapper.selectByPrimaryKey(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteAgent(Integer id) {
|
||||
agentDOMapper.deleteByPrimaryKey(id);
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.DeleteMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
@@ -21,12 +22,9 @@ import java.util.Map;
|
||||
@RequestMapping({"/api/chat/agent", "/openapi/chat/agent"})
|
||||
public class AgentController {
|
||||
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
|
||||
public AgentController(AgentService agentService) {
|
||||
this.agentService = agentService;
|
||||
}
|
||||
|
||||
@PostMapping
|
||||
public boolean createAgent(@RequestBody Agent agent,
|
||||
HttpServletRequest httpServletRequest,
|
||||
|
||||
@@ -1,25 +1,23 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.VisualConfig;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.AgentRepository;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
public class AgentServiceImpl implements AgentService {
|
||||
|
||||
private AgentRepository agentRepository;
|
||||
|
||||
public AgentServiceImpl(AgentRepository agentRepository) {
|
||||
this.agentRepository = agentRepository;
|
||||
}
|
||||
public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
implements AgentService {
|
||||
|
||||
@Override
|
||||
public List<Agent> getAgents() {
|
||||
@@ -29,12 +27,14 @@ public class AgentServiceImpl implements AgentService {
|
||||
|
||||
@Override
|
||||
public void createAgent(Agent agent, User user) {
|
||||
agentRepository.createAgent(convert(agent, user));
|
||||
agent.createdBy(user.getName());
|
||||
save(convert(agent));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateAgent(Agent agent, User user) {
|
||||
agentRepository.updateAgent(convert(agent, user));
|
||||
agent.updatedBy(user.getName());
|
||||
updateById(convert(agent));
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -42,16 +42,16 @@ public class AgentServiceImpl implements AgentService {
|
||||
if (id == null) {
|
||||
return null;
|
||||
}
|
||||
return convert(agentRepository.getAgent(id));
|
||||
return convert(getById(id));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteAgent(Integer id) {
|
||||
agentRepository.deleteAgent(id);
|
||||
removeById(id);
|
||||
}
|
||||
|
||||
private List<AgentDO> getAgentDOList() {
|
||||
return agentRepository.getAgents();
|
||||
return list();
|
||||
}
|
||||
|
||||
private Agent convert(AgentDO agentDO) {
|
||||
@@ -61,19 +61,21 @@ public class AgentServiceImpl implements AgentService {
|
||||
Agent agent = new Agent();
|
||||
BeanUtils.copyProperties(agentDO, agent);
|
||||
agent.setAgentConfig(agentDO.getConfig());
|
||||
agent.setExamples(JSONObject.parseArray(agentDO.getExamples(), String.class));
|
||||
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
|
||||
agent.setLlmConfig(JsonUtil.toObject(agentDO.getLlmConfig(), LLMConfig.class));
|
||||
agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
|
||||
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
|
||||
return agent;
|
||||
}
|
||||
|
||||
private AgentDO convert(Agent agent, User user) {
|
||||
private AgentDO convert(Agent agent) {
|
||||
AgentDO agentDO = new AgentDO();
|
||||
BeanUtils.copyProperties(agent, agentDO);
|
||||
agentDO.setConfig(agent.getAgentConfig());
|
||||
agentDO.setExamples(JSONObject.toJSONString(agent.getExamples()));
|
||||
agentDO.setCreatedAt(new Date());
|
||||
agentDO.setCreatedBy(user.getName());
|
||||
agentDO.setUpdatedAt(new Date());
|
||||
agentDO.setUpdatedBy(user.getName());
|
||||
agentDO.setExamples(JsonUtil.toString(agent.getExamples()));
|
||||
agentDO.setLlmConfig(JsonUtil.toString(agent.getLlmConfig()));
|
||||
agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig()));
|
||||
agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig()));
|
||||
if (agentDO.getStatus() == null) {
|
||||
agentDO.setStatus(1);
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ public class QueryReqConverter {
|
||||
&& MapUtils.isNotEmpty(queryReq.getMapInfo().getDataSetElementMatches())) {
|
||||
queryReq.setMapInfo(queryReq.getMapInfo());
|
||||
}
|
||||
queryReq.setLlmConfig(agent.getLlmConfig());
|
||||
return queryReq;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,303 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||
<mapper namespace="com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper">
|
||||
<resultMap id="BaseResultMap" type="com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO">
|
||||
<id column="id" jdbcType="INTEGER" property="id" />
|
||||
<result column="name" jdbcType="VARCHAR" property="name" />
|
||||
<result column="description" jdbcType="VARCHAR" property="description" />
|
||||
<result column="status" jdbcType="INTEGER" property="status" />
|
||||
<result column="examples" jdbcType="VARCHAR" property="examples" />
|
||||
<result column="config" jdbcType="VARCHAR" property="config" />
|
||||
<result column="created_by" jdbcType="VARCHAR" property="createdBy" />
|
||||
<result column="created_at" jdbcType="TIMESTAMP" property="createdAt" />
|
||||
<result column="updated_by" jdbcType="VARCHAR" property="updatedBy" />
|
||||
<result column="updated_at" jdbcType="TIMESTAMP" property="updatedAt" />
|
||||
<result column="enable_search" jdbcType="INTEGER" property="enableSearch" />
|
||||
</resultMap>
|
||||
<sql id="Example_Where_Clause">
|
||||
<where>
|
||||
<foreach collection="oredCriteria" item="criteria" separator="or">
|
||||
<if test="criteria.valid">
|
||||
<trim prefix="(" prefixOverrides="and" suffix=")">
|
||||
<foreach collection="criteria.criteria" item="criterion">
|
||||
<choose>
|
||||
<when test="criterion.noValue">
|
||||
and ${criterion.condition}
|
||||
</when>
|
||||
<when test="criterion.singleValue">
|
||||
and ${criterion.condition} #{criterion.value}
|
||||
</when>
|
||||
<when test="criterion.betweenValue">
|
||||
and ${criterion.condition} #{criterion.value} and #{criterion.secondValue}
|
||||
</when>
|
||||
<when test="criterion.listValue">
|
||||
and ${criterion.condition}
|
||||
<foreach close=")" collection="criterion.value" item="listItem" open="(" separator=",">
|
||||
#{listItem}
|
||||
</foreach>
|
||||
</when>
|
||||
</choose>
|
||||
</foreach>
|
||||
</trim>
|
||||
</if>
|
||||
</foreach>
|
||||
</where>
|
||||
</sql>
|
||||
<sql id="Update_By_Example_Where_Clause">
|
||||
<where>
|
||||
<foreach collection="example.oredCriteria" item="criteria" separator="or">
|
||||
<if test="criteria.valid">
|
||||
<trim prefix="(" prefixOverrides="and" suffix=")">
|
||||
<foreach collection="criteria.criteria" item="criterion">
|
||||
<choose>
|
||||
<when test="criterion.noValue">
|
||||
and ${criterion.condition}
|
||||
</when>
|
||||
<when test="criterion.singleValue">
|
||||
and ${criterion.condition} #{criterion.value}
|
||||
</when>
|
||||
<when test="criterion.betweenValue">
|
||||
and ${criterion.condition} #{criterion.value} and #{criterion.secondValue}
|
||||
</when>
|
||||
<when test="criterion.listValue">
|
||||
and ${criterion.condition}
|
||||
<foreach close=")" collection="criterion.value" item="listItem" open="(" separator=",">
|
||||
#{listItem}
|
||||
</foreach>
|
||||
</when>
|
||||
</choose>
|
||||
</foreach>
|
||||
</trim>
|
||||
</if>
|
||||
</foreach>
|
||||
</where>
|
||||
</sql>
|
||||
<sql id="Base_Column_List">
|
||||
id, name, description, status, examples, config, created_by, created_at, updated_by,
|
||||
updated_at, enable_search
|
||||
</sql>
|
||||
<select id="selectByExample" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.AgentDOExample" resultMap="BaseResultMap">
|
||||
select
|
||||
<if test="distinct">
|
||||
distinct
|
||||
</if>
|
||||
<include refid="Base_Column_List" />
|
||||
from s2_agent
|
||||
<if test="_parameter != null">
|
||||
<include refid="Example_Where_Clause" />
|
||||
</if>
|
||||
<if test="orderByClause != null">
|
||||
order by ${orderByClause}
|
||||
</if>
|
||||
<if test="limitStart != null and limitStart>=0">
|
||||
limit #{limitStart} , #{limitEnd}
|
||||
</if>
|
||||
</select>
|
||||
<select id="selectByPrimaryKey" parameterType="java.lang.Integer" resultMap="BaseResultMap">
|
||||
select
|
||||
<include refid="Base_Column_List" />
|
||||
from s2_agent
|
||||
where id = #{id,jdbcType=INTEGER}
|
||||
</select>
|
||||
<delete id="deleteByPrimaryKey" parameterType="java.lang.Integer">
|
||||
delete from s2_agent
|
||||
where id = #{id,jdbcType=INTEGER}
|
||||
</delete>
|
||||
<insert id="insert" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO">
|
||||
insert into s2_agent (id, name, description,
|
||||
status, examples, config,
|
||||
created_by, created_at, updated_by,
|
||||
updated_at, enable_search)
|
||||
values (#{id,jdbcType=INTEGER}, #{name,jdbcType=VARCHAR}, #{description,jdbcType=VARCHAR},
|
||||
#{status,jdbcType=INTEGER}, #{examples,jdbcType=VARCHAR}, #{config,jdbcType=VARCHAR},
|
||||
#{createdBy,jdbcType=VARCHAR}, #{createdAt,jdbcType=TIMESTAMP}, #{updatedBy,jdbcType=VARCHAR},
|
||||
#{updatedAt,jdbcType=TIMESTAMP}, #{enableSearch,jdbcType=INTEGER})
|
||||
</insert>
|
||||
<insert id="insertSelective" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO">
|
||||
insert into s2_agent
|
||||
<trim prefix="(" suffix=")" suffixOverrides=",">
|
||||
<if test="id != null">
|
||||
id,
|
||||
</if>
|
||||
<if test="name != null">
|
||||
name,
|
||||
</if>
|
||||
<if test="description != null">
|
||||
description,
|
||||
</if>
|
||||
<if test="status != null">
|
||||
status,
|
||||
</if>
|
||||
<if test="examples != null">
|
||||
examples,
|
||||
</if>
|
||||
<if test="config != null">
|
||||
config,
|
||||
</if>
|
||||
<if test="createdBy != null">
|
||||
created_by,
|
||||
</if>
|
||||
<if test="createdAt != null">
|
||||
created_at,
|
||||
</if>
|
||||
<if test="updatedBy != null">
|
||||
updated_by,
|
||||
</if>
|
||||
<if test="updatedAt != null">
|
||||
updated_at,
|
||||
</if>
|
||||
<if test="enableSearch != null">
|
||||
enable_search,
|
||||
</if>
|
||||
</trim>
|
||||
<trim prefix="values (" suffix=")" suffixOverrides=",">
|
||||
<if test="id != null">
|
||||
#{id,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="name != null">
|
||||
#{name,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="description != null">
|
||||
#{description,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="status != null">
|
||||
#{status,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="examples != null">
|
||||
#{examples,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="config != null">
|
||||
#{config,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="createdBy != null">
|
||||
#{createdBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="createdAt != null">
|
||||
#{createdAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="updatedBy != null">
|
||||
#{updatedBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="updatedAt != null">
|
||||
#{updatedAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="enableSearch != null">
|
||||
#{enableSearch,jdbcType=INTEGER},
|
||||
</if>
|
||||
</trim>
|
||||
</insert>
|
||||
<select id="countByExample" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.AgentDOExample" resultType="java.lang.Long">
|
||||
select count(*) from s2_agent
|
||||
<if test="_parameter != null">
|
||||
<include refid="Example_Where_Clause" />
|
||||
</if>
|
||||
</select>
|
||||
<update id="updateByExampleSelective" parameterType="map">
|
||||
update s2_agent
|
||||
<set>
|
||||
<if test="record.id != null">
|
||||
id = #{record.id,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="record.name != null">
|
||||
name = #{record.name,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.description != null">
|
||||
description = #{record.description,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.status != null">
|
||||
status = #{record.status,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="record.examples != null">
|
||||
examples = #{record.examples,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.config != null">
|
||||
config = #{record.config,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.createdBy != null">
|
||||
created_by = #{record.createdBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.createdAt != null">
|
||||
created_at = #{record.createdAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="record.updatedBy != null">
|
||||
updated_by = #{record.updatedBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.updatedAt != null">
|
||||
updated_at = #{record.updatedAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="record.enableSearch != null">
|
||||
enable_search = #{record.enableSearch,jdbcType=INTEGER},
|
||||
</if>
|
||||
</set>
|
||||
<if test="_parameter != null">
|
||||
<include refid="Update_By_Example_Where_Clause" />
|
||||
</if>
|
||||
</update>
|
||||
<update id="updateByExample" parameterType="map">
|
||||
update s2_agent
|
||||
set id = #{record.id,jdbcType=INTEGER},
|
||||
name = #{record.name,jdbcType=VARCHAR},
|
||||
description = #{record.description,jdbcType=VARCHAR},
|
||||
status = #{record.status,jdbcType=INTEGER},
|
||||
examples = #{record.examples,jdbcType=VARCHAR},
|
||||
config = #{record.config,jdbcType=VARCHAR},
|
||||
created_by = #{record.createdBy,jdbcType=VARCHAR},
|
||||
created_at = #{record.createdAt,jdbcType=TIMESTAMP},
|
||||
updated_by = #{record.updatedBy,jdbcType=VARCHAR},
|
||||
updated_at = #{record.updatedAt,jdbcType=TIMESTAMP},
|
||||
enable_search = #{record.enableSearch,jdbcType=INTEGER}
|
||||
<if test="_parameter != null">
|
||||
<include refid="Update_By_Example_Where_Clause" />
|
||||
</if>
|
||||
</update>
|
||||
<update id="updateByPrimaryKeySelective" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO">
|
||||
update s2_agent
|
||||
<set>
|
||||
<if test="name != null">
|
||||
name = #{name,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="description != null">
|
||||
description = #{description,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="status != null">
|
||||
status = #{status,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="examples != null">
|
||||
examples = #{examples,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="config != null">
|
||||
config = #{config,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="createdBy != null">
|
||||
created_by = #{createdBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="createdAt != null">
|
||||
created_at = #{createdAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="updatedBy != null">
|
||||
updated_by = #{updatedBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="updatedAt != null">
|
||||
updated_at = #{updatedAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="enableSearch != null">
|
||||
enable_search = #{enableSearch,jdbcType=INTEGER},
|
||||
</if>
|
||||
</set>
|
||||
where id = #{id,jdbcType=INTEGER}
|
||||
</update>
|
||||
<update id="updateByPrimaryKey" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO">
|
||||
update s2_agent
|
||||
set name = #{name,jdbcType=VARCHAR},
|
||||
description = #{description,jdbcType=VARCHAR},
|
||||
status = #{status,jdbcType=INTEGER},
|
||||
examples = #{examples,jdbcType=VARCHAR},
|
||||
config = #{config,jdbcType=VARCHAR},
|
||||
created_by = #{createdBy,jdbcType=VARCHAR},
|
||||
created_at = #{createdAt,jdbcType=TIMESTAMP},
|
||||
updated_by = #{updatedBy,jdbcType=VARCHAR},
|
||||
updated_at = #{updatedAt,jdbcType=TIMESTAMP},
|
||||
enable_search = #{enableSearch,jdbcType=INTEGER}
|
||||
where id = #{id,jdbcType=INTEGER}
|
||||
</update>
|
||||
</mapper>
|
||||
@@ -172,6 +172,10 @@
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-open-ai</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-local-ai</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j</artifactId>
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
package com.tencent.supersonic.common.pojo.enums;
|
||||
|
||||
public enum S2ModelProvider {
|
||||
|
||||
OPEN_AI,
|
||||
HUGGING_FACE,
|
||||
LOCAL_AI,
|
||||
IN_PROCESS
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package com.tencent.supersonic.headless.api.pojo;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class LLMConfig {
|
||||
|
||||
private String provider;
|
||||
|
||||
private String baseUrl;
|
||||
|
||||
private String apiKey;
|
||||
|
||||
private String modelName;
|
||||
|
||||
private Double temperature;
|
||||
|
||||
private Long timeOut;
|
||||
|
||||
public LLMConfig(String provider, String baseUrl, String apiKey, String modelName) {
|
||||
this.provider = provider;
|
||||
this.baseUrl = baseUrl;
|
||||
this.apiKey = apiKey;
|
||||
this.modelName = modelName;
|
||||
this.temperature = 0.0d;
|
||||
this.timeOut = 60L;
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.api.pojo.request;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
@@ -22,4 +23,5 @@ public class QueryReq {
|
||||
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private LLMConfig llmConfig;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.response;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import lombok.Data;
|
||||
import java.util.List;
|
||||
@@ -11,8 +12,8 @@ public class DataSetMapInfo {
|
||||
|
||||
private String description;
|
||||
|
||||
private List<SchemaElementMatch> mapFields;
|
||||
private List<SchemaElementMatch> mapFields = Lists.newArrayList();
|
||||
|
||||
private List<SchemaElementMatch> topFields;
|
||||
private List<SchemaElementMatch> topFields = Lists.newArrayList();
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.utils.S2ChatModelProvider;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public abstract class BaseSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
protected static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
@Autowired
|
||||
protected SqlExamplarLoader sqlExamplarLoader;
|
||||
|
||||
@Autowired
|
||||
protected OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
protected SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) {
|
||||
return S2ChatModelProvider.provide(llmConfig);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -96,6 +96,7 @@ public class LLMRequestService {
|
||||
}
|
||||
llmReq.setCurrentDate(currentDate);
|
||||
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode().getName());
|
||||
llmReq.setLlmConfig(queryCtx.getLlmConfig());
|
||||
return llmReq;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
@@ -12,10 +11,6 @@ import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
@@ -26,20 +21,7 @@ import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
@Autowired
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private SqlExamplarLoader sqlExamplarLoader;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
public class OnePassSCSqlGeneration extends BaseSqlGeneration {
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
|
||||
@@ -59,7 +41,8 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
|
||||
.apply(new HashMap<>());
|
||||
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
String result = response.content().text();
|
||||
llmResults.add(result);
|
||||
keyPipelineLog.info("model response:{}", result);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
@@ -12,10 +11,6 @@ import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
@@ -24,26 +19,13 @@ import java.util.Map;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
@Autowired
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private SqlExamplarLoader sqlExampleLoader;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
public class OnePassSqlGeneration extends BaseSqlGeneration {
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
|
||||
//1.retriever sqlExamples
|
||||
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
//2.generator linking and sql prompt by sqlExamples,and generate response.
|
||||
@@ -51,6 +33,7 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
String result = response.content().text();
|
||||
keyPipelineLog.info("model response:{}", result);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
@@ -11,10 +10,6 @@ import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
@@ -23,20 +18,7 @@ import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
@Service
|
||||
public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
@Autowired
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private SqlExamplarLoader sqlExamplarLoader;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
public class TwoPassSCSqlGeneration extends BaseSqlGeneration {
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
|
||||
@@ -51,6 +33,7 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
//2.generator linking prompt,and parallel generate response.
|
||||
List<String> linkingPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, false);
|
||||
List<String> linkingResults = new CopyOnWriteArrayList<>();
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
linkingPromptPool.parallelStream().forEach(
|
||||
linkingPrompt -> {
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>());
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
@@ -11,10 +10,6 @@ import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
@@ -23,20 +18,7 @@ import java.util.Map;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
@Autowired
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private SqlExamplarLoader sqlExamplarLoader;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
public class TwoPassSqlGeneration extends BaseSqlGeneration {
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
|
||||
@@ -48,6 +30,7 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
keyPipelineLog.info("step one model response:{}", response.content().text());
|
||||
String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text());
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.core.chat.query.llm.s2sql;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonValue;
|
||||
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
@@ -22,6 +23,8 @@ public class LLMReq {
|
||||
|
||||
private String sqlGenerationMode;
|
||||
|
||||
private LLMConfig llmConfig;
|
||||
|
||||
@Data
|
||||
public static class ElementValue {
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
@@ -47,6 +48,7 @@ public class QueryContext {
|
||||
@JsonIgnore
|
||||
private WorkflowState workflowState;
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private LLMConfig llmConfig;
|
||||
|
||||
public List<SemanticQuery> getCandidateQueries() {
|
||||
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
package com.tencent.supersonic.headless.core.utils;
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
||||
import com.tencent.supersonic.common.pojo.enums.S2ModelProvider;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.localai.LocalAiChatModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import java.time.Duration;
|
||||
|
||||
public class S2ChatModelProvider {
|
||||
|
||||
public static ChatLanguageModel provide(LLMConfig llmConfig) {
|
||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||
if (StringUtils.isBlank(llmConfig.getProvider())
|
||||
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
|
||||
return chatLanguageModel;
|
||||
}
|
||||
if (S2ModelProvider.OPEN_AI.name().equalsIgnoreCase(llmConfig.getProvider())) {
|
||||
return OpenAiChatModel
|
||||
.builder()
|
||||
.baseUrl(llmConfig.getBaseUrl())
|
||||
.modelName(llmConfig.getModelName())
|
||||
.apiKey(llmConfig.getApiKey())
|
||||
.temperature(llmConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
||||
.build();
|
||||
} else if (S2ModelProvider.LOCAL_AI.name().equalsIgnoreCase(llmConfig.getProvider())) {
|
||||
return LocalAiChatModel
|
||||
.builder()
|
||||
.baseUrl(llmConfig.getBaseUrl())
|
||||
.modelName(llmConfig.getModelName())
|
||||
.temperature(llmConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
throw new RuntimeException("unsupported provider: " + llmConfig.getProvider());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -22,7 +22,6 @@ import com.tencent.supersonic.headless.server.service.DatabaseService;
|
||||
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||
import com.tencent.supersonic.headless.server.utils.DatabaseConverter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -36,8 +35,6 @@ import java.util.stream.Collectors;
|
||||
@Slf4j
|
||||
@Service
|
||||
public class DatabaseServiceImpl implements DatabaseService {
|
||||
@Value("${inMemoryEmbeddingStore.persistent.path:/tmp}")
|
||||
private String embeddingStorePersistentPath;
|
||||
|
||||
private final SqlUtils sqlUtils;
|
||||
private DatabaseRepository databaseRepository;
|
||||
|
||||
@@ -66,6 +66,13 @@ import com.tencent.supersonic.headless.server.service.TagMetaService;
|
||||
import com.tencent.supersonic.headless.server.utils.MetricCheckUtils;
|
||||
import com.tencent.supersonic.headless.server.utils.MetricConverter;
|
||||
import com.tencent.supersonic.headless.server.utils.ModelClusterBuilder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
@@ -79,13 +86,6 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@@ -293,12 +293,13 @@ public class MetricServiceImpl implements MetricService {
|
||||
queryMapReq.setUser(user);
|
||||
queryMapReq.setMapModeEnum(MapModeEnum.LOOSE);
|
||||
MapInfoResp mapMeta = metaDiscoveryService.getMapMeta(queryMapReq);
|
||||
Map<String, DataSetMapInfo> dataSetMapInfo = mapMeta.getDataSetMapInfo();
|
||||
if (CollectionUtils.isEmpty(dataSetMapInfo)) {
|
||||
Map<String, DataSetMapInfo> dataSetMapInfoMap = mapMeta.getDataSetMapInfo();
|
||||
if (CollectionUtils.isEmpty(dataSetMapInfoMap)) {
|
||||
return metricRespPageInfo;
|
||||
}
|
||||
Map<Long, Double> result = dataSetMapInfo.values().stream()
|
||||
Map<Long, Double> result = dataSetMapInfoMap.values().stream()
|
||||
.map(DataSetMapInfo::getMapFields)
|
||||
.filter(Objects::nonNull)
|
||||
.flatMap(Collection::stream).filter(schemaElementMatch ->
|
||||
SchemaElementType.METRIC.equals(schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toMap(schemaElementMatch ->
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package dev.langchain4j;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.S2ModelProvider;
|
||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
||||
|
||||
class S2EmbeddingModel {
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
package dev.langchain4j;
|
||||
|
||||
enum S2ModelProvider {
|
||||
|
||||
OPEN_AI,
|
||||
HUGGING_FACE,
|
||||
LOCAL_AI,
|
||||
IN_PROCESS
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
||||
import com.tencent.supersonic.chat.server.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
|
||||
@@ -21,6 +22,8 @@ import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
import com.tencent.supersonic.common.pojo.SysParameter;
|
||||
import com.tencent.supersonic.common.service.SysParameterService;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
||||
import com.tencent.supersonic.common.pojo.enums.S2ModelProvider;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
@@ -174,6 +177,11 @@ public class ChatDemoLoader implements CommandLineRunner {
|
||||
agentConfig.getTools().add(llmParserTool);
|
||||
}
|
||||
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
||||
LLMConfig llmConfig = new LLMConfig(S2ModelProvider.OPEN_AI.name(),
|
||||
"", "your_key", "gpt-3.5-turbo");
|
||||
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(false);
|
||||
agent.setLlmConfig(llmConfig);
|
||||
agent.setMultiTurnConfig(multiTurnConfig);
|
||||
agentService.createAgent(agent, User.getFakeUser());
|
||||
}
|
||||
|
||||
|
||||
@@ -307,4 +307,9 @@ CREATE TABLE IF NOT EXISTS `s2_term` (
|
||||
`updated_at` datetime DEFAULT NULL ,
|
||||
`updated_by` varchar(100) DEFAULT NULL ,
|
||||
PRIMARY KEY (`id`)
|
||||
);
|
||||
);
|
||||
|
||||
--20240520
|
||||
alter table s2_agent add column `llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL;
|
||||
alter table s2_agent add column `multi_turn_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL;
|
||||
alter table s2_agent add column `visual_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL;
|
||||
@@ -351,6 +351,9 @@ CREATE TABLE IF NOT EXISTS s2_agent
|
||||
status int null,
|
||||
examples varchar(500) null,
|
||||
config varchar(2000) null,
|
||||
llm_config varchar(2000) null,
|
||||
multi_turn_config varchar(2000) null,
|
||||
visual_config varchar(2000) null,
|
||||
created_by varchar(100) null,
|
||||
created_at TIMESTAMP null,
|
||||
updated_by varchar(100) null,
|
||||
|
||||
@@ -72,6 +72,9 @@ CREATE TABLE `s2_agent` (
|
||||
`status` int(11) DEFAULT NULL,
|
||||
`model` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`multi_turn_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`visual_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`created_by` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`created_at` datetime DEFAULT NULL,
|
||||
`updated_by` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
|
||||
@@ -351,6 +351,9 @@ CREATE TABLE IF NOT EXISTS s2_agent
|
||||
status int null,
|
||||
examples varchar(500) null,
|
||||
config varchar(2000) null,
|
||||
llm_config varchar(2000) null,
|
||||
multi_turn_config varchar(2000) null,
|
||||
visual_config varchar(2000) null,
|
||||
created_by varchar(100) null,
|
||||
created_at TIMESTAMP null,
|
||||
updated_by varchar(100) null,
|
||||
|
||||
5
pom.xml
5
pom.xml
@@ -125,6 +125,11 @@
|
||||
<artifactId>langchain4j-open-ai</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-local-ai</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-hugging-face</artifactId>
|
||||
|
||||
2
webapp/.gitignore
vendored
2
webapp/.gitignore
vendored
@@ -13,7 +13,7 @@
|
||||
|
||||
/dist
|
||||
|
||||
/supersonic-webapp
|
||||
/webapp
|
||||
|
||||
../assembly/build/supersonic-webapp.tar.gz
|
||||
supersonic-webapp.tar.gz
|
||||
|
||||
Reference in New Issue
Block a user