[improvement](supersonic) based on version 0.7.2 (#34)

Co-authored-by: zuopengge <hwzuopengge@tencent.com>
This commit is contained in:
mainmain
2023-08-20 17:30:35 +08:00
committed by GitHub
parent c93e60ced7
commit cf1b5336c3
122 changed files with 4045 additions and 1075 deletions

View File

@@ -6,9 +6,10 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\
com.tencent.supersonic.chat.api.component.SemanticParser=\
com.tencent.supersonic.chat.parser.rule.QueryModeParser, \
com.tencent.supersonic.chat.parser.rule.ContextInheritParser, \
com.tencent.supersonic.chat.parser.rule.AgentCheckParser, \
com.tencent.supersonic.chat.parser.rule.TimeRangeParser, \
com.tencent.supersonic.chat.parser.rule.AggregateTypeParser, \
com.tencent.supersonic.chat.parser.llm.LLMDSLParser, \
com.tencent.supersonic.chat.parser.llm.dsl.LLMDSLParser, \
com.tencent.supersonic.chat.parser.function.FunctionBasedParser
com.tencent.supersonic.chat.api.component.SemanticLayer=\
com.tencent.supersonic.knowledge.semantic.RemoteSemanticLayer
@@ -19,4 +20,13 @@ com.tencent.supersonic.chat.parser.function.ModelResolver=\
com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor=\
com.tencent.supersonic.auth.authentication.interceptor.DefaultAuthenticationInterceptor
com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
com.tencent.supersonic.chat.api.component.DSLOptimizer=\
com.tencent.supersonic.chat.query.dsl.optimizer.DateFieldCorrector, \
com.tencent.supersonic.chat.query.dsl.optimizer.FieldCorrector, \
com.tencent.supersonic.chat.query.dsl.optimizer.FunctionCorrector, \
com.tencent.supersonic.chat.query.dsl.optimizer.TableNameCorrector, \
com.tencent.supersonic.chat.query.dsl.optimizer.QueryFilterAppend, \
com.tencent.supersonic.chat.query.dsl.optimizer.SelectFieldAppendCorrector

View File

@@ -1,6 +1,12 @@
package com.tencent.supersonic;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.AgentConfig;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
@@ -14,16 +20,14 @@ import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.query.plugin.ParamOption;
import com.tencent.supersonic.chat.query.plugin.WebBase;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.chat.service.*;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.context.event.ApplicationReadyEvent;
import org.springframework.context.ApplicationListener;
import org.springframework.stereotype.Component;
@@ -32,6 +36,7 @@ import org.springframework.stereotype.Component;
@Slf4j
public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent> {
@Qualifier("chatQueryService")
@Autowired
private QueryService queryService;
@Autowired
@@ -40,6 +45,8 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
protected ConfigService configService;
@Autowired
private PluginService pluginService;
@Autowired
private AgentService agentService;
private User user = User.getFakeUser();
@@ -175,43 +182,25 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
pluginService.createPlugin(plugin_1, user);
}
private void addPlugin_2() {
Plugin plugin_2 = new Plugin();
plugin_2.setType("DSL");
plugin_2.setModelList(Arrays.asList(1L, 2L));
plugin_2.setPattern("");
plugin_2.setParseModeConfig(null);
plugin_2.setName("大模型语义解析");
List<String> examples = new ArrayList<>();
examples.add("超音数访问次数最高的部门是哪个");
examples.add("超音数访问人数最高的部门是哪个");
PluginParseConfig parseConfig = PluginParseConfig.builder()
.name("DSL")
.description("这个工具能够将用户的自然语言查询转化为SQL语句从而从数据库中的查询具体的数据。用于处理数据查询的问题提供基于事实的数据")
.examples(examples)
.build();
plugin_2.setParseModeConfig(JsonUtil.toString(parseConfig));
pluginService.createPlugin(plugin_2, user);
}
private void addPlugin_3() {
Plugin plugin_2 = new Plugin();
plugin_2.setType("CONTENT_INTERPRET");
plugin_2.setModelList(Arrays.asList(1L));
plugin_2.setPattern("超音数最近访问情况怎么样");
plugin_2.setParseModeConfig(null);
plugin_2.setName("内容解读");
List<String> examples = new ArrayList<>();
examples.add("超音数最近访问情况怎么样");
examples.add("超音数最近访问情况如何");
PluginParseConfig parseConfig = PluginParseConfig.builder()
.name("supersonic_content_interpret")
.description("这个工具能够先查询到相关的数据并交给大模型进行解读, 最后返回解读结果")
.examples(examples)
.build();
plugin_2.setParseModeConfig(JsonUtil.toString(parseConfig));
pluginService.createPlugin(plugin_2, user);
private void addAgent() {
Agent agent = new Agent();
agent.setId(1);
agent.setName("查信息");
agent.setDescription("查信息");
agent.setStatus(1);
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("超音数访问次数", "超音数访问人数", "alice 停留时长"));
AgentConfig agentConfig = new AgentConfig();
RuleQueryTool ruleQueryTool = new RuleQueryTool();
ruleQueryTool.setType(AgentToolType.RULE);
ruleQueryTool.setQueryModes(Lists.newArrayList(
"ENTITY_DETAIL", "ENTITY_LIST_FILTER", "ENTITY_ID", "METRIC_ENTITY",
"METRIC_FILTER", "METRIC_GROUPBY", "METRIC_MODEL", "METRIC_ORDERBY"
));
agentConfig.getTools().add(ruleQueryTool);
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
agentService.createAgent(agent, User.getFakeUser());
}
@Override
@@ -220,8 +209,7 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
addDemoChatConfig_1();
addDemoChatConfig_2();
addPlugin_1();
addPlugin_2();
addPlugin_3();
addAgent();
addSampleChats();
addSampleChats2();
} catch (Exception e) {

View File

@@ -6,9 +6,10 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\
com.tencent.supersonic.chat.api.component.SemanticParser=\
com.tencent.supersonic.chat.parser.rule.QueryModeParser, \
com.tencent.supersonic.chat.parser.rule.ContextInheritParser, \
com.tencent.supersonic.chat.parser.rule.AgentCheckParser, \
com.tencent.supersonic.chat.parser.rule.TimeRangeParser, \
com.tencent.supersonic.chat.parser.rule.AggregateTypeParser, \
com.tencent.supersonic.chat.parser.llm.LLMDSLParser, \
com.tencent.supersonic.chat.parser.llm.dsl.LLMDSLParser, \
com.tencent.supersonic.chat.parser.embedding.EmbeddingBasedParser, \
com.tencent.supersonic.chat.parser.function.FunctionBasedParser
com.tencent.supersonic.chat.api.component.SemanticLayer=\
@@ -21,3 +22,11 @@ com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor
com.tencent.supersonic.auth.authentication.interceptor.DefaultAuthenticationInterceptor
com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
com.tencent.supersonic.chat.api.component.DSLOptimizer=\
com.tencent.supersonic.chat.query.dsl.optimizer.DateFieldCorrector, \
com.tencent.supersonic.chat.query.dsl.optimizer.FieldCorrector, \
com.tencent.supersonic.chat.query.dsl.optimizer.FunctionCorrector, \
com.tencent.supersonic.chat.query.dsl.optimizer.TableNameCorrector, \
com.tencent.supersonic.chat.query.dsl.optimizer.QueryFilterAppend, \
com.tencent.supersonic.chat.query.dsl.optimizer.SelectFieldAppendCorrector

View File

@@ -648,6 +648,22 @@ CREATE TABLE IF NOT EXISTS `s2_plugin`
COMMENT
ON TABLE s2_plugin IS 'plugin information table';
CREATE TABLE IF NOT EXISTS s2_agent
(
id int AUTO_INCREMENT,
name varchar(100) null,
description varchar(500) null,
status int null,
examples varchar(500) null,
config varchar(2000) null,
created_by varchar(100) null,
created_at TIMESTAMP null,
updated_by varchar(100) null,
updated_at TIMESTAMP null,
enable_search int null,
PRIMARY KEY (`id`)
); COMMENT ON TABLE s2_agent IS 'assistant information table';
-------demo for semantic and chat
CREATE TABLE IF NOT EXISTS `s2_user_department`

View File

@@ -11,6 +11,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.service.QueryService;
@@ -22,6 +23,7 @@ import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.junit4.SpringRunner;
@@ -42,6 +44,8 @@ public class BaseQueryTest {
protected ChatService chatService;
@Autowired
protected ConfigService configService;
@MockBean
protected AgentService agentService;
protected QueryResult submitMultiTurnChat(String queryText) throws Exception {
ParseResp parseResp = submitParse(queryText);
@@ -78,6 +82,11 @@ public class BaseQueryTest {
return queryService.performParsing(queryContextReq);
}
protected ParseResp submitParseWithAgent(String queryText, Integer agentId) {
QueryReq queryContextReq = DataUtils.getQueryReqWithAgent(10, queryText, agentId);
return queryService.performParsing(queryContextReq);
}
protected void assertSchemaElements(Set<SchemaElement> expected, Set<SchemaElement> actual) {
Set<String> expectedNames = expected.stream().map(s -> s.getName())
.filter(s -> s != null).collect(Collectors.toSet());

View File

@@ -0,0 +1,56 @@
package com.tencent.supersonic.integration;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.StandaloneLauncher;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.metricInterpret.LLmAnswerResp;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.util.DataUtils;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.http.ResponseEntity;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.junit4.SpringRunner;
@RunWith(SpringRunner.class)
@SpringBootTest(classes = StandaloneLauncher.class)
@ActiveProfiles("local")
public class MetricInterpretTest {
@MockBean
private AgentService agentService;
@MockBean
private PluginManager pluginManager;
@MockBean
private EmbeddingConfig embeddingConfig;
@Autowired
@Qualifier("chatQueryService")
private QueryService queryService;
@Test
public void testMetricInterpret() throws Exception {
MockConfiguration.mockAgent(agentService);
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
LLmAnswerResp lLmAnswerResp = new LLmAnswerResp();
lLmAnswerResp.setAssistant_message("alice最近在超音数的访问情况有增多");
MockConfiguration.mockPluginManagerDoRequest(pluginManager, "answer_with_plugin_call",
ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
QueryReq queryReq = DataUtils.getQueryReqWithAgent(1000, "能不能帮我解读分析下最近alice在超音数的访问情况",
DataUtils.getAgent().getId());
QueryResult queryResult = queryService.executeQuery(queryReq);
Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistant_message());
}
}

View File

@@ -1,30 +1,31 @@
package com.tencent.supersonic.integration;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricFilterQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricGroupByQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricTopNQuery;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import com.tencent.supersonic.util.DataUtils;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.Assert;
import org.junit.Test;
import org.springframework.beans.BeanUtils;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.stream.Collectors;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.*;
public class MetricQueryTest extends BaseQueryTest {
@@ -50,6 +51,17 @@ public class MetricQueryTest extends BaseQueryTest {
assertQueryResult(expectedResult, actualResult);
}
@Test
public void queryTest_METRIC_FILTER_with_agent() {
//agent only support METRIC_ENTITY, METRIC_FILTER
MockConfiguration.mockAgent(agentService);
ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getAgent().getId());
Assert.assertNotNull(parseResp.getSelectedParses());
List<String> queryModes = parseResp.getSelectedParses().stream()
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
Assert.assertTrue(queryModes.contains("METRIC_FILTER"));
}
@Test
public void queryTest_METRIC_DOMAIN() throws Exception {
QueryResult actualResult = submitNewChat("超音数的访问次数");
@@ -69,6 +81,16 @@ public class MetricQueryTest extends BaseQueryTest {
assertQueryResult(expectedResult, actualResult);
}
@Test
public void queryTest_METRIC_MODEL_with_agent() {
//agent only support METRIC_ENTITY, METRIC_FILTER
MockConfiguration.mockAgent(agentService);
ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getAgent().getId());
List<String> queryModes = parseResp.getSelectedParses().stream()
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
Assert.assertTrue(queryModes.contains("METRIC_MODEL"));
}
@Test
public void queryTest_METRIC_GROUPBY() throws Exception {
QueryResult actualResult = submitNewChat("超音数各部门的访问次数");

View File

@@ -1,22 +1,23 @@
package com.tencent.supersonic.integration.plugin;
package com.tencent.supersonic.integration;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.notNull;
import static org.mockito.Mockito.when;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingResp;
import com.tencent.supersonic.chat.parser.embedding.RecallRetrieval;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.util.DataUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.ResponseEntity;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.notNull;
import static org.mockito.Mockito.when;
@Configuration
@Slf4j
public class PluginMockConfiguration {
public class MockConfiguration {
public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) {
EmbeddingResp embeddingResp = new EmbeddingResp();
@@ -33,9 +34,12 @@ public class PluginMockConfiguration {
when(embeddingConfig.getUrl()).thenReturn("test");
}
public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path,
ResponseEntity<String> responseEntity) {
public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path, ResponseEntity<String> responseEntity) {
when(pluginManager.doRequest(eq(path), notNull(String.class))).thenReturn(responseEntity);
}
public static void mockAgent(AgentService agentService) {
when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent());
}
}

View File

@@ -1,23 +1,23 @@
package com.tencent.supersonic.integration.plugin;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.ContentInterpret.LLmAnswerResp;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.integration.MockConfiguration;
import com.tencent.supersonic.util.DataUtils;
import org.junit.Assert;
import org.junit.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.http.ResponseEntity;
public class PluginRecognizeTest extends BasePluginTest {
public class PluginRecognizeTest extends BasePluginTest{
@MockBean
private EmbeddingConfig embeddingConfig;
@@ -25,14 +25,17 @@ public class PluginRecognizeTest extends BasePluginTest {
@MockBean
protected PluginManager pluginManager;
@MockBean
protected AgentService agentService;
@Autowired
@Qualifier("chatQueryService")
private QueryService queryService;
@Test
public void webPageRecognize() throws Exception {
PluginMockConfiguration.mockEmbeddingRecognize(pluginManager, "alice最近的访问情况怎么样", "1");
PluginMockConfiguration.mockEmbeddingUrl(embeddingConfig);
MockConfiguration.mockEmbeddingRecognize(pluginManager, "alice最近的访问情况怎么样","1");
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
QueryReq queryContextReq = DataUtils.getQueryContextReq(1000, "alice最近的访问情况怎么样");
QueryResult queryResult = queryService.executeQuery(queryContextReq);
assertPluginRecognizeResult(queryResult);
@@ -40,8 +43,8 @@ public class PluginRecognizeTest extends BasePluginTest {
@Test
public void webPageRecognizeWithQueryFilter() throws Exception {
PluginMockConfiguration.mockEmbeddingRecognize(pluginManager, "在超音数最近的情况怎么样", "1");
PluginMockConfiguration.mockEmbeddingUrl(embeddingConfig);
MockConfiguration.mockEmbeddingRecognize(pluginManager, "在超音数最近的情况怎么样","1");
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
QueryReq queryRequest = DataUtils.getQueryContextReq(1000, "在超音数最近的情况怎么样");
QueryFilters queryFilters = new QueryFilters();
QueryFilter queryFilter = new QueryFilter();
@@ -55,17 +58,15 @@ public class PluginRecognizeTest extends BasePluginTest {
}
@Test
public void contentInterpretRecognize() throws Exception {
PluginMockConfiguration.mockEmbeddingRecognize(pluginManager, "超音数最近访问情况怎么样", "3");
PluginMockConfiguration.mockEmbeddingUrl(embeddingConfig);
LLmAnswerResp lLmAnswerResp = new LLmAnswerResp();
lLmAnswerResp.setAssistant_message("超音数最近访问情况不错");
PluginMockConfiguration.mockPluginManagerDoRequest(pluginManager, "answer_with_plugin_call",
ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
QueryReq queryRequest = DataUtils.getQueryContextReq(1000, "超音数最近访问情况怎么样");
QueryResult queryResult = queryService.executeQuery(queryRequest);
Assert.assertEquals(queryResult.getResponse(), lLmAnswerResp.getAssistant_message());
System.out.println();
public void pluginRecognizeWithAgent() {
MockConfiguration.mockEmbeddingRecognize(pluginManager, "alice最近访问情况怎么样","1");
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
MockConfiguration.mockAgent(agentService);
QueryReq queryContextReq = DataUtils.getQueryReqWithAgent(1000, "alice最近访问情况怎么样",
DataUtils.getAgent().getId());
ParseResp parseResp = queryService.performParsing(queryContextReq);
Assert.assertTrue(parseResp.getSelectedParses() != null
&& parseResp.getSelectedParses().size() > 0);
}
}

View File

@@ -1,16 +1,26 @@
package com.tencent.supersonic.util;
import static java.time.LocalDate.now;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.AgentConfig;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
import com.tencent.supersonic.chat.agent.tool.PluginTool;
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.llm.interpret.MetricOption;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import java.util.Set;
import static java.time.LocalDate.now;
public class DataUtils {
private static final User user_test = new User(1L, "admin", "admin", "admin@email");
@@ -27,6 +37,15 @@ public class DataUtils {
return queryContextReq;
}
public static QueryReq getQueryReqWithAgent(Integer id, String query, Integer agentId) {
QueryReq queryReq = new QueryReq();
queryReq.setQueryText(query);//"alice的访问次数"
queryReq.setChatId(id);
queryReq.setUser(user_test);
queryReq.setAgentId(agentId);
return queryReq;
}
public static SchemaElement getSchemaElement(String name) {
return SchemaElement.builder()
.name(name)
@@ -55,9 +74,8 @@ public class DataUtils {
.build();
}
public static QueryFilter getFilter(String bizName, FilterOperatorEnum filterOperatorEnum, Object value,
String name,
Long elementId) {
public static QueryFilter getFilter(String bizName, FilterOperatorEnum filterOperatorEnum, Object value, String name,
Long elementId) {
QueryFilter filter = new QueryFilter();
filter.setBizName(bizName);
filter.setOperator(filterOperatorEnum);
@@ -77,8 +95,7 @@ public class DataUtils {
return dateInfo;
}
public static DateConf getDateConf(DateConf.DateMode dateMode, Integer unit, String period, String startDate,
String endDate) {
public static DateConf getDateConf(DateConf.DateMode dateMode, Integer unit, String period, String startDate, String endDate) {
DateConf dateInfo = new DateConf();
dateInfo.setUnit(unit);
dateInfo.setDateMode(dateMode);
@@ -129,4 +146,43 @@ public class DataUtils {
return dimensionFilterExist;
}
public static Agent getAgent() {
Agent agent = new Agent();
agent.setId(1);
agent.setName("查信息");
agent.setDescription("查信息");
AgentConfig agentConfig = new AgentConfig();
agentConfig.getTools().add(getRuleQueryTool());
agentConfig.getTools().add(getPluginTool());
agentConfig.getTools().add(getMetricInterpretTool());
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
return agent;
}
private static RuleQueryTool getRuleQueryTool() {
RuleQueryTool ruleQueryTool = new RuleQueryTool();
ruleQueryTool.setType(AgentToolType.RULE);
ruleQueryTool.setQueryModes(Lists.newArrayList("METRIC_ENTITY", "METRIC_FILTER", "METRIC_MODEL"));
return ruleQueryTool;
}
private static PluginTool getPluginTool() {
PluginTool pluginTool = new PluginTool();
pluginTool.setType(AgentToolType.PLUGIN);
pluginTool.setPlugins(Lists.newArrayList(1L));
return pluginTool;
}
private static MetricInterpretTool getMetricInterpretTool() {
MetricInterpretTool metricInterpretTool = new MetricInterpretTool();
metricInterpretTool.setModelId(1L);
metricInterpretTool.setType(AgentToolType.INTERPRET);
metricInterpretTool.setMetricOptions(Lists.newArrayList(
new MetricOption(1L),
new MetricOption(2L),
new MetricOption(3L)));
return metricInterpretTool;
}
}

View File

@@ -3,8 +3,10 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\
com.tencent.supersonic.chat.api.component.SemanticParser=\
com.tencent.supersonic.chat.parser.rule.QueryModeParser, \
com.tencent.supersonic.chat.parser.rule.ContextInheritParser, \
com.tencent.supersonic.chat.parser.rule.AgentCheckParser, \
com.tencent.supersonic.chat.parser.rule.TimeRangeParser, \
com.tencent.supersonic.chat.parser.rule.AggregateTypeParser
com.tencent.supersonic.chat.parser.rule.AggregateTypeParser, \
com.tencent.supersonic.chat.parser.llm.interpret.MetricInterpretParser
# com.tencent.supersonic.chat.parser.llm.DSLQueryFunction
com.tencent.supersonic.chat.api.component.QueryProcessor=\
com.tencent.supersonic.chat.application.processor.SemanticQueryProcessor