mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
[improvement](supersonic) based on version 0.7.2 (#34)
Co-authored-by: zuopengge <hwzuopengge@tencent.com>
This commit is contained in:
@@ -6,9 +6,10 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\
|
||||
com.tencent.supersonic.chat.api.component.SemanticParser=\
|
||||
com.tencent.supersonic.chat.parser.rule.QueryModeParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.ContextInheritParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.AgentCheckParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.TimeRangeParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.AggregateTypeParser, \
|
||||
com.tencent.supersonic.chat.parser.llm.LLMDSLParser, \
|
||||
com.tencent.supersonic.chat.parser.llm.dsl.LLMDSLParser, \
|
||||
com.tencent.supersonic.chat.parser.function.FunctionBasedParser
|
||||
com.tencent.supersonic.chat.api.component.SemanticLayer=\
|
||||
com.tencent.supersonic.knowledge.semantic.RemoteSemanticLayer
|
||||
@@ -19,4 +20,13 @@ com.tencent.supersonic.chat.parser.function.ModelResolver=\
|
||||
com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor=\
|
||||
com.tencent.supersonic.auth.authentication.interceptor.DefaultAuthenticationInterceptor
|
||||
com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
|
||||
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
|
||||
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
|
||||
|
||||
|
||||
com.tencent.supersonic.chat.api.component.DSLOptimizer=\
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.DateFieldCorrector, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.FieldCorrector, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.FunctionCorrector, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.TableNameCorrector, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.QueryFilterAppend, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.SelectFieldAppendCorrector
|
||||
@@ -1,6 +1,12 @@
|
||||
package com.tencent.supersonic;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.agent.AgentConfig;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
|
||||
@@ -14,16 +20,14 @@ import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.chat.query.plugin.ParamOption;
|
||||
import com.tencent.supersonic.chat.query.plugin.WebBase;
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
import com.tencent.supersonic.chat.service.ConfigService;
|
||||
import com.tencent.supersonic.chat.service.PluginService;
|
||||
import com.tencent.supersonic.chat.service.QueryService;
|
||||
import com.tencent.supersonic.chat.service.*;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.boot.context.event.ApplicationReadyEvent;
|
||||
import org.springframework.context.ApplicationListener;
|
||||
import org.springframework.stereotype.Component;
|
||||
@@ -32,6 +36,7 @@ import org.springframework.stereotype.Component;
|
||||
@Slf4j
|
||||
public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent> {
|
||||
|
||||
@Qualifier("chatQueryService")
|
||||
@Autowired
|
||||
private QueryService queryService;
|
||||
@Autowired
|
||||
@@ -40,6 +45,8 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
|
||||
protected ConfigService configService;
|
||||
@Autowired
|
||||
private PluginService pluginService;
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
|
||||
private User user = User.getFakeUser();
|
||||
|
||||
@@ -175,43 +182,25 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
|
||||
pluginService.createPlugin(plugin_1, user);
|
||||
}
|
||||
|
||||
private void addPlugin_2() {
|
||||
Plugin plugin_2 = new Plugin();
|
||||
plugin_2.setType("DSL");
|
||||
plugin_2.setModelList(Arrays.asList(1L, 2L));
|
||||
plugin_2.setPattern("");
|
||||
plugin_2.setParseModeConfig(null);
|
||||
plugin_2.setName("大模型语义解析");
|
||||
List<String> examples = new ArrayList<>();
|
||||
examples.add("超音数访问次数最高的部门是哪个");
|
||||
examples.add("超音数访问人数最高的部门是哪个");
|
||||
|
||||
PluginParseConfig parseConfig = PluginParseConfig.builder()
|
||||
.name("DSL")
|
||||
.description("这个工具能够将用户的自然语言查询转化为SQL语句,从而从数据库中的查询具体的数据。用于处理数据查询的问题,提供基于事实的数据")
|
||||
.examples(examples)
|
||||
.build();
|
||||
plugin_2.setParseModeConfig(JsonUtil.toString(parseConfig));
|
||||
pluginService.createPlugin(plugin_2, user);
|
||||
}
|
||||
|
||||
private void addPlugin_3() {
|
||||
Plugin plugin_2 = new Plugin();
|
||||
plugin_2.setType("CONTENT_INTERPRET");
|
||||
plugin_2.setModelList(Arrays.asList(1L));
|
||||
plugin_2.setPattern("超音数最近访问情况怎么样");
|
||||
plugin_2.setParseModeConfig(null);
|
||||
plugin_2.setName("内容解读");
|
||||
List<String> examples = new ArrayList<>();
|
||||
examples.add("超音数最近访问情况怎么样");
|
||||
examples.add("超音数最近访问情况如何");
|
||||
PluginParseConfig parseConfig = PluginParseConfig.builder()
|
||||
.name("supersonic_content_interpret")
|
||||
.description("这个工具能够先查询到相关的数据并交给大模型进行解读, 最后返回解读结果")
|
||||
.examples(examples)
|
||||
.build();
|
||||
plugin_2.setParseModeConfig(JsonUtil.toString(parseConfig));
|
||||
pluginService.createPlugin(plugin_2, user);
|
||||
private void addAgent() {
|
||||
Agent agent = new Agent();
|
||||
agent.setId(1);
|
||||
agent.setName("查信息");
|
||||
agent.setDescription("查信息");
|
||||
agent.setStatus(1);
|
||||
agent.setEnableSearch(1);
|
||||
agent.setExamples(Lists.newArrayList("超音数访问次数", "超音数访问人数", "alice 停留时长"));
|
||||
AgentConfig agentConfig = new AgentConfig();
|
||||
RuleQueryTool ruleQueryTool = new RuleQueryTool();
|
||||
ruleQueryTool.setType(AgentToolType.RULE);
|
||||
ruleQueryTool.setQueryModes(Lists.newArrayList(
|
||||
"ENTITY_DETAIL", "ENTITY_LIST_FILTER", "ENTITY_ID", "METRIC_ENTITY",
|
||||
"METRIC_FILTER", "METRIC_GROUPBY", "METRIC_MODEL", "METRIC_ORDERBY"
|
||||
));
|
||||
agentConfig.getTools().add(ruleQueryTool);
|
||||
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
||||
agentService.createAgent(agent, User.getFakeUser());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -220,8 +209,7 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
|
||||
addDemoChatConfig_1();
|
||||
addDemoChatConfig_2();
|
||||
addPlugin_1();
|
||||
addPlugin_2();
|
||||
addPlugin_3();
|
||||
addAgent();
|
||||
addSampleChats();
|
||||
addSampleChats2();
|
||||
} catch (Exception e) {
|
||||
|
||||
@@ -6,9 +6,10 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\
|
||||
com.tencent.supersonic.chat.api.component.SemanticParser=\
|
||||
com.tencent.supersonic.chat.parser.rule.QueryModeParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.ContextInheritParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.AgentCheckParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.TimeRangeParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.AggregateTypeParser, \
|
||||
com.tencent.supersonic.chat.parser.llm.LLMDSLParser, \
|
||||
com.tencent.supersonic.chat.parser.llm.dsl.LLMDSLParser, \
|
||||
com.tencent.supersonic.chat.parser.embedding.EmbeddingBasedParser, \
|
||||
com.tencent.supersonic.chat.parser.function.FunctionBasedParser
|
||||
com.tencent.supersonic.chat.api.component.SemanticLayer=\
|
||||
@@ -21,3 +22,11 @@ com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor
|
||||
com.tencent.supersonic.auth.authentication.interceptor.DefaultAuthenticationInterceptor
|
||||
com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
|
||||
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
|
||||
|
||||
com.tencent.supersonic.chat.api.component.DSLOptimizer=\
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.DateFieldCorrector, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.FieldCorrector, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.FunctionCorrector, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.TableNameCorrector, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.QueryFilterAppend, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.SelectFieldAppendCorrector
|
||||
@@ -648,6 +648,22 @@ CREATE TABLE IF NOT EXISTS `s2_plugin`
|
||||
COMMENT
|
||||
ON TABLE s2_plugin IS 'plugin information table';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS s2_agent
|
||||
(
|
||||
id int AUTO_INCREMENT,
|
||||
name varchar(100) null,
|
||||
description varchar(500) null,
|
||||
status int null,
|
||||
examples varchar(500) null,
|
||||
config varchar(2000) null,
|
||||
created_by varchar(100) null,
|
||||
created_at TIMESTAMP null,
|
||||
updated_by varchar(100) null,
|
||||
updated_at TIMESTAMP null,
|
||||
enable_search int null,
|
||||
PRIMARY KEY (`id`)
|
||||
); COMMENT ON TABLE s2_agent IS 'assistant information table';
|
||||
|
||||
|
||||
-------demo for semantic and chat
|
||||
CREATE TABLE IF NOT EXISTS `s2_user_department`
|
||||
|
||||
@@ -11,6 +11,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
import com.tencent.supersonic.chat.service.ConfigService;
|
||||
import com.tencent.supersonic.chat.service.QueryService;
|
||||
@@ -22,6 +23,7 @@ import org.junit.runner.RunWith;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.boot.test.mock.mockito.MockBean;
|
||||
import org.springframework.test.context.ActiveProfiles;
|
||||
import org.springframework.test.context.junit4.SpringRunner;
|
||||
|
||||
@@ -42,6 +44,8 @@ public class BaseQueryTest {
|
||||
protected ChatService chatService;
|
||||
@Autowired
|
||||
protected ConfigService configService;
|
||||
@MockBean
|
||||
protected AgentService agentService;
|
||||
|
||||
protected QueryResult submitMultiTurnChat(String queryText) throws Exception {
|
||||
ParseResp parseResp = submitParse(queryText);
|
||||
@@ -78,6 +82,11 @@ public class BaseQueryTest {
|
||||
return queryService.performParsing(queryContextReq);
|
||||
}
|
||||
|
||||
protected ParseResp submitParseWithAgent(String queryText, Integer agentId) {
|
||||
QueryReq queryContextReq = DataUtils.getQueryReqWithAgent(10, queryText, agentId);
|
||||
return queryService.performParsing(queryContextReq);
|
||||
}
|
||||
|
||||
protected void assertSchemaElements(Set<SchemaElement> expected, Set<SchemaElement> actual) {
|
||||
Set<String> expectedNames = expected.stream().map(s -> s.getName())
|
||||
.filter(s -> s != null).collect(Collectors.toSet());
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
package com.tencent.supersonic.integration;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.StandaloneLauncher;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.query.metricInterpret.LLmAnswerResp;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.service.QueryService;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.boot.test.mock.mockito.MockBean;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.test.context.ActiveProfiles;
|
||||
import org.springframework.test.context.junit4.SpringRunner;
|
||||
|
||||
@RunWith(SpringRunner.class)
|
||||
@SpringBootTest(classes = StandaloneLauncher.class)
|
||||
@ActiveProfiles("local")
|
||||
public class MetricInterpretTest {
|
||||
|
||||
@MockBean
|
||||
private AgentService agentService;
|
||||
|
||||
@MockBean
|
||||
private PluginManager pluginManager;
|
||||
|
||||
@MockBean
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
@Autowired
|
||||
@Qualifier("chatQueryService")
|
||||
private QueryService queryService;
|
||||
|
||||
@Test
|
||||
public void testMetricInterpret() throws Exception {
|
||||
MockConfiguration.mockAgent(agentService);
|
||||
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
|
||||
LLmAnswerResp lLmAnswerResp = new LLmAnswerResp();
|
||||
lLmAnswerResp.setAssistant_message("alice最近在超音数的访问情况有增多");
|
||||
MockConfiguration.mockPluginManagerDoRequest(pluginManager, "answer_with_plugin_call",
|
||||
ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
|
||||
QueryReq queryReq = DataUtils.getQueryReqWithAgent(1000, "能不能帮我解读分析下最近alice在超音数的访问情况",
|
||||
DataUtils.getAgent().getId());
|
||||
QueryResult queryResult = queryService.executeQuery(queryReq);
|
||||
Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistant_message());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,30 +1,31 @@
|
||||
package com.tencent.supersonic.integration;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricFilterQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricGroupByQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricTopNQuery;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import java.text.DateFormat;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
import java.text.DateFormat;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.*;
|
||||
|
||||
|
||||
public class MetricQueryTest extends BaseQueryTest {
|
||||
|
||||
@@ -50,6 +51,17 @@ public class MetricQueryTest extends BaseQueryTest {
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_METRIC_FILTER_with_agent() {
|
||||
//agent only support METRIC_ENTITY, METRIC_FILTER
|
||||
MockConfiguration.mockAgent(agentService);
|
||||
ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getAgent().getId());
|
||||
Assert.assertNotNull(parseResp.getSelectedParses());
|
||||
List<String> queryModes = parseResp.getSelectedParses().stream()
|
||||
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
|
||||
Assert.assertTrue(queryModes.contains("METRIC_FILTER"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_METRIC_DOMAIN() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("超音数的访问次数");
|
||||
@@ -69,6 +81,16 @@ public class MetricQueryTest extends BaseQueryTest {
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_METRIC_MODEL_with_agent() {
|
||||
//agent only support METRIC_ENTITY, METRIC_FILTER
|
||||
MockConfiguration.mockAgent(agentService);
|
||||
ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getAgent().getId());
|
||||
List<String> queryModes = parseResp.getSelectedParses().stream()
|
||||
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
|
||||
Assert.assertTrue(queryModes.contains("METRIC_MODEL"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_METRIC_GROUPBY() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("超音数各部门的访问次数");
|
||||
|
||||
@@ -1,22 +1,23 @@
|
||||
package com.tencent.supersonic.integration.plugin;
|
||||
package com.tencent.supersonic.integration;
|
||||
|
||||
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.ArgumentMatchers.notNull;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
|
||||
import com.tencent.supersonic.chat.parser.embedding.EmbeddingResp;
|
||||
import com.tencent.supersonic.chat.parser.embedding.RecallRetrieval;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.ArgumentMatchers.notNull;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@Configuration
|
||||
@Slf4j
|
||||
public class PluginMockConfiguration {
|
||||
public class MockConfiguration {
|
||||
|
||||
public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) {
|
||||
EmbeddingResp embeddingResp = new EmbeddingResp();
|
||||
@@ -33,9 +34,12 @@ public class PluginMockConfiguration {
|
||||
when(embeddingConfig.getUrl()).thenReturn("test");
|
||||
}
|
||||
|
||||
public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path,
|
||||
ResponseEntity<String> responseEntity) {
|
||||
public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path, ResponseEntity<String> responseEntity) {
|
||||
when(pluginManager.doRequest(eq(path), notNull(String.class))).thenReturn(responseEntity);
|
||||
}
|
||||
|
||||
public static void mockAgent(AgentService agentService) {
|
||||
when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,23 +1,23 @@
|
||||
package com.tencent.supersonic.integration.plugin;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.query.ContentInterpret.LLmAnswerResp;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.service.QueryService;
|
||||
import com.tencent.supersonic.integration.MockConfiguration;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.boot.test.mock.mockito.MockBean;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
|
||||
public class PluginRecognizeTest extends BasePluginTest {
|
||||
public class PluginRecognizeTest extends BasePluginTest{
|
||||
|
||||
@MockBean
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
@@ -25,14 +25,17 @@ public class PluginRecognizeTest extends BasePluginTest {
|
||||
@MockBean
|
||||
protected PluginManager pluginManager;
|
||||
|
||||
@MockBean
|
||||
protected AgentService agentService;
|
||||
|
||||
@Autowired
|
||||
@Qualifier("chatQueryService")
|
||||
private QueryService queryService;
|
||||
|
||||
@Test
|
||||
public void webPageRecognize() throws Exception {
|
||||
PluginMockConfiguration.mockEmbeddingRecognize(pluginManager, "alice最近的访问情况怎么样", "1");
|
||||
PluginMockConfiguration.mockEmbeddingUrl(embeddingConfig);
|
||||
MockConfiguration.mockEmbeddingRecognize(pluginManager, "alice最近的访问情况怎么样","1");
|
||||
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
|
||||
QueryReq queryContextReq = DataUtils.getQueryContextReq(1000, "alice最近的访问情况怎么样");
|
||||
QueryResult queryResult = queryService.executeQuery(queryContextReq);
|
||||
assertPluginRecognizeResult(queryResult);
|
||||
@@ -40,8 +43,8 @@ public class PluginRecognizeTest extends BasePluginTest {
|
||||
|
||||
@Test
|
||||
public void webPageRecognizeWithQueryFilter() throws Exception {
|
||||
PluginMockConfiguration.mockEmbeddingRecognize(pluginManager, "在超音数最近的情况怎么样", "1");
|
||||
PluginMockConfiguration.mockEmbeddingUrl(embeddingConfig);
|
||||
MockConfiguration.mockEmbeddingRecognize(pluginManager, "在超音数最近的情况怎么样","1");
|
||||
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
|
||||
QueryReq queryRequest = DataUtils.getQueryContextReq(1000, "在超音数最近的情况怎么样");
|
||||
QueryFilters queryFilters = new QueryFilters();
|
||||
QueryFilter queryFilter = new QueryFilter();
|
||||
@@ -55,17 +58,15 @@ public class PluginRecognizeTest extends BasePluginTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void contentInterpretRecognize() throws Exception {
|
||||
PluginMockConfiguration.mockEmbeddingRecognize(pluginManager, "超音数最近访问情况怎么样", "3");
|
||||
PluginMockConfiguration.mockEmbeddingUrl(embeddingConfig);
|
||||
LLmAnswerResp lLmAnswerResp = new LLmAnswerResp();
|
||||
lLmAnswerResp.setAssistant_message("超音数最近访问情况不错");
|
||||
PluginMockConfiguration.mockPluginManagerDoRequest(pluginManager, "answer_with_plugin_call",
|
||||
ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
|
||||
QueryReq queryRequest = DataUtils.getQueryContextReq(1000, "超音数最近访问情况怎么样");
|
||||
QueryResult queryResult = queryService.executeQuery(queryRequest);
|
||||
Assert.assertEquals(queryResult.getResponse(), lLmAnswerResp.getAssistant_message());
|
||||
System.out.println();
|
||||
public void pluginRecognizeWithAgent() {
|
||||
MockConfiguration.mockEmbeddingRecognize(pluginManager, "alice最近的访问情况怎么样","1");
|
||||
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
|
||||
MockConfiguration.mockAgent(agentService);
|
||||
QueryReq queryContextReq = DataUtils.getQueryReqWithAgent(1000, "alice最近的访问情况怎么样",
|
||||
DataUtils.getAgent().getId());
|
||||
ParseResp parseResp = queryService.performParsing(queryContextReq);
|
||||
Assert.assertTrue(parseResp.getSelectedParses() != null
|
||||
&& parseResp.getSelectedParses().size() > 0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
package com.tencent.supersonic.util;
|
||||
|
||||
import static java.time.LocalDate.now;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.agent.AgentConfig;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
|
||||
import com.tencent.supersonic.chat.agent.tool.PluginTool;
|
||||
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.parser.llm.interpret.MetricOption;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
import static java.time.LocalDate.now;
|
||||
|
||||
public class DataUtils {
|
||||
|
||||
private static final User user_test = new User(1L, "admin", "admin", "admin@email");
|
||||
@@ -27,6 +37,15 @@ public class DataUtils {
|
||||
return queryContextReq;
|
||||
}
|
||||
|
||||
public static QueryReq getQueryReqWithAgent(Integer id, String query, Integer agentId) {
|
||||
QueryReq queryReq = new QueryReq();
|
||||
queryReq.setQueryText(query);//"alice的访问次数"
|
||||
queryReq.setChatId(id);
|
||||
queryReq.setUser(user_test);
|
||||
queryReq.setAgentId(agentId);
|
||||
return queryReq;
|
||||
}
|
||||
|
||||
public static SchemaElement getSchemaElement(String name) {
|
||||
return SchemaElement.builder()
|
||||
.name(name)
|
||||
@@ -55,9 +74,8 @@ public class DataUtils {
|
||||
.build();
|
||||
}
|
||||
|
||||
public static QueryFilter getFilter(String bizName, FilterOperatorEnum filterOperatorEnum, Object value,
|
||||
String name,
|
||||
Long elementId) {
|
||||
public static QueryFilter getFilter(String bizName, FilterOperatorEnum filterOperatorEnum, Object value, String name,
|
||||
Long elementId) {
|
||||
QueryFilter filter = new QueryFilter();
|
||||
filter.setBizName(bizName);
|
||||
filter.setOperator(filterOperatorEnum);
|
||||
@@ -77,8 +95,7 @@ public class DataUtils {
|
||||
return dateInfo;
|
||||
}
|
||||
|
||||
public static DateConf getDateConf(DateConf.DateMode dateMode, Integer unit, String period, String startDate,
|
||||
String endDate) {
|
||||
public static DateConf getDateConf(DateConf.DateMode dateMode, Integer unit, String period, String startDate, String endDate) {
|
||||
DateConf dateInfo = new DateConf();
|
||||
dateInfo.setUnit(unit);
|
||||
dateInfo.setDateMode(dateMode);
|
||||
@@ -129,4 +146,43 @@ public class DataUtils {
|
||||
return dimensionFilterExist;
|
||||
}
|
||||
|
||||
|
||||
public static Agent getAgent() {
|
||||
Agent agent = new Agent();
|
||||
agent.setId(1);
|
||||
agent.setName("查信息");
|
||||
agent.setDescription("查信息");
|
||||
AgentConfig agentConfig = new AgentConfig();
|
||||
agentConfig.getTools().add(getRuleQueryTool());
|
||||
agentConfig.getTools().add(getPluginTool());
|
||||
agentConfig.getTools().add(getMetricInterpretTool());
|
||||
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
||||
return agent;
|
||||
}
|
||||
|
||||
private static RuleQueryTool getRuleQueryTool() {
|
||||
RuleQueryTool ruleQueryTool = new RuleQueryTool();
|
||||
ruleQueryTool.setType(AgentToolType.RULE);
|
||||
ruleQueryTool.setQueryModes(Lists.newArrayList("METRIC_ENTITY", "METRIC_FILTER", "METRIC_MODEL"));
|
||||
return ruleQueryTool;
|
||||
}
|
||||
|
||||
private static PluginTool getPluginTool() {
|
||||
PluginTool pluginTool = new PluginTool();
|
||||
pluginTool.setType(AgentToolType.PLUGIN);
|
||||
pluginTool.setPlugins(Lists.newArrayList(1L));
|
||||
return pluginTool;
|
||||
}
|
||||
|
||||
private static MetricInterpretTool getMetricInterpretTool() {
|
||||
MetricInterpretTool metricInterpretTool = new MetricInterpretTool();
|
||||
metricInterpretTool.setModelId(1L);
|
||||
metricInterpretTool.setType(AgentToolType.INTERPRET);
|
||||
metricInterpretTool.setMetricOptions(Lists.newArrayList(
|
||||
new MetricOption(1L),
|
||||
new MetricOption(2L),
|
||||
new MetricOption(3L)));
|
||||
return metricInterpretTool;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -3,8 +3,10 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\
|
||||
com.tencent.supersonic.chat.api.component.SemanticParser=\
|
||||
com.tencent.supersonic.chat.parser.rule.QueryModeParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.ContextInheritParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.AgentCheckParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.TimeRangeParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.AggregateTypeParser
|
||||
com.tencent.supersonic.chat.parser.rule.AggregateTypeParser, \
|
||||
com.tencent.supersonic.chat.parser.llm.interpret.MetricInterpretParser
|
||||
# com.tencent.supersonic.chat.parser.llm.DSLQueryFunction
|
||||
com.tencent.supersonic.chat.api.component.QueryProcessor=\
|
||||
com.tencent.supersonic.chat.application.processor.SemanticQueryProcessor
|
||||
|
||||
Reference in New Issue
Block a user