[improvement](supersonic) based on version 0.7.2 (#34)

Co-authored-by: zuopengge <hwzuopengge@tencent.com>
This commit is contained in:
mainmain
2023-08-20 17:30:35 +08:00
committed by GitHub
parent c93e60ced7
commit cf1b5336c3
122 changed files with 4045 additions and 1075 deletions

View File

@@ -11,6 +11,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.service.QueryService;
@@ -22,6 +23,7 @@ import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.junit4.SpringRunner;
@@ -42,6 +44,8 @@ public class BaseQueryTest {
protected ChatService chatService;
@Autowired
protected ConfigService configService;
@MockBean
protected AgentService agentService;
protected QueryResult submitMultiTurnChat(String queryText) throws Exception {
ParseResp parseResp = submitParse(queryText);
@@ -78,6 +82,11 @@ public class BaseQueryTest {
return queryService.performParsing(queryContextReq);
}
protected ParseResp submitParseWithAgent(String queryText, Integer agentId) {
QueryReq queryContextReq = DataUtils.getQueryReqWithAgent(10, queryText, agentId);
return queryService.performParsing(queryContextReq);
}
protected void assertSchemaElements(Set<SchemaElement> expected, Set<SchemaElement> actual) {
Set<String> expectedNames = expected.stream().map(s -> s.getName())
.filter(s -> s != null).collect(Collectors.toSet());

View File

@@ -0,0 +1,56 @@
package com.tencent.supersonic.integration;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.StandaloneLauncher;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.metricInterpret.LLmAnswerResp;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.util.DataUtils;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.http.ResponseEntity;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.junit4.SpringRunner;
@RunWith(SpringRunner.class)
@SpringBootTest(classes = StandaloneLauncher.class)
@ActiveProfiles("local")
public class MetricInterpretTest {
@MockBean
private AgentService agentService;
@MockBean
private PluginManager pluginManager;
@MockBean
private EmbeddingConfig embeddingConfig;
@Autowired
@Qualifier("chatQueryService")
private QueryService queryService;
@Test
public void testMetricInterpret() throws Exception {
MockConfiguration.mockAgent(agentService);
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
LLmAnswerResp lLmAnswerResp = new LLmAnswerResp();
lLmAnswerResp.setAssistant_message("alice最近在超音数的访问情况有增多");
MockConfiguration.mockPluginManagerDoRequest(pluginManager, "answer_with_plugin_call",
ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
QueryReq queryReq = DataUtils.getQueryReqWithAgent(1000, "能不能帮我解读分析下最近alice在超音数的访问情况",
DataUtils.getAgent().getId());
QueryResult queryResult = queryService.executeQuery(queryReq);
Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistant_message());
}
}

View File

@@ -1,30 +1,31 @@
package com.tencent.supersonic.integration;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricFilterQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricGroupByQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricTopNQuery;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import com.tencent.supersonic.util.DataUtils;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.Assert;
import org.junit.Test;
import org.springframework.beans.BeanUtils;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.stream.Collectors;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.*;
public class MetricQueryTest extends BaseQueryTest {
@@ -50,6 +51,17 @@ public class MetricQueryTest extends BaseQueryTest {
assertQueryResult(expectedResult, actualResult);
}
@Test
public void queryTest_METRIC_FILTER_with_agent() {
//agent only support METRIC_ENTITY, METRIC_FILTER
MockConfiguration.mockAgent(agentService);
ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getAgent().getId());
Assert.assertNotNull(parseResp.getSelectedParses());
List<String> queryModes = parseResp.getSelectedParses().stream()
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
Assert.assertTrue(queryModes.contains("METRIC_FILTER"));
}
@Test
public void queryTest_METRIC_DOMAIN() throws Exception {
QueryResult actualResult = submitNewChat("超音数的访问次数");
@@ -69,6 +81,16 @@ public class MetricQueryTest extends BaseQueryTest {
assertQueryResult(expectedResult, actualResult);
}
@Test
public void queryTest_METRIC_MODEL_with_agent() {
//agent only support METRIC_ENTITY, METRIC_FILTER
MockConfiguration.mockAgent(agentService);
ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getAgent().getId());
List<String> queryModes = parseResp.getSelectedParses().stream()
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
Assert.assertTrue(queryModes.contains("METRIC_MODEL"));
}
@Test
public void queryTest_METRIC_GROUPBY() throws Exception {
QueryResult actualResult = submitNewChat("超音数各部门的访问次数");

View File

@@ -1,22 +1,23 @@
package com.tencent.supersonic.integration.plugin;
package com.tencent.supersonic.integration;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.notNull;
import static org.mockito.Mockito.when;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingResp;
import com.tencent.supersonic.chat.parser.embedding.RecallRetrieval;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.util.DataUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.ResponseEntity;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.notNull;
import static org.mockito.Mockito.when;
@Configuration
@Slf4j
public class PluginMockConfiguration {
public class MockConfiguration {
public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) {
EmbeddingResp embeddingResp = new EmbeddingResp();
@@ -33,9 +34,12 @@ public class PluginMockConfiguration {
when(embeddingConfig.getUrl()).thenReturn("test");
}
public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path,
ResponseEntity<String> responseEntity) {
public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path, ResponseEntity<String> responseEntity) {
when(pluginManager.doRequest(eq(path), notNull(String.class))).thenReturn(responseEntity);
}
public static void mockAgent(AgentService agentService) {
when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent());
}
}

View File

@@ -1,23 +1,23 @@
package com.tencent.supersonic.integration.plugin;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.ContentInterpret.LLmAnswerResp;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.integration.MockConfiguration;
import com.tencent.supersonic.util.DataUtils;
import org.junit.Assert;
import org.junit.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.http.ResponseEntity;
public class PluginRecognizeTest extends BasePluginTest {
public class PluginRecognizeTest extends BasePluginTest{
@MockBean
private EmbeddingConfig embeddingConfig;
@@ -25,14 +25,17 @@ public class PluginRecognizeTest extends BasePluginTest {
@MockBean
protected PluginManager pluginManager;
@MockBean
protected AgentService agentService;
@Autowired
@Qualifier("chatQueryService")
private QueryService queryService;
@Test
public void webPageRecognize() throws Exception {
PluginMockConfiguration.mockEmbeddingRecognize(pluginManager, "alice最近的访问情况怎么样", "1");
PluginMockConfiguration.mockEmbeddingUrl(embeddingConfig);
MockConfiguration.mockEmbeddingRecognize(pluginManager, "alice最近的访问情况怎么样","1");
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
QueryReq queryContextReq = DataUtils.getQueryContextReq(1000, "alice最近的访问情况怎么样");
QueryResult queryResult = queryService.executeQuery(queryContextReq);
assertPluginRecognizeResult(queryResult);
@@ -40,8 +43,8 @@ public class PluginRecognizeTest extends BasePluginTest {
@Test
public void webPageRecognizeWithQueryFilter() throws Exception {
PluginMockConfiguration.mockEmbeddingRecognize(pluginManager, "在超音数最近的情况怎么样", "1");
PluginMockConfiguration.mockEmbeddingUrl(embeddingConfig);
MockConfiguration.mockEmbeddingRecognize(pluginManager, "在超音数最近的情况怎么样","1");
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
QueryReq queryRequest = DataUtils.getQueryContextReq(1000, "在超音数最近的情况怎么样");
QueryFilters queryFilters = new QueryFilters();
QueryFilter queryFilter = new QueryFilter();
@@ -55,17 +58,15 @@ public class PluginRecognizeTest extends BasePluginTest {
}
@Test
public void contentInterpretRecognize() throws Exception {
PluginMockConfiguration.mockEmbeddingRecognize(pluginManager, "超音数最近访问情况怎么样", "3");
PluginMockConfiguration.mockEmbeddingUrl(embeddingConfig);
LLmAnswerResp lLmAnswerResp = new LLmAnswerResp();
lLmAnswerResp.setAssistant_message("超音数最近访问情况不错");
PluginMockConfiguration.mockPluginManagerDoRequest(pluginManager, "answer_with_plugin_call",
ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
QueryReq queryRequest = DataUtils.getQueryContextReq(1000, "超音数最近访问情况怎么样");
QueryResult queryResult = queryService.executeQuery(queryRequest);
Assert.assertEquals(queryResult.getResponse(), lLmAnswerResp.getAssistant_message());
System.out.println();
public void pluginRecognizeWithAgent() {
MockConfiguration.mockEmbeddingRecognize(pluginManager, "alice最近访问情况怎么样","1");
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
MockConfiguration.mockAgent(agentService);
QueryReq queryContextReq = DataUtils.getQueryReqWithAgent(1000, "alice最近访问情况怎么样",
DataUtils.getAgent().getId());
ParseResp parseResp = queryService.performParsing(queryContextReq);
Assert.assertTrue(parseResp.getSelectedParses() != null
&& parseResp.getSelectedParses().size() > 0);
}
}

View File

@@ -1,16 +1,26 @@
package com.tencent.supersonic.util;
import static java.time.LocalDate.now;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.AgentConfig;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
import com.tencent.supersonic.chat.agent.tool.PluginTool;
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.llm.interpret.MetricOption;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import java.util.Set;
import static java.time.LocalDate.now;
public class DataUtils {
private static final User user_test = new User(1L, "admin", "admin", "admin@email");
@@ -27,6 +37,15 @@ public class DataUtils {
return queryContextReq;
}
public static QueryReq getQueryReqWithAgent(Integer id, String query, Integer agentId) {
QueryReq queryReq = new QueryReq();
queryReq.setQueryText(query);//"alice的访问次数"
queryReq.setChatId(id);
queryReq.setUser(user_test);
queryReq.setAgentId(agentId);
return queryReq;
}
public static SchemaElement getSchemaElement(String name) {
return SchemaElement.builder()
.name(name)
@@ -55,9 +74,8 @@ public class DataUtils {
.build();
}
public static QueryFilter getFilter(String bizName, FilterOperatorEnum filterOperatorEnum, Object value,
String name,
Long elementId) {
public static QueryFilter getFilter(String bizName, FilterOperatorEnum filterOperatorEnum, Object value, String name,
Long elementId) {
QueryFilter filter = new QueryFilter();
filter.setBizName(bizName);
filter.setOperator(filterOperatorEnum);
@@ -77,8 +95,7 @@ public class DataUtils {
return dateInfo;
}
public static DateConf getDateConf(DateConf.DateMode dateMode, Integer unit, String period, String startDate,
String endDate) {
public static DateConf getDateConf(DateConf.DateMode dateMode, Integer unit, String period, String startDate, String endDate) {
DateConf dateInfo = new DateConf();
dateInfo.setUnit(unit);
dateInfo.setDateMode(dateMode);
@@ -129,4 +146,43 @@ public class DataUtils {
return dimensionFilterExist;
}
public static Agent getAgent() {
Agent agent = new Agent();
agent.setId(1);
agent.setName("查信息");
agent.setDescription("查信息");
AgentConfig agentConfig = new AgentConfig();
agentConfig.getTools().add(getRuleQueryTool());
agentConfig.getTools().add(getPluginTool());
agentConfig.getTools().add(getMetricInterpretTool());
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
return agent;
}
private static RuleQueryTool getRuleQueryTool() {
RuleQueryTool ruleQueryTool = new RuleQueryTool();
ruleQueryTool.setType(AgentToolType.RULE);
ruleQueryTool.setQueryModes(Lists.newArrayList("METRIC_ENTITY", "METRIC_FILTER", "METRIC_MODEL"));
return ruleQueryTool;
}
private static PluginTool getPluginTool() {
PluginTool pluginTool = new PluginTool();
pluginTool.setType(AgentToolType.PLUGIN);
pluginTool.setPlugins(Lists.newArrayList(1L));
return pluginTool;
}
private static MetricInterpretTool getMetricInterpretTool() {
MetricInterpretTool metricInterpretTool = new MetricInterpretTool();
metricInterpretTool.setModelId(1L);
metricInterpretTool.setType(AgentToolType.INTERPRET);
metricInterpretTool.setMetricOptions(Lists.newArrayList(
new MetricOption(1L),
new MetricOption(2L),
new MetricOption(3L)));
return metricInterpretTool;
}
}

View File

@@ -3,8 +3,10 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\
com.tencent.supersonic.chat.api.component.SemanticParser=\
com.tencent.supersonic.chat.parser.rule.QueryModeParser, \
com.tencent.supersonic.chat.parser.rule.ContextInheritParser, \
com.tencent.supersonic.chat.parser.rule.AgentCheckParser, \
com.tencent.supersonic.chat.parser.rule.TimeRangeParser, \
com.tencent.supersonic.chat.parser.rule.AggregateTypeParser
com.tencent.supersonic.chat.parser.rule.AggregateTypeParser, \
com.tencent.supersonic.chat.parser.llm.interpret.MetricInterpretParser
# com.tencent.supersonic.chat.parser.llm.DSLQueryFunction
com.tencent.supersonic.chat.api.component.QueryProcessor=\
com.tencent.supersonic.chat.application.processor.SemanticQueryProcessor