[release][project] supersonic 0.7.3 version backend update (#40)

* [improvement] add some features

* [improvement] revise CHANGELOG

---------

Co-authored-by: zuopengge <hwzuopengge@tencent.com>
This commit is contained in:
mainmain
2023-08-29 20:06:34 +08:00
committed by GitHub
parent 6fe9ab79ed
commit e1911bc81b
260 changed files with 6466 additions and 7108 deletions

View File

@@ -51,6 +51,8 @@ public class BaseQueryTest {
ParseResp parseResp = submitParse(queryText);
ExecuteQueryReq request = new ExecuteQueryReq();
request.setQueryId(parseResp.getQueryId());
request.setParseId(parseResp.getSelectedParses().get(0).getId());
request.setChatId(parseResp.getChatId());
request.setQueryText(parseResp.getQueryText());
request.setUser(DataUtils.getUser());
@@ -63,6 +65,8 @@ public class BaseQueryTest {
ParseResp parseResp = submitParse(queryText);
ExecuteQueryReq request = new ExecuteQueryReq();
request.setQueryId(parseResp.getQueryId());
request.setParseId(parseResp.getSelectedParses().get(0).getId());
request.setChatId(parseResp.getChatId());
request.setQueryText(parseResp.getQueryText());
request.setUser(DataUtils.getUser());

View File

@@ -18,7 +18,7 @@ import org.junit.Test;
public class EntityQueryTest extends BaseQueryTest {
@Test
public void queryTest_METRIC_ENTITY_QUERY() throws Exception {
public void queryTest_metric_entity_query() throws Exception {
QueryResult actualResult = submitNewChat("艺人周杰伦的播放量");
QueryResult expectedResult = new QueryResult();
@@ -41,7 +41,7 @@ public class EntityQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_ENTITY_LIST_FILTER() throws Exception {
public void queryTest_entity_list_filter() throws Exception {
QueryResult actualResult = submitNewChat("爱情、流行类型的艺人");
QueryResult expectedResult = new QueryResult();

View File

@@ -4,9 +4,9 @@ import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.StandaloneLauncher;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.metricInterpret.LLmAnswerResp;
import com.tencent.supersonic.chat.query.metricinterpret.LLmAnswerResp;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.util.DataUtils;
@@ -44,13 +44,13 @@ public class MetricInterpretTest {
MockConfiguration.mockAgent(agentService);
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
LLmAnswerResp lLmAnswerResp = new LLmAnswerResp();
lLmAnswerResp.setAssistant_message("alice最近在超音数的访问情况有增多");
lLmAnswerResp.setAssistantMessage("alice最近在超音数的访问情况有增多");
MockConfiguration.mockPluginManagerDoRequest(pluginManager, "answer_with_plugin_call",
ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
QueryReq queryReq = DataUtils.getQueryReqWithAgent(1000, "能不能帮我解读分析下最近alice在超音数的访问情况",
DataUtils.getAgent().getId());
QueryResult queryResult = queryService.executeQuery(queryReq);
Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistant_message());
Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistantMessage());
}
}

View File

@@ -21,16 +21,18 @@ import org.springframework.beans.BeanUtils;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.List;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.stream.Collectors;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.*;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
public class MetricQueryTest extends BaseQueryTest {
@Test
public void queryTest_METRIC_FILTER() throws Exception {
public void queryTest_metric_filter() throws Exception {
QueryResult actualResult = submitNewChat("alice的访问次数");
QueryResult expectedResult = new QueryResult();
@@ -52,7 +54,7 @@ public class MetricQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_METRIC_FILTER_with_agent() {
public void queryTest_metric_filter_with_agent() {
//agent only support METRIC_ENTITY, METRIC_FILTER
MockConfiguration.mockAgent(agentService);
ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getAgent().getId());
@@ -63,7 +65,7 @@ public class MetricQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_METRIC_DOMAIN() throws Exception {
public void queryTest_metric_domain() throws Exception {
QueryResult actualResult = submitNewChat("超音数的访问次数");
QueryResult expectedResult = new QueryResult();
@@ -82,7 +84,7 @@ public class MetricQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_METRIC_MODEL_with_agent() {
public void queryTest_metric_model_with_agent() {
//agent only support METRIC_ENTITY, METRIC_FILTER
MockConfiguration.mockAgent(agentService);
ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getAgent().getId());
@@ -92,7 +94,7 @@ public class MetricQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_METRIC_GROUPBY() throws Exception {
public void queryTest_metric_groupby() throws Exception {
QueryResult actualResult = submitNewChat("超音数各部门的访问次数");
QueryResult expectedResult = new QueryResult();
@@ -112,7 +114,7 @@ public class MetricQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_METRIC_FILTER_COMPARE() throws Exception {
public void queryTest_metric_filter_compare() throws Exception {
QueryResult actualResult = submitNewChat("对比alice和lucy的访问次数");
QueryResult expectedResult = new QueryResult();
@@ -137,7 +139,7 @@ public class MetricQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_METRIC_TOPN() throws Exception {
public void queryTest_metric_topn() throws Exception {
QueryResult actualResult = submitNewChat("近3天访问次数最多的用户");
QueryResult expectedResult = new QueryResult();
@@ -157,7 +159,7 @@ public class MetricQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_METRIC_GROUPBY_SUM() throws Exception {
public void queryTest_metric_groupby_sum() throws Exception {
QueryResult actualResult = submitNewChat("超音数各部门的访问次数总和");
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -176,7 +178,7 @@ public class MetricQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_METRIC_FILTER_TIME() throws Exception {
public void queryTest_metric_filter_time() throws Exception {
DateFormat format = new SimpleDateFormat("yyyy-mm-dd");
DateFormat textFormat = new SimpleDateFormat("yyyy年mm月dd日");
String dateStr = textFormat.format(format.parse(startDay));
@@ -202,7 +204,7 @@ public class MetricQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_CONFIG_VISIBILITY() throws Exception {
public void queryTest_config_visibility() throws Exception {
// 1. round_1 use blacklist
ChatConfigResp chatConfig = configService.fetchConfigByModelId(1L);
ChatConfigEditReqReq extendEditCmd = new ChatConfigEditReqReq();

View File

@@ -2,9 +2,9 @@ package com.tencent.supersonic.integration;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingResp;
import com.tencent.supersonic.chat.parser.embedding.RecallRetrieval;
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.util.DataUtils;
@@ -34,7 +34,8 @@ public class MockConfiguration {
when(embeddingConfig.getUrl()).thenReturn("test");
}
public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path, ResponseEntity<String> responseEntity) {
public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path,
ResponseEntity<String> responseEntity) {
when(pluginManager.doRequest(eq(path), notNull(String.class))).thenReturn(responseEntity);
}

View File

@@ -0,0 +1,43 @@
package com.tencent.supersonic.integration.llm;
import static org.mockito.Mockito.when;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.LLMConfig;
import com.tencent.supersonic.chat.parser.llm.dsl.LLMDslParser;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.integration.BaseQueryTest;
import com.tencent.supersonic.util.DataUtils;
import org.junit.Test;
import org.springframework.boot.test.mock.mockito.MockBean;
public class LLMDslParserTest extends BaseQueryTest {
@MockBean
protected LLMConfig llmConfig;
@Test
public void parse() throws Exception {
String queryText = "周杰伦专辑十一月的萧邦有哪些歌曲";
QueryReq queryReq = DataUtils.getQueryContextReq(10, queryText);
QueryContext queryContext = new QueryContext();
queryContext.setRequest(queryReq);
SemanticParser dslParser = ComponentFactory.getSemanticParsers().stream().filter(parser -> {
if (parser instanceof LLMDslParser) {
return true;
} else {
return false;
}
}
).findFirst().get();
when(llmConfig.getUrl()).thenReturn("llmUrl");
ChatContext chatCtx = new ChatContext();
dslParser.parse(queryContext, chatCtx);
}
}

View File

@@ -0,0 +1,46 @@
package com.tencent.supersonic.integration.mapper;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.integration.BaseQueryTest;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import com.tencent.supersonic.util.DataUtils;
import org.junit.Test;
public class MapperTest extends BaseQueryTest {
@Test
public void hanlp() throws Exception {
QueryReq queryContextReq = DataUtils.getQueryContextReq(10, "艺人周杰伦的播放量");
queryContextReq.setAgentId(1);
QueryResult actualResult = submitNewChat("艺人周杰伦的播放量");
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
expectedResult.setChatContext(expectedParseInfo);
expectedResult.setQueryMode(MetricEntityQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE);
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 7L);
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
SchemaElement metric = SchemaElement.builder().name("播放量").build();
expectedParseInfo.getMetrics().add(metric);
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, 7, period, startDay, endDay));
expectedParseInfo.setNativeQuery(false);
assertQueryResult(expectedResult, actualResult);
}
}

View File

@@ -29,4 +29,4 @@ public class BasePluginTest {
Assert.assertEquals("alice", webPage.getParams().get(0).getValue());
}
}
}

View File

@@ -5,7 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.QueryService;
@@ -17,16 +17,16 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.test.mock.mockito.MockBean;
public class PluginRecognizeTest extends BasePluginTest{
@MockBean
private EmbeddingConfig embeddingConfig;
public class PluginRecognizeTest extends BasePluginTest {
@MockBean
protected PluginManager pluginManager;
@MockBean
protected AgentService agentService;
private EmbeddingConfig embeddingConfig;
@MockBean
private AgentService agentService;
@Autowired
@Qualifier("chatQueryService")
@@ -34,7 +34,7 @@ public class PluginRecognizeTest extends BasePluginTest{
@Test
public void webPageRecognize() throws Exception {
MockConfiguration.mockEmbeddingRecognize(pluginManager, "alice最近的访问情况怎么样","1");
MockConfiguration.mockEmbeddingRecognize(pluginManager, "alice最近的访问情况怎么样", "1");
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
QueryReq queryContextReq = DataUtils.getQueryContextReq(1000, "alice最近的访问情况怎么样");
QueryResult queryResult = queryService.executeQuery(queryContextReq);
@@ -43,7 +43,7 @@ public class PluginRecognizeTest extends BasePluginTest{
@Test
public void webPageRecognizeWithQueryFilter() throws Exception {
MockConfiguration.mockEmbeddingRecognize(pluginManager, "在超音数最近的情况怎么样","1");
MockConfiguration.mockEmbeddingRecognize(pluginManager, "在超音数最近的情况怎么样", "1");
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
QueryReq queryRequest = DataUtils.getQueryContextReq(1000, "在超音数最近的情况怎么样");
QueryFilters queryFilters = new QueryFilters();
@@ -59,7 +59,7 @@ public class PluginRecognizeTest extends BasePluginTest{
@Test
public void pluginRecognizeWithAgent() {
MockConfiguration.mockEmbeddingRecognize(pluginManager, "alice最近的访问情况怎么样","1");
MockConfiguration.mockEmbeddingRecognize(pluginManager, "alice最近的访问情况怎么样", "1");
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
MockConfiguration.mockAgent(agentService);
QueryReq queryContextReq = DataUtils.getQueryReqWithAgent(1000, "alice最近的访问情况怎么样",

View File

@@ -31,7 +31,7 @@ public class DataUtils {
public static QueryReq getQueryContextReq(Integer id, String query) {
QueryReq queryContextReq = new QueryReq();
queryContextReq.setQueryText(query);//"alice的访问次数"
queryContextReq.setQueryText(query);
queryContextReq.setChatId(id);
queryContextReq.setUser(user_test);
return queryContextReq;
@@ -39,7 +39,7 @@ public class DataUtils {
public static QueryReq getQueryReqWithAgent(Integer id, String query, Integer agentId) {
QueryReq queryReq = new QueryReq();
queryReq.setQueryText(query);//"alice的访问次数"
queryReq.setQueryText(query);
queryReq.setChatId(id);
queryReq.setUser(user_test);
queryReq.setAgentId(agentId);
@@ -74,8 +74,8 @@ public class DataUtils {
.build();
}
public static QueryFilter getFilter(String bizName, FilterOperatorEnum filterOperatorEnum, Object value, String name,
Long elementId) {
public static QueryFilter getFilter(String bizName, FilterOperatorEnum filterOperatorEnum,
Object value, String name, Long elementId) {
QueryFilter filter = new QueryFilter();
filter.setBizName(bizName);
filter.setOperator(filterOperatorEnum);
@@ -95,7 +95,8 @@ public class DataUtils {
return dateInfo;
}
public static DateConf getDateConf(DateConf.DateMode dateMode, Integer unit, String period, String startDate, String endDate) {
public static DateConf getDateConf(DateConf.DateMode dateMode, Integer unit,
String period, String startDate, String endDate) {
DateConf dateInfo = new DateConf();
dateInfo.setUnit(unit);
dateInfo.setDateMode(dateMode);
@@ -114,9 +115,9 @@ public class DataUtils {
}
public static Boolean compareDate(DateConf dateInfo1, DateConf dateInfo2) {
Boolean timeFilterExist = dateInfo1.getUnit().equals(dateInfo2.getUnit()) &&
dateInfo1.getDateMode().equals(dateInfo2.getDateMode()) &&
dateInfo1.getPeriod().equals(dateInfo2.getPeriod());
Boolean timeFilterExist = dateInfo1.getUnit().equals(dateInfo2.getUnit())
&& dateInfo1.getDateMode().equals(dateInfo2.getDateMode())
&& dateInfo1.getPeriod().equals(dateInfo2.getPeriod());
return timeFilterExist;
}
@@ -135,11 +136,11 @@ public class DataUtils {
public static Boolean compareDimensionFilter(Set<QueryFilter> dimensionFilters, QueryFilter dimensionFilter) {
Boolean dimensionFilterExist = false;
for (QueryFilter filter : dimensionFilters) {
if (filter.getBizName().equals(dimensionFilter.getBizName()) &&
filter.getOperator().equals(dimensionFilter.getOperator()) &&
filter.getValue().toString().equals(dimensionFilter.getValue().toString()) &&
filter.getElementID().equals(dimensionFilter.getElementID()) &&
filter.getName().equals(dimensionFilter.getName())) {
if (filter.getBizName().equals(dimensionFilter.getBizName())
&& filter.getOperator().equals(dimensionFilter.getOperator())
&& filter.getValue().toString().equals(dimensionFilter.getValue().toString())
&& filter.getElementID().equals(dimensionFilter.getElementID())
&& filter.getName().equals(dimensionFilter.getName())) {
dimensionFilterExist = true;
}
}
@@ -163,6 +164,7 @@ public class DataUtils {
private static RuleQueryTool getRuleQueryTool() {
RuleQueryTool ruleQueryTool = new RuleQueryTool();
ruleQueryTool.setType(AgentToolType.RULE);
ruleQueryTool.setModelIds(Lists.newArrayList(1L, 2L));
ruleQueryTool.setQueryModes(Lists.newArrayList("METRIC_ENTITY", "METRIC_FILTER", "METRIC_MODEL"));
return ruleQueryTool;
}

View File

@@ -24,5 +24,8 @@ semantic:
url:
prefix: http://127.0.0.1:9081
time:
threshold: 100
mybatis:
mapper-locations=classpath:mappers/custom/*.xml,classpath*:/mappers/*.xml

File diff suppressed because it is too large Load Diff