mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-04-22 14:54:21 +08:00
[feature][chat]Refactor chat model config related codes.#1739
This commit is contained in:
@@ -7,10 +7,10 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.BaseTest;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
||||
import com.tencent.supersonic.chat.server.agent.ToolConfig;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatModel;
|
||||
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
@@ -49,7 +49,7 @@ public class Text2SQLEval extends BaseTest {
|
||||
QueryResult result = submitNewChat("近30天总访问次数", agentId);
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() == 1;
|
||||
assert result.getQueryColumns().get(0).getName().contains("访问次数");
|
||||
assert result.getTextResult().contains("511");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -58,8 +58,8 @@ public class Text2SQLEval extends BaseTest {
|
||||
QueryResult result = submitNewChat("近30日每天的访问次数", agentId);
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() == 2;
|
||||
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("date");
|
||||
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
||||
assert result.getQueryResults().size() == 30;
|
||||
assert result.getTextResult().contains("date");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -68,9 +68,11 @@ public class Text2SQLEval extends BaseTest {
|
||||
QueryResult result = submitNewChat("过去30天每个部门的汇总访问次数", agentId);
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() == 2;
|
||||
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门");
|
||||
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
||||
assert result.getQueryResults().size() == 4;
|
||||
assert result.getTextResult().contains("marketing");
|
||||
assert result.getTextResult().contains("sales");
|
||||
assert result.getTextResult().contains("strategy");
|
||||
assert result.getTextResult().contains("HR");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -134,16 +136,16 @@ public class Text2SQLEval extends BaseTest {
|
||||
public Agent getLLMAgent(boolean enableMultiturn) {
|
||||
Agent agent = new Agent();
|
||||
agent.setName("Agent for Test");
|
||||
AgentConfig agentConfig = new AgentConfig();
|
||||
agentConfig.getTools().add(getLLMQueryTool());
|
||||
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
||||
ToolConfig toolConfig = new ToolConfig();
|
||||
toolConfig.getTools().add(getLLMQueryTool());
|
||||
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
|
||||
ChatModel chatModel = new ChatModel();
|
||||
chatModel.setName("Text2SQL LLM");
|
||||
chatModel.setConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3));
|
||||
chatModel = chatModelService.createChatModel(chatModel, User.getFakeUser());
|
||||
chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser());
|
||||
Map<ChatModelType, Integer> chatModelConfig = Maps.newHashMap();
|
||||
chatModelConfig.put(ChatModelType.TEXT_TO_SQL, chatModel.getId());
|
||||
agent.setModelConfig(chatModelConfig);
|
||||
agent.setChatModelConfig(chatModelConfig);
|
||||
MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
|
||||
multiTurnConfig.setEnableMultiTurn(enableMultiturn);
|
||||
agent.setMultiTurnConfig(multiTurnConfig);
|
||||
|
||||
@@ -34,7 +34,7 @@ public class BaseTest extends BaseApplication {
|
||||
private DomainRepository domainRepository;
|
||||
|
||||
protected SemanticQueryResp queryBySql(String sql) throws Exception {
|
||||
return queryBySql(sql, User.getFakeUser());
|
||||
return queryBySql(sql, User.getDefaultUser());
|
||||
}
|
||||
|
||||
protected SemanticQueryResp queryBySql(String sql, User user) throws Exception {
|
||||
|
||||
@@ -22,7 +22,7 @@ public class MetaDiscoveryTest extends BaseTest {
|
||||
QueryMapReq queryMapReq = new QueryMapReq();
|
||||
queryMapReq.setQueryText("对比alice和lucy的访问次数");
|
||||
queryMapReq.setTopN(10);
|
||||
queryMapReq.setUser(User.getFakeUser());
|
||||
queryMapReq.setUser(User.getDefaultUser());
|
||||
queryMapReq.setDataSetNames(Collections.singletonList("超音数数据集"));
|
||||
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
|
||||
|
||||
@@ -36,7 +36,7 @@ public class MetaDiscoveryTest extends BaseTest {
|
||||
QueryMapReq queryMapReq = new QueryMapReq();
|
||||
queryMapReq.setQueryText("风格为流行的艺人");
|
||||
queryMapReq.setTopN(10);
|
||||
queryMapReq.setUser(User.getFakeUser());
|
||||
queryMapReq.setUser(User.getDefaultUser());
|
||||
queryMapReq.setDataSetNames(Collections.singletonList("艺人库"));
|
||||
queryMapReq.setQueryDataType(QueryDataType.TAG);
|
||||
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
|
||||
@@ -48,7 +48,7 @@ public class MetaDiscoveryTest extends BaseTest {
|
||||
QueryMapReq queryMapReq = new QueryMapReq();
|
||||
queryMapReq.setQueryText("超音数访问次数最高的部门");
|
||||
queryMapReq.setTopN(10);
|
||||
queryMapReq.setUser(User.getFakeUser());
|
||||
queryMapReq.setUser(User.getDefaultUser());
|
||||
queryMapReq.setDataSetNames(Collections.singletonList("超音数"));
|
||||
queryMapReq.setQueryDataType(QueryDataType.METRIC);
|
||||
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
|
||||
|
||||
@@ -23,7 +23,7 @@ public class QueryByMetricTest extends BaseTest {
|
||||
QueryMetricReq queryMetricReq = new QueryMetricReq();
|
||||
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
|
||||
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser());
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
|
||||
Assert.assertNotNull(queryResp.getResultList());
|
||||
Assert.assertEquals(6, queryResp.getResultList().size());
|
||||
}
|
||||
@@ -33,7 +33,7 @@ public class QueryByMetricTest extends BaseTest {
|
||||
QueryMetricReq queryMetricReq = new QueryMetricReq();
|
||||
queryMetricReq.setMetricNames(Arrays.asList("停留时长", "访问次数"));
|
||||
queryMetricReq.setDimensionNames(Arrays.asList("用户", "部门"));
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser());
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
|
||||
Assert.assertNotNull(queryResp.getResultList());
|
||||
Assert.assertEquals(6, queryResp.getResultList().size());
|
||||
}
|
||||
@@ -44,7 +44,7 @@ public class QueryByMetricTest extends BaseTest {
|
||||
queryMetricReq.setDomainId(1L);
|
||||
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
|
||||
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser());
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
|
||||
Assert.assertNotNull(queryResp.getResultList());
|
||||
Assert.assertEquals(6, queryResp.getResultList().size());
|
||||
|
||||
@@ -52,7 +52,7 @@ public class QueryByMetricTest extends BaseTest {
|
||||
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
|
||||
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
|
||||
assertThrows(IllegalArgumentException.class,
|
||||
() -> queryByMetric(queryMetricReq, User.getFakeUser()));
|
||||
() -> queryByMetric(queryMetricReq, User.getDefaultUser()));
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -61,7 +61,7 @@ public class QueryByMetricTest extends BaseTest {
|
||||
queryMetricReq.setDomainId(1L);
|
||||
queryMetricReq.setMetricIds(Arrays.asList(1L, 3L));
|
||||
queryMetricReq.setDimensionIds(Arrays.asList(1L, 2L));
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser());
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
|
||||
Assert.assertNotNull(queryResp.getResultList());
|
||||
Assert.assertEquals(6, queryResp.getResultList().size());
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ public class QueryByStructTest extends BaseTest {
|
||||
QueryStructReq queryStructReq =
|
||||
buildQueryStructReq(Arrays.asList("user_name", "department"), QueryType.DETAIL);
|
||||
SemanticQueryResp semanticQueryResp =
|
||||
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
|
||||
semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
|
||||
assertEquals(3, semanticQueryResp.getColumns().size());
|
||||
QueryColumn firstColumn = semanticQueryResp.getColumns().get(0);
|
||||
assertEquals("用户", firstColumn.getName());
|
||||
@@ -64,7 +64,7 @@ public class QueryByStructTest extends BaseTest {
|
||||
public void testSumQuery() throws Exception {
|
||||
QueryStructReq queryStructReq = buildQueryStructReq(null);
|
||||
SemanticQueryResp semanticQueryResp =
|
||||
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
|
||||
semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
|
||||
assertEquals(1, semanticQueryResp.getColumns().size());
|
||||
QueryColumn queryColumn = semanticQueryResp.getColumns().get(0);
|
||||
assertEquals("访问次数", queryColumn.getName());
|
||||
@@ -75,7 +75,7 @@ public class QueryByStructTest extends BaseTest {
|
||||
public void testGroupByQuery() throws Exception {
|
||||
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
|
||||
SemanticQueryResp result =
|
||||
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
|
||||
semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
|
||||
assertEquals(2, result.getColumns().size());
|
||||
QueryColumn firstColumn = result.getColumns().get(0);
|
||||
QueryColumn secondColumn = result.getColumns().get(1);
|
||||
@@ -97,7 +97,7 @@ public class QueryByStructTest extends BaseTest {
|
||||
queryStructReq.setDimensionFilters(dimensionFilters);
|
||||
|
||||
SemanticQueryResp result =
|
||||
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
|
||||
semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
|
||||
assertEquals(2, result.getColumns().size());
|
||||
QueryColumn firstColumn = result.getColumns().get(0);
|
||||
QueryColumn secondColumn = result.getColumns().get(1);
|
||||
|
||||
@@ -15,7 +15,7 @@ public class QueryDimensionTest extends BaseTest {
|
||||
queryDimValueReq.setBizName("department");
|
||||
|
||||
SemanticQueryResp queryResp =
|
||||
semanticLayerService.queryDimensionValue(queryDimValueReq, User.getFakeUser());
|
||||
semanticLayerService.queryDimensionValue(queryDimValueReq, User.getDefaultUser());
|
||||
Assert.assertNotNull(queryResp.getResultList());
|
||||
Assert.assertEquals(4, queryResp.getResultList().size());
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ public class QueryRuleTest extends BaseTest {
|
||||
@Autowired
|
||||
private QueryRuleService queryRuleService;
|
||||
|
||||
private User user = User.getFakeUser();
|
||||
private User user = User.getDefaultUser();
|
||||
|
||||
public QueryRuleReq addSystemRule() {
|
||||
QueryRuleReq queryRuleReq = new QueryRuleReq();
|
||||
|
||||
@@ -18,7 +18,7 @@ public class TagObjectTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
void testCreateTagObject() throws Exception {
|
||||
User user = User.getFakeUser();
|
||||
User user = User.getDefaultUser();
|
||||
TagObjectReq tagObjectReq = newTagObjectReq();
|
||||
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, user);
|
||||
tagObjectService.delete(tagObjectResp.getId(), user, false);
|
||||
@@ -27,24 +27,25 @@ public class TagObjectTest extends BaseTest {
|
||||
@Test
|
||||
void testUpdateTagObject() throws Exception {
|
||||
TagObjectReq tagObjectReq = newTagObjectReq();
|
||||
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, User.getFakeUser());
|
||||
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, User.getDefaultUser());
|
||||
TagObjectReq tagObjectReqUpdate = new TagObjectReq();
|
||||
BeanUtils.copyProperties(tagObjectResp, tagObjectReqUpdate);
|
||||
tagObjectReqUpdate.setName("艺人1");
|
||||
tagObjectService.update(tagObjectReqUpdate, User.getFakeUser());
|
||||
tagObjectService.update(tagObjectReqUpdate, User.getDefaultUser());
|
||||
TagObjectResp tagObject =
|
||||
tagObjectService.getTagObject(tagObjectReqUpdate.getId(), User.getFakeUser());
|
||||
tagObjectService.delete(tagObject.getId(), User.getFakeUser(), false);
|
||||
tagObjectService.getTagObject(tagObjectReqUpdate.getId(), User.getDefaultUser());
|
||||
tagObjectService.delete(tagObject.getId(), User.getDefaultUser(), false);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testQueryTagObject() throws Exception {
|
||||
TagObjectReq tagObjectReq = newTagObjectReq();
|
||||
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, User.getFakeUser());
|
||||
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, User.getDefaultUser());
|
||||
TagObjectFilter filter = new TagObjectFilter();
|
||||
List<TagObjectResp> tagObjects = tagObjectService.getTagObjects(filter, User.getFakeUser());
|
||||
List<TagObjectResp> tagObjects =
|
||||
tagObjectService.getTagObjects(filter, User.getDefaultUser());
|
||||
tagObjects.size();
|
||||
tagObjectService.delete(tagObjectResp.getId(), User.getFakeUser(), false);
|
||||
tagObjectService.delete(tagObjectResp.getId(), User.getDefaultUser(), false);
|
||||
}
|
||||
|
||||
private TagObjectReq newTagObjectReq() {
|
||||
|
||||
@@ -21,7 +21,7 @@ public class TagTest extends BaseTest {
|
||||
ItemValueReq itemValueReq = new ItemValueReq();
|
||||
itemValueReq.setId(1L);
|
||||
ItemValueResp itemValueResp =
|
||||
tagQueryService.queryTagValue(itemValueReq, User.getFakeUser());
|
||||
tagQueryService.queryTagValue(itemValueReq, User.getDefaultUser());
|
||||
Assertions.assertNotNull(itemValueResp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ public class TranslateTest extends BaseTest {
|
||||
String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ";
|
||||
SemanticTranslateResp explain = semanticLayerService.translate(
|
||||
QueryReqBuilder.buildS2SQLReq(sql, DataUtils.getMetricAgentView()),
|
||||
User.getFakeUser());
|
||||
User.getDefaultUser());
|
||||
assertNotNull(explain);
|
||||
assertNotNull(explain.getQuerySQL());
|
||||
assertTrue(explain.getQuerySQL().contains("department"));
|
||||
@@ -30,7 +30,7 @@ public class TranslateTest extends BaseTest {
|
||||
public void testStructExplain() throws Exception {
|
||||
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
|
||||
SemanticTranslateResp explain =
|
||||
semanticLayerService.translate(queryStructReq, User.getFakeUser());
|
||||
semanticLayerService.translate(queryStructReq, User.getDefaultUser());
|
||||
assertNotNull(explain);
|
||||
assertNotNull(explain.getQuerySQL());
|
||||
assertTrue(explain.getQuerySQL().contains("department"));
|
||||
|
||||
@@ -19,7 +19,7 @@ public class DataUtils {
|
||||
public static final Integer tagAgentId = 2;
|
||||
public static final Integer ONE_TURNS_CHAT_ID = 10;
|
||||
public static final Integer MULTI_TURNS_CHAT_ID = 11;
|
||||
private static final User user_test = User.getFakeUser();
|
||||
private static final User user_test = User.getDefaultUser();
|
||||
|
||||
public static User getUser() {
|
||||
return user_test;
|
||||
|
||||
@@ -388,9 +388,9 @@ CREATE TABLE IF NOT EXISTS s2_agent
|
||||
description varchar(500) null,
|
||||
status int null,
|
||||
examples varchar(500) null,
|
||||
config varchar(2000) null,
|
||||
tool_config varchar(2000) null,
|
||||
llm_config varchar(2000) null,
|
||||
model_config varchar(6000) null,
|
||||
chat_model_config varchar(6000) null,
|
||||
prompt_config varchar(5000) null,
|
||||
multi_turn_config varchar(2000) null,
|
||||
visual_config varchar(2000) null,
|
||||
|
||||
Reference in New Issue
Block a user