From 53b6c032887479f338d7a190abaad7bd3dcc40f4 Mon Sep 17 00:00:00 2001 From: LXW <1264174498@qq.com> Date: Mon, 20 May 2024 12:58:21 +0800 Subject: [PATCH] (improvement)(Chat) add extend config for agent (#1010) --- .../chat/api/pojo/enums/DefaultShowType.java | 9 + .../supersonic/chat/server/agent/Agent.java | 4 + .../chat/server/agent/MultiTurnConfig.java | 15 + .../chat/server/agent/VisualConfig.java | 15 + .../persistence/dataobject/AgentDO.java | 163 +--------- .../persistence/mapper/AgentDOMapper.java | 65 +--- .../repository/AgentRepository.java | 18 -- .../repository/impl/AgentRepositoryImpl.java | 43 --- .../chat/server/rest/AgentController.java | 6 +- .../server/service/impl/AgentServiceImpl.java | 50 +-- .../chat/server/util/QueryReqConverter.java | 1 + .../main/resources/mapper/AgentDOMapper.xml | 303 ------------------ common/pom.xml | 4 + .../common/pojo/enums/S2ModelProvider.java | 9 + .../headless/api/pojo/LLMConfig.java | 32 ++ .../headless/api/pojo/request/QueryReq.java | 2 + .../api/pojo/response/DataSetMapInfo.java | 5 +- .../chat/parser/llm/BaseSqlGeneration.java | 31 ++ .../chat/parser/llm/LLMRequestService.java | 1 + .../parser/llm/OnePassSCSqlGeneration.java | 23 +- .../chat/parser/llm/OnePassSqlGeneration.java | 23 +- .../parser/llm/TwoPassSCSqlGeneration.java | 21 +- .../chat/parser/llm/TwoPassSqlGeneration.java | 21 +- .../core/chat/query/llm/s2sql/LLMReq.java | 3 + .../headless/core/pojo/QueryContext.java | 2 + .../core/utils/S2ChatModelProvider.java | 41 +++ .../service/impl/DatabaseServiceImpl.java | 3 - .../service/impl/MetricServiceImpl.java | 21 +- .../dev/langchain4j/S2EmbeddingModel.java | 1 + .../java/dev/langchain4j/S2ModelProvider.java | 9 - .../tencent/supersonic/ChatDemoLoader.java | 8 + .../resources/config.update/sql-update.sql | 7 +- .../src/main/resources/db/schema-h2.sql | 3 + .../src/main/resources/db/schema-mysql.sql | 3 + .../src/test/resources/db/schema-h2.sql | 3 + pom.xml | 5 + webapp/.gitignore | 2 +- 37 files changed, 264 insertions(+), 711 deletions(-) create mode 100644 chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/DefaultShowType.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/MultiTurnConfig.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/VisualConfig.java delete mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/AgentRepository.java delete mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/AgentRepositoryImpl.java delete mode 100644 chat/server/src/main/resources/mapper/AgentDOMapper.xml create mode 100644 common/src/main/java/com/tencent/supersonic/common/pojo/enums/S2ModelProvider.java create mode 100644 headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/LLMConfig.java create mode 100644 headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/BaseSqlGeneration.java create mode 100644 headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/S2ChatModelProvider.java delete mode 100644 launchers/common/src/main/java/dev/langchain4j/S2ModelProvider.java diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/DefaultShowType.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/DefaultShowType.java new file mode 100644 index 000000000..08829a417 --- /dev/null +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/DefaultShowType.java @@ -0,0 +1,9 @@ +package com.tencent.supersonic.chat.api.pojo.enums; + +public enum DefaultShowType { + + TEXT, + TABLE, + WIDGET + +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java index f6f5e2a3c..1224947d3 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java @@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.server.agent; import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import com.tencent.supersonic.headless.api.pojo.LLMConfig; import com.tencent.supersonic.common.pojo.RecordInfo; import lombok.Data; import org.springframework.util.CollectionUtils; @@ -30,6 +31,9 @@ public class Agent extends RecordInfo { private Integer status; private List examples; private String agentConfig; + private LLMConfig llmConfig; + private MultiTurnConfig multiTurnConfig; + private VisualConfig visualConfig; public List getTools(AgentToolType type) { Map map = JSONObject.parseObject(agentConfig, Map.class); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/MultiTurnConfig.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/MultiTurnConfig.java new file mode 100644 index 000000000..c290cf262 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/MultiTurnConfig.java @@ -0,0 +1,15 @@ +package com.tencent.supersonic.chat.server.agent; + + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@AllArgsConstructor +@NoArgsConstructor +public class MultiTurnConfig { + + private boolean enableMultiTurn; + +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/VisualConfig.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/VisualConfig.java new file mode 100644 index 000000000..f6ee1fc78 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/VisualConfig.java @@ -0,0 +1,15 @@ +package com.tencent.supersonic.chat.server.agent; + +import com.tencent.supersonic.chat.api.pojo.enums.DefaultShowType; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class VisualConfig { + + private DefaultShowType defaultShowType; + +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java index 4f2c9d3e9..42c1a075a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java @@ -1,10 +1,18 @@ package com.tencent.supersonic.chat.server.persistence.dataobject; +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.Data; + import java.util.Date; +@Data +@TableName("s2_agent") public class AgentDO { /** */ + @TableId(type = IdType.AUTO) private Integer id; /** @@ -48,159 +56,10 @@ public class AgentDO { */ private Integer enableSearch; - /** - * @return id - */ - public Integer getId() { - return id; - } + private String llmConfig; - /** - * @param id - */ - public void setId(Integer id) { - this.id = id; - } + private String multiTurnConfig; - /** - * @return name - */ - public String getName() { - return name; - } + private String visualConfig; - /** - * @param name - */ - public void setName(String name) { - this.name = name == null ? null : name.trim(); - } - - /** - * @return description - */ - public String getDescription() { - return description; - } - - /** - * @param description - */ - public void setDescription(String description) { - this.description = description == null ? null : description.trim(); - } - - /** - * 0 offline, 1 online - * @return status 0 offline, 1 online - */ - public Integer getStatus() { - return status; - } - - /** - * 0 offline, 1 online - * @param status 0 offline, 1 online - */ - public void setStatus(Integer status) { - this.status = status; - } - - /** - * @return examples - */ - public String getExamples() { - return examples; - } - - /** - * @param examples - */ - public void setExamples(String examples) { - this.examples = examples == null ? null : examples.trim(); - } - - /** - * @return config - */ - public String getConfig() { - return config; - } - - /** - * @param config - */ - public void setConfig(String config) { - this.config = config == null ? null : config.trim(); - } - - /** - * @return created_by - */ - public String getCreatedBy() { - return createdBy; - } - - /** - * @param createdBy - */ - public void setCreatedBy(String createdBy) { - this.createdBy = createdBy == null ? null : createdBy.trim(); - } - - /** - * @return created_at - */ - public Date getCreatedAt() { - return createdAt; - } - - /** - * @param createdAt - */ - public void setCreatedAt(Date createdAt) { - this.createdAt = createdAt; - } - - /** - * @return updated_by - */ - public String getUpdatedBy() { - return updatedBy; - } - - /** - * @param updatedBy - */ - public void setUpdatedBy(String updatedBy) { - this.updatedBy = updatedBy == null ? null : updatedBy.trim(); - } - - /** - * @return updated_at - */ - public Date getUpdatedAt() { - return updatedAt; - } - - /** - * @param updatedAt - */ - public void setUpdatedAt(Date updatedAt) { - this.updatedAt = updatedAt; - } - - /** - * @return enable_search - */ - public Integer getEnableSearch() { - return enableSearch; - } - - /** - * @param enableSearch - */ - public void setEnableSearch(Integer enableSearch) { - this.enableSearch = enableSearch; - } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/AgentDOMapper.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/AgentDOMapper.java index 6beb31b37..4334782e4 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/AgentDOMapper.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/AgentDOMapper.java @@ -1,71 +1,10 @@ package com.tencent.supersonic.chat.server.persistence.mapper; +import com.baomidou.mybatisplus.core.mapper.BaseMapper; import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO; -import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDOExample; import org.apache.ibatis.annotations.Mapper; -import org.apache.ibatis.annotations.Param; - -import java.util.List; @Mapper -public interface AgentDOMapper { - /** - * - * @mbg.generated - */ - long countByExample(AgentDOExample example); +public interface AgentDOMapper extends BaseMapper { - /** - * - * @mbg.generated - */ - int deleteByPrimaryKey(Integer id); - - /** - * - * @mbg.generated - */ - int insert(AgentDO record); - - /** - * - * @mbg.generated - */ - int insertSelective(AgentDO record); - - /** - * - * @mbg.generated - */ - List selectByExample(AgentDOExample example); - - /** - * - * @mbg.generated - */ - AgentDO selectByPrimaryKey(Integer id); - - /** - * - * @mbg.generated - */ - int updateByExampleSelective(@Param("record") AgentDO record, @Param("example") AgentDOExample example); - - /** - * - * @mbg.generated - */ - int updateByExample(@Param("record") AgentDO record, @Param("example") AgentDOExample example); - - /** - * - * @mbg.generated - */ - int updateByPrimaryKeySelective(AgentDO record); - - /** - * - * @mbg.generated - */ - int updateByPrimaryKey(AgentDO record); } \ No newline at end of file diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/AgentRepository.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/AgentRepository.java deleted file mode 100644 index 49fb2e38f..000000000 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/AgentRepository.java +++ /dev/null @@ -1,18 +0,0 @@ -package com.tencent.supersonic.chat.server.persistence.repository; - -import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO; - -import java.util.List; - -public interface AgentRepository { - - List getAgents(); - - void createAgent(AgentDO agentDO); - - void updateAgent(AgentDO agentDO); - - AgentDO getAgent(Integer id); - - void deleteAgent(Integer id); -} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/AgentRepositoryImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/AgentRepositoryImpl.java deleted file mode 100644 index ebe670f87..000000000 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/AgentRepositoryImpl.java +++ /dev/null @@ -1,43 +0,0 @@ -package com.tencent.supersonic.chat.server.persistence.repository.impl; - -import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO; -import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDOExample; -import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper; -import com.tencent.supersonic.chat.server.persistence.repository.AgentRepository; -import org.springframework.stereotype.Repository; -import java.util.List; - -@Repository -public class AgentRepositoryImpl implements AgentRepository { - - private AgentDOMapper agentDOMapper; - - public AgentRepositoryImpl(AgentDOMapper agentDOMapper) { - this.agentDOMapper = agentDOMapper; - } - - @Override - public List getAgents() { - return agentDOMapper.selectByExample(new AgentDOExample()); - } - - @Override - public void createAgent(AgentDO agentDO) { - agentDOMapper.insert(agentDO); - } - - @Override - public void updateAgent(AgentDO agentDO) { - agentDOMapper.updateByPrimaryKey(agentDO); - } - - @Override - public AgentDO getAgent(Integer id) { - return agentDOMapper.selectByPrimaryKey(id); - } - - @Override - public void deleteAgent(Integer id) { - agentDOMapper.deleteByPrimaryKey(id); - } -} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java index ba374343c..0081c11a1 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java @@ -5,6 +5,7 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.service.AgentService; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.DeleteMapping; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PostMapping; @@ -21,12 +22,9 @@ import java.util.Map; @RequestMapping({"/api/chat/agent", "/openapi/chat/agent"}) public class AgentController { + @Autowired private AgentService agentService; - public AgentController(AgentService agentService) { - this.agentService = agentService; - } - @PostMapping public boolean createAgent(@RequestBody Agent agent, HttpServletRequest httpServletRequest, diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index a9bff8c10..92c08c7b0 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -1,25 +1,23 @@ package com.tencent.supersonic.chat.server.service.impl; -import com.alibaba.fastjson.JSONObject; +import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.server.agent.Agent; +import com.tencent.supersonic.headless.api.pojo.LLMConfig; +import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; +import com.tencent.supersonic.chat.server.agent.VisualConfig; import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO; -import com.tencent.supersonic.chat.server.persistence.repository.AgentRepository; +import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper; import com.tencent.supersonic.chat.server.service.AgentService; -import java.util.Date; -import java.util.List; -import java.util.stream.Collectors; +import com.tencent.supersonic.common.util.JsonUtil; import org.springframework.beans.BeanUtils; import org.springframework.stereotype.Service; +import java.util.List; +import java.util.stream.Collectors; @Service -public class AgentServiceImpl implements AgentService { - - private AgentRepository agentRepository; - - public AgentServiceImpl(AgentRepository agentRepository) { - this.agentRepository = agentRepository; - } +public class AgentServiceImpl extends ServiceImpl + implements AgentService { @Override public List getAgents() { @@ -29,12 +27,14 @@ public class AgentServiceImpl implements AgentService { @Override public void createAgent(Agent agent, User user) { - agentRepository.createAgent(convert(agent, user)); + agent.createdBy(user.getName()); + save(convert(agent)); } @Override public void updateAgent(Agent agent, User user) { - agentRepository.updateAgent(convert(agent, user)); + agent.updatedBy(user.getName()); + updateById(convert(agent)); } @Override @@ -42,16 +42,16 @@ public class AgentServiceImpl implements AgentService { if (id == null) { return null; } - return convert(agentRepository.getAgent(id)); + return convert(getById(id)); } @Override public void deleteAgent(Integer id) { - agentRepository.deleteAgent(id); + removeById(id); } private List getAgentDOList() { - return agentRepository.getAgents(); + return list(); } private Agent convert(AgentDO agentDO) { @@ -61,19 +61,21 @@ public class AgentServiceImpl implements AgentService { Agent agent = new Agent(); BeanUtils.copyProperties(agentDO, agent); agent.setAgentConfig(agentDO.getConfig()); - agent.setExamples(JSONObject.parseArray(agentDO.getExamples(), String.class)); + agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class)); + agent.setLlmConfig(JsonUtil.toObject(agentDO.getLlmConfig(), LLMConfig.class)); + agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class)); + agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class)); return agent; } - private AgentDO convert(Agent agent, User user) { + private AgentDO convert(Agent agent) { AgentDO agentDO = new AgentDO(); BeanUtils.copyProperties(agent, agentDO); agentDO.setConfig(agent.getAgentConfig()); - agentDO.setExamples(JSONObject.toJSONString(agent.getExamples())); - agentDO.setCreatedAt(new Date()); - agentDO.setCreatedBy(user.getName()); - agentDO.setUpdatedAt(new Date()); - agentDO.setUpdatedBy(user.getName()); + agentDO.setExamples(JsonUtil.toString(agent.getExamples())); + agentDO.setLlmConfig(JsonUtil.toString(agent.getLlmConfig())); + agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig())); + agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig())); if (agentDO.getStatus() == null) { agentDO.setStatus(1); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java index c1328e3da..3b2afad81 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java @@ -30,6 +30,7 @@ public class QueryReqConverter { && MapUtils.isNotEmpty(queryReq.getMapInfo().getDataSetElementMatches())) { queryReq.setMapInfo(queryReq.getMapInfo()); } + queryReq.setLlmConfig(agent.getLlmConfig()); return queryReq; } diff --git a/chat/server/src/main/resources/mapper/AgentDOMapper.xml b/chat/server/src/main/resources/mapper/AgentDOMapper.xml deleted file mode 100644 index 9c1c3556f..000000000 --- a/chat/server/src/main/resources/mapper/AgentDOMapper.xml +++ /dev/null @@ -1,303 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - and ${criterion.condition} - - - and ${criterion.condition} #{criterion.value} - - - and ${criterion.condition} #{criterion.value} and #{criterion.secondValue} - - - and ${criterion.condition} - - #{listItem} - - - - - - - - - - - - - - - - - - and ${criterion.condition} - - - and ${criterion.condition} #{criterion.value} - - - and ${criterion.condition} #{criterion.value} and #{criterion.secondValue} - - - and ${criterion.condition} - - #{listItem} - - - - - - - - - - - id, name, description, status, examples, config, created_by, created_at, updated_by, - updated_at, enable_search - - - - - delete from s2_agent - where id = #{id,jdbcType=INTEGER} - - - insert into s2_agent (id, name, description, - status, examples, config, - created_by, created_at, updated_by, - updated_at, enable_search) - values (#{id,jdbcType=INTEGER}, #{name,jdbcType=VARCHAR}, #{description,jdbcType=VARCHAR}, - #{status,jdbcType=INTEGER}, #{examples,jdbcType=VARCHAR}, #{config,jdbcType=VARCHAR}, - #{createdBy,jdbcType=VARCHAR}, #{createdAt,jdbcType=TIMESTAMP}, #{updatedBy,jdbcType=VARCHAR}, - #{updatedAt,jdbcType=TIMESTAMP}, #{enableSearch,jdbcType=INTEGER}) - - - insert into s2_agent - - - id, - - - name, - - - description, - - - status, - - - examples, - - - config, - - - created_by, - - - created_at, - - - updated_by, - - - updated_at, - - - enable_search, - - - - - #{id,jdbcType=INTEGER}, - - - #{name,jdbcType=VARCHAR}, - - - #{description,jdbcType=VARCHAR}, - - - #{status,jdbcType=INTEGER}, - - - #{examples,jdbcType=VARCHAR}, - - - #{config,jdbcType=VARCHAR}, - - - #{createdBy,jdbcType=VARCHAR}, - - - #{createdAt,jdbcType=TIMESTAMP}, - - - #{updatedBy,jdbcType=VARCHAR}, - - - #{updatedAt,jdbcType=TIMESTAMP}, - - - #{enableSearch,jdbcType=INTEGER}, - - - - - - update s2_agent - - - id = #{record.id,jdbcType=INTEGER}, - - - name = #{record.name,jdbcType=VARCHAR}, - - - description = #{record.description,jdbcType=VARCHAR}, - - - status = #{record.status,jdbcType=INTEGER}, - - - examples = #{record.examples,jdbcType=VARCHAR}, - - - config = #{record.config,jdbcType=VARCHAR}, - - - created_by = #{record.createdBy,jdbcType=VARCHAR}, - - - created_at = #{record.createdAt,jdbcType=TIMESTAMP}, - - - updated_by = #{record.updatedBy,jdbcType=VARCHAR}, - - - updated_at = #{record.updatedAt,jdbcType=TIMESTAMP}, - - - enable_search = #{record.enableSearch,jdbcType=INTEGER}, - - - - - - - - update s2_agent - set id = #{record.id,jdbcType=INTEGER}, - name = #{record.name,jdbcType=VARCHAR}, - description = #{record.description,jdbcType=VARCHAR}, - status = #{record.status,jdbcType=INTEGER}, - examples = #{record.examples,jdbcType=VARCHAR}, - config = #{record.config,jdbcType=VARCHAR}, - created_by = #{record.createdBy,jdbcType=VARCHAR}, - created_at = #{record.createdAt,jdbcType=TIMESTAMP}, - updated_by = #{record.updatedBy,jdbcType=VARCHAR}, - updated_at = #{record.updatedAt,jdbcType=TIMESTAMP}, - enable_search = #{record.enableSearch,jdbcType=INTEGER} - - - - - - update s2_agent - - - name = #{name,jdbcType=VARCHAR}, - - - description = #{description,jdbcType=VARCHAR}, - - - status = #{status,jdbcType=INTEGER}, - - - examples = #{examples,jdbcType=VARCHAR}, - - - config = #{config,jdbcType=VARCHAR}, - - - created_by = #{createdBy,jdbcType=VARCHAR}, - - - created_at = #{createdAt,jdbcType=TIMESTAMP}, - - - updated_by = #{updatedBy,jdbcType=VARCHAR}, - - - updated_at = #{updatedAt,jdbcType=TIMESTAMP}, - - - enable_search = #{enableSearch,jdbcType=INTEGER}, - - - where id = #{id,jdbcType=INTEGER} - - - update s2_agent - set name = #{name,jdbcType=VARCHAR}, - description = #{description,jdbcType=VARCHAR}, - status = #{status,jdbcType=INTEGER}, - examples = #{examples,jdbcType=VARCHAR}, - config = #{config,jdbcType=VARCHAR}, - created_by = #{createdBy,jdbcType=VARCHAR}, - created_at = #{createdAt,jdbcType=TIMESTAMP}, - updated_by = #{updatedBy,jdbcType=VARCHAR}, - updated_at = #{updatedAt,jdbcType=TIMESTAMP}, - enable_search = #{enableSearch,jdbcType=INTEGER} - where id = #{id,jdbcType=INTEGER} - - \ No newline at end of file diff --git a/common/pom.xml b/common/pom.xml index aed6cc3eb..c3cfd33c0 100644 --- a/common/pom.xml +++ b/common/pom.xml @@ -172,6 +172,10 @@ dev.langchain4j langchain4j-open-ai + + dev.langchain4j + langchain4j-local-ai + dev.langchain4j langchain4j diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/S2ModelProvider.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/S2ModelProvider.java new file mode 100644 index 000000000..ba1065886 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/S2ModelProvider.java @@ -0,0 +1,9 @@ +package com.tencent.supersonic.common.pojo.enums; + +public enum S2ModelProvider { + + OPEN_AI, + HUGGING_FACE, + LOCAL_AI, + IN_PROCESS +} \ No newline at end of file diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/LLMConfig.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/LLMConfig.java new file mode 100644 index 000000000..84ea32ef8 --- /dev/null +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/LLMConfig.java @@ -0,0 +1,32 @@ +package com.tencent.supersonic.headless.api.pojo; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@AllArgsConstructor +@NoArgsConstructor +public class LLMConfig { + + private String provider; + + private String baseUrl; + + private String apiKey; + + private String modelName; + + private Double temperature; + + private Long timeOut; + + public LLMConfig(String provider, String baseUrl, String apiKey, String modelName) { + this.provider = provider; + this.baseUrl = baseUrl; + this.apiKey = apiKey; + this.modelName = modelName; + this.temperature = 0.0d; + this.timeOut = 60L; + } +} diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java index 55f360b42..251d2be08 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.api.pojo.request; import com.google.common.collect.Sets; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; +import com.tencent.supersonic.headless.api.pojo.LLMConfig; import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; @@ -22,4 +23,5 @@ public class QueryReq { private MapModeEnum mapModeEnum = MapModeEnum.STRICT; private SchemaMapInfo mapInfo = new SchemaMapInfo(); private QueryDataType queryDataType = QueryDataType.ALL; + private LLMConfig llmConfig; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/DataSetMapInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/DataSetMapInfo.java index c4676f150..e32182092 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/DataSetMapInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/DataSetMapInfo.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.headless.api.pojo.response; +import com.google.common.collect.Lists; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import lombok.Data; import java.util.List; @@ -11,8 +12,8 @@ public class DataSetMapInfo { private String description; - private List mapFields; + private List mapFields = Lists.newArrayList(); - private List topFields; + private List topFields = Lists.newArrayList(); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/BaseSqlGeneration.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/BaseSqlGeneration.java new file mode 100644 index 000000000..55441b1b1 --- /dev/null +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/BaseSqlGeneration.java @@ -0,0 +1,31 @@ +package com.tencent.supersonic.headless.core.chat.parser.llm; + +import com.tencent.supersonic.headless.api.pojo.LLMConfig; +import com.tencent.supersonic.headless.core.config.OptimizationConfig; +import com.tencent.supersonic.headless.core.utils.S2ChatModelProvider; +import dev.langchain4j.model.chat.ChatLanguageModel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +@Service +public abstract class BaseSqlGeneration implements SqlGeneration, InitializingBean { + + protected static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); + + @Autowired + protected SqlExamplarLoader sqlExamplarLoader; + + @Autowired + protected OptimizationConfig optimizationConfig; + + @Autowired + protected SqlPromptGenerator sqlPromptGenerator; + + protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) { + return S2ChatModelProvider.provide(llmConfig); + } + +} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java index 0a37ea86c..1cd3ddae8 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java @@ -96,6 +96,7 @@ public class LLMRequestService { } llmReq.setCurrentDate(currentDate); llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode().getName()); + llmReq.setLlmConfig(queryCtx.getLlmConfig()); return llmReq; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGeneration.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGeneration.java index 43325b79e..6444bd73e 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGeneration.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGeneration.java @@ -1,7 +1,6 @@ package com.tencent.supersonic.headless.core.chat.parser.llm; import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.headless.core.config.OptimizationConfig; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; @@ -12,10 +11,6 @@ import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.output.Response; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.tuple.Pair; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.beans.factory.InitializingBean; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.util.HashMap; @@ -26,20 +21,7 @@ import java.util.stream.Collectors; @Service @Slf4j -public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean { - - private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); - @Autowired - private ChatLanguageModel chatLanguageModel; - - @Autowired - private SqlExamplarLoader sqlExamplarLoader; - - @Autowired - private OptimizationConfig optimizationConfig; - - @Autowired - private SqlPromptGenerator sqlPromptGenerator; +public class OnePassSCSqlGeneration extends BaseSqlGeneration { @Override public LLMResp generation(LLMReq llmReq, Long dataSetId) { @@ -59,7 +41,8 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean { Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt)) .apply(new HashMap<>()); keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage()); - Response response = chatLanguageModel.generate(prompt.toSystemMessage()); + ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig()); + Response response = chatLanguageModel.generate(prompt.toSystemMessage()); String result = response.content().text(); llmResults.add(result); keyPipelineLog.info("model response:{}", result); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSqlGeneration.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSqlGeneration.java index 738583554..c1c8328f4 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSqlGeneration.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSqlGeneration.java @@ -1,7 +1,6 @@ package com.tencent.supersonic.headless.core.chat.parser.llm; import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.headless.core.config.OptimizationConfig; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; @@ -12,10 +11,6 @@ import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.output.Response; import lombok.extern.slf4j.Slf4j; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.beans.factory.InitializingBean; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.util.HashMap; @@ -24,26 +19,13 @@ import java.util.Map; @Service @Slf4j -public class OnePassSqlGeneration implements SqlGeneration, InitializingBean { - - private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); - @Autowired - private ChatLanguageModel chatLanguageModel; - - @Autowired - private SqlExamplarLoader sqlExampleLoader; - - @Autowired - private OptimizationConfig optimizationConfig; - - @Autowired - private SqlPromptGenerator sqlPromptGenerator; +public class OnePassSqlGeneration extends BaseSqlGeneration { @Override public LLMResp generation(LLMReq llmReq, Long dataSetId) { //1.retriever sqlExamples keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq); - List> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(), + List> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(), optimizationConfig.getText2sqlExampleNum()); //2.generator linking and sql prompt by sqlExamples,and generate response. @@ -51,6 +33,7 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean { Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>()); keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage()); + ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig()); Response response = chatLanguageModel.generate(prompt.toSystemMessage()); String result = response.content().text(); keyPipelineLog.info("model response:{}", result); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGeneration.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGeneration.java index 199ff5da7..78e3f559c 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGeneration.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGeneration.java @@ -1,7 +1,6 @@ package com.tencent.supersonic.headless.core.chat.parser.llm; import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.headless.core.config.OptimizationConfig; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; @@ -11,10 +10,6 @@ import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.output.Response; import org.apache.commons.lang3.tuple.Pair; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.beans.factory.InitializingBean; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.util.HashMap; @@ -23,20 +18,7 @@ import java.util.Map; import java.util.concurrent.CopyOnWriteArrayList; @Service -public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean { - - private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); - @Autowired - private ChatLanguageModel chatLanguageModel; - - @Autowired - private SqlExamplarLoader sqlExamplarLoader; - - @Autowired - private OptimizationConfig optimizationConfig; - - @Autowired - private SqlPromptGenerator sqlPromptGenerator; +public class TwoPassSCSqlGeneration extends BaseSqlGeneration { @Override public LLMResp generation(LLMReq llmReq, Long dataSetId) { @@ -51,6 +33,7 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean { //2.generator linking prompt,and parallel generate response. List linkingPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, false); List linkingResults = new CopyOnWriteArrayList<>(); + ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig()); linkingPromptPool.parallelStream().forEach( linkingPrompt -> { Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>()); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSqlGeneration.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSqlGeneration.java index 00da5dc11..148c6b3c9 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSqlGeneration.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSqlGeneration.java @@ -1,7 +1,6 @@ package com.tencent.supersonic.headless.core.chat.parser.llm; import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.headless.core.config.OptimizationConfig; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; @@ -11,10 +10,6 @@ import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.output.Response; import lombok.extern.slf4j.Slf4j; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.beans.factory.InitializingBean; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.util.HashMap; @@ -23,20 +18,7 @@ import java.util.Map; @Service @Slf4j -public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean { - - private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); - @Autowired - private ChatLanguageModel chatLanguageModel; - - @Autowired - private SqlExamplarLoader sqlExamplarLoader; - - @Autowired - private OptimizationConfig optimizationConfig; - - @Autowired - private SqlPromptGenerator sqlPromptGenerator; +public class TwoPassSqlGeneration extends BaseSqlGeneration { @Override public LLMResp generation(LLMReq llmReq, Long dataSetId) { @@ -48,6 +30,7 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean { Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>()); keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage()); + ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig()); Response response = chatLanguageModel.generate(prompt.toSystemMessage()); keyPipelineLog.info("step one model response:{}", response.content().text()); String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text()); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java index 525b306e4..f5b909f5c 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.core.chat.query.llm.s2sql; import com.fasterxml.jackson.annotation.JsonValue; +import com.tencent.supersonic.headless.api.pojo.LLMConfig; import lombok.Data; import java.util.List; @@ -22,6 +23,8 @@ public class LLMReq { private String sqlGenerationMode; + private LLMConfig llmConfig; + @Data public static class ElementValue { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java index dfb0ce50e..5b7c34b9d 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java @@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.headless.api.pojo.LLMConfig; import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; @@ -47,6 +48,7 @@ public class QueryContext { @JsonIgnore private WorkflowState workflowState; private QueryDataType queryDataType = QueryDataType.ALL; + private LLMConfig llmConfig; public List getCandidateQueries() { OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/S2ChatModelProvider.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/S2ChatModelProvider.java new file mode 100644 index 000000000..eed8041fe --- /dev/null +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/S2ChatModelProvider.java @@ -0,0 +1,41 @@ +package com.tencent.supersonic.headless.core.utils; + +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.headless.api.pojo.LLMConfig; +import com.tencent.supersonic.common.pojo.enums.S2ModelProvider; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.localai.LocalAiChatModel; +import dev.langchain4j.model.openai.OpenAiChatModel; +import org.apache.commons.lang3.StringUtils; +import java.time.Duration; + +public class S2ChatModelProvider { + + public static ChatLanguageModel provide(LLMConfig llmConfig) { + ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class); + if (StringUtils.isBlank(llmConfig.getProvider()) + || StringUtils.isBlank(llmConfig.getBaseUrl())) { + return chatLanguageModel; + } + if (S2ModelProvider.OPEN_AI.name().equalsIgnoreCase(llmConfig.getProvider())) { + return OpenAiChatModel + .builder() + .baseUrl(llmConfig.getBaseUrl()) + .modelName(llmConfig.getModelName()) + .apiKey(llmConfig.getApiKey()) + .temperature(llmConfig.getTemperature()) + .timeout(Duration.ofSeconds(llmConfig.getTimeOut())) + .build(); + } else if (S2ModelProvider.LOCAL_AI.name().equalsIgnoreCase(llmConfig.getProvider())) { + return LocalAiChatModel + .builder() + .baseUrl(llmConfig.getBaseUrl()) + .modelName(llmConfig.getModelName()) + .temperature(llmConfig.getTemperature()) + .timeout(Duration.ofSeconds(llmConfig.getTimeOut())) + .build(); + } + throw new RuntimeException("unsupported provider: " + llmConfig.getProvider()); + } + +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java index 4c2b4e33b..97e643ae9 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java @@ -22,7 +22,6 @@ import com.tencent.supersonic.headless.server.service.DatabaseService; import com.tencent.supersonic.headless.server.service.ModelService; import com.tencent.supersonic.headless.server.utils.DatabaseConverter; import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Lazy; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; @@ -36,8 +35,6 @@ import java.util.stream.Collectors; @Slf4j @Service public class DatabaseServiceImpl implements DatabaseService { - @Value("${inMemoryEmbeddingStore.persistent.path:/tmp}") - private String embeddingStorePersistentPath; private final SqlUtils sqlUtils; private DatabaseRepository databaseRepository; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java index 5cff01e8d..4f917c5a5 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java @@ -66,6 +66,13 @@ import com.tencent.supersonic.headless.server.service.TagMetaService; import com.tencent.supersonic.headless.server.utils.MetricCheckUtils; import com.tencent.supersonic.headless.server.utils.MetricConverter; import com.tencent.supersonic.headless.server.utils.ModelClusterBuilder; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.BeanUtils; +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.annotation.Lazy; +import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -79,13 +86,6 @@ import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.BeanUtils; -import org.springframework.context.ApplicationEventPublisher; -import org.springframework.context.annotation.Lazy; -import org.springframework.stereotype.Service; -import org.springframework.util.CollectionUtils; @Service @Slf4j @@ -293,12 +293,13 @@ public class MetricServiceImpl implements MetricService { queryMapReq.setUser(user); queryMapReq.setMapModeEnum(MapModeEnum.LOOSE); MapInfoResp mapMeta = metaDiscoveryService.getMapMeta(queryMapReq); - Map dataSetMapInfo = mapMeta.getDataSetMapInfo(); - if (CollectionUtils.isEmpty(dataSetMapInfo)) { + Map dataSetMapInfoMap = mapMeta.getDataSetMapInfo(); + if (CollectionUtils.isEmpty(dataSetMapInfoMap)) { return metricRespPageInfo; } - Map result = dataSetMapInfo.values().stream() + Map result = dataSetMapInfoMap.values().stream() .map(DataSetMapInfo::getMapFields) + .filter(Objects::nonNull) .flatMap(Collection::stream).filter(schemaElementMatch -> SchemaElementType.METRIC.equals(schemaElementMatch.getElement().getType())) .collect(Collectors.toMap(schemaElementMatch -> diff --git a/launchers/common/src/main/java/dev/langchain4j/S2EmbeddingModel.java b/launchers/common/src/main/java/dev/langchain4j/S2EmbeddingModel.java index 94ffbe0ea..e096168b2 100644 --- a/launchers/common/src/main/java/dev/langchain4j/S2EmbeddingModel.java +++ b/launchers/common/src/main/java/dev/langchain4j/S2EmbeddingModel.java @@ -1,5 +1,6 @@ package dev.langchain4j; +import com.tencent.supersonic.common.pojo.enums.S2ModelProvider; import org.springframework.boot.context.properties.NestedConfigurationProperty; class S2EmbeddingModel { diff --git a/launchers/common/src/main/java/dev/langchain4j/S2ModelProvider.java b/launchers/common/src/main/java/dev/langchain4j/S2ModelProvider.java deleted file mode 100644 index 606e6a237..000000000 --- a/launchers/common/src/main/java/dev/langchain4j/S2ModelProvider.java +++ /dev/null @@ -1,9 +0,0 @@ -package dev.langchain4j; - -enum S2ModelProvider { - - OPEN_AI, - HUGGING_FACE, - LOCAL_AI, - IN_PROCESS -} \ No newline at end of file diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/ChatDemoLoader.java b/launchers/standalone/src/main/java/com/tencent/supersonic/ChatDemoLoader.java index b6cdd3704..7eb86c3c1 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/ChatDemoLoader.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/ChatDemoLoader.java @@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.AgentConfig; import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.LLMParserTool; +import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.agent.RuleParserTool; import com.tencent.supersonic.chat.server.plugin.Plugin; import com.tencent.supersonic.chat.server.plugin.PluginParseConfig; @@ -21,6 +22,8 @@ import com.tencent.supersonic.chat.server.service.PluginService; import com.tencent.supersonic.common.pojo.SysParameter; import com.tencent.supersonic.common.service.SysParameterService; import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.headless.api.pojo.LLMConfig; +import com.tencent.supersonic.common.pojo.enums.S2ModelProvider; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; @@ -174,6 +177,11 @@ public class ChatDemoLoader implements CommandLineRunner { agentConfig.getTools().add(llmParserTool); } agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); + LLMConfig llmConfig = new LLMConfig(S2ModelProvider.OPEN_AI.name(), + "", "your_key", "gpt-3.5-turbo"); + MultiTurnConfig multiTurnConfig = new MultiTurnConfig(false); + agent.setLlmConfig(llmConfig); + agent.setMultiTurnConfig(multiTurnConfig); agentService.createAgent(agent, User.getFakeUser()); } diff --git a/launchers/standalone/src/main/resources/config.update/sql-update.sql b/launchers/standalone/src/main/resources/config.update/sql-update.sql index c5cca5769..45003f1e9 100644 --- a/launchers/standalone/src/main/resources/config.update/sql-update.sql +++ b/launchers/standalone/src/main/resources/config.update/sql-update.sql @@ -307,4 +307,9 @@ CREATE TABLE IF NOT EXISTS `s2_term` ( `updated_at` datetime DEFAULT NULL , `updated_by` varchar(100) DEFAULT NULL , PRIMARY KEY (`id`) -); \ No newline at end of file +); + +--20240520 +alter table s2_agent add column `llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL; +alter table s2_agent add column `multi_turn_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL; +alter table s2_agent add column `visual_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL; \ No newline at end of file diff --git a/launchers/standalone/src/main/resources/db/schema-h2.sql b/launchers/standalone/src/main/resources/db/schema-h2.sql index 9412b14b9..c2c06f598 100644 --- a/launchers/standalone/src/main/resources/db/schema-h2.sql +++ b/launchers/standalone/src/main/resources/db/schema-h2.sql @@ -351,6 +351,9 @@ CREATE TABLE IF NOT EXISTS s2_agent status int null, examples varchar(500) null, config varchar(2000) null, + llm_config varchar(2000) null, + multi_turn_config varchar(2000) null, + visual_config varchar(2000) null, created_by varchar(100) null, created_at TIMESTAMP null, updated_by varchar(100) null, diff --git a/launchers/standalone/src/main/resources/db/schema-mysql.sql b/launchers/standalone/src/main/resources/db/schema-mysql.sql index 779d25461..1e7e5fcd3 100644 --- a/launchers/standalone/src/main/resources/db/schema-mysql.sql +++ b/launchers/standalone/src/main/resources/db/schema-mysql.sql @@ -72,6 +72,9 @@ CREATE TABLE `s2_agent` ( `status` int(11) DEFAULT NULL, `model` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL, `config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL, + `llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL, + `multi_turn_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL, + `visual_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL, `created_by` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL, `created_at` datetime DEFAULT NULL, `updated_by` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL, diff --git a/launchers/standalone/src/test/resources/db/schema-h2.sql b/launchers/standalone/src/test/resources/db/schema-h2.sql index 9412b14b9..c2c06f598 100644 --- a/launchers/standalone/src/test/resources/db/schema-h2.sql +++ b/launchers/standalone/src/test/resources/db/schema-h2.sql @@ -351,6 +351,9 @@ CREATE TABLE IF NOT EXISTS s2_agent status int null, examples varchar(500) null, config varchar(2000) null, + llm_config varchar(2000) null, + multi_turn_config varchar(2000) null, + visual_config varchar(2000) null, created_by varchar(100) null, created_at TIMESTAMP null, updated_by varchar(100) null, diff --git a/pom.xml b/pom.xml index 497d1b4d3..4a0dc0ba7 100644 --- a/pom.xml +++ b/pom.xml @@ -125,6 +125,11 @@ langchain4j-open-ai ${langchain4j.version} + + dev.langchain4j + langchain4j-local-ai + ${langchain4j.version} + dev.langchain4j langchain4j-hugging-face diff --git a/webapp/.gitignore b/webapp/.gitignore index 33d65bce4..f78504428 100644 --- a/webapp/.gitignore +++ b/webapp/.gitignore @@ -13,7 +13,7 @@ /dist -/supersonic-webapp +/webapp ../assembly/build/supersonic-webapp.tar.gz supersonic-webapp.tar.gz