[feature][chat]Refactor chat model config related codes.#1739

This commit is contained in:
jerryjzhang
2024-10-09 17:27:07 +08:00
parent 60b0a1a1a1
commit 248f4f83f6
53 changed files with 275 additions and 251 deletions

View File

@@ -7,10 +7,10 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.BaseTest;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentConfig;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.chat.server.pojo.ChatModel;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import com.tencent.supersonic.util.DataUtils;
@@ -49,7 +49,7 @@ public class Text2SQLEval extends BaseTest {
QueryResult result = submitNewChat("近30天总访问次数", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() == 1;
assert result.getQueryColumns().get(0).getName().contains("访问次数");
assert result.getTextResult().contains("511");
}
@Test
@@ -58,8 +58,8 @@ public class Text2SQLEval extends BaseTest {
QueryResult result = submitNewChat("近30日每天的访问次数", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() == 2;
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("date");
assert result.getQueryColumns().get(1).getName().contains("访问次数");
assert result.getQueryResults().size() == 30;
assert result.getTextResult().contains("date");
}
@Test
@@ -68,9 +68,11 @@ public class Text2SQLEval extends BaseTest {
QueryResult result = submitNewChat("过去30天每个部门的汇总访问次数", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() == 2;
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门");
assert result.getQueryColumns().get(1).getName().contains("访问次数");
assert result.getQueryResults().size() == 4;
assert result.getTextResult().contains("marketing");
assert result.getTextResult().contains("sales");
assert result.getTextResult().contains("strategy");
assert result.getTextResult().contains("HR");
}
@Test
@@ -134,16 +136,16 @@ public class Text2SQLEval extends BaseTest {
public Agent getLLMAgent(boolean enableMultiturn) {
Agent agent = new Agent();
agent.setName("Agent for Test");
AgentConfig agentConfig = new AgentConfig();
agentConfig.getTools().add(getLLMQueryTool());
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
ToolConfig toolConfig = new ToolConfig();
toolConfig.getTools().add(getLLMQueryTool());
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
ChatModel chatModel = new ChatModel();
chatModel.setName("Text2SQL LLM");
chatModel.setConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3));
chatModel = chatModelService.createChatModel(chatModel, User.getFakeUser());
chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser());
Map<ChatModelType, Integer> chatModelConfig = Maps.newHashMap();
chatModelConfig.put(ChatModelType.TEXT_TO_SQL, chatModel.getId());
agent.setModelConfig(chatModelConfig);
agent.setChatModelConfig(chatModelConfig);
MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
multiTurnConfig.setEnableMultiTurn(enableMultiturn);
agent.setMultiTurnConfig(multiTurnConfig);

View File

@@ -34,7 +34,7 @@ public class BaseTest extends BaseApplication {
private DomainRepository domainRepository;
protected SemanticQueryResp queryBySql(String sql) throws Exception {
return queryBySql(sql, User.getFakeUser());
return queryBySql(sql, User.getDefaultUser());
}
protected SemanticQueryResp queryBySql(String sql, User user) throws Exception {

View File

@@ -22,7 +22,7 @@ public class MetaDiscoveryTest extends BaseTest {
QueryMapReq queryMapReq = new QueryMapReq();
queryMapReq.setQueryText("对比alice和lucy的访问次数");
queryMapReq.setTopN(10);
queryMapReq.setUser(User.getFakeUser());
queryMapReq.setUser(User.getDefaultUser());
queryMapReq.setDataSetNames(Collections.singletonList("超音数数据集"));
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
@@ -36,7 +36,7 @@ public class MetaDiscoveryTest extends BaseTest {
QueryMapReq queryMapReq = new QueryMapReq();
queryMapReq.setQueryText("风格为流行的艺人");
queryMapReq.setTopN(10);
queryMapReq.setUser(User.getFakeUser());
queryMapReq.setUser(User.getDefaultUser());
queryMapReq.setDataSetNames(Collections.singletonList("艺人库"));
queryMapReq.setQueryDataType(QueryDataType.TAG);
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
@@ -48,7 +48,7 @@ public class MetaDiscoveryTest extends BaseTest {
QueryMapReq queryMapReq = new QueryMapReq();
queryMapReq.setQueryText("超音数访问次数最高的部门");
queryMapReq.setTopN(10);
queryMapReq.setUser(User.getFakeUser());
queryMapReq.setUser(User.getDefaultUser());
queryMapReq.setDataSetNames(Collections.singletonList("超音数"));
queryMapReq.setQueryDataType(QueryDataType.METRIC);
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);

View File

@@ -23,7 +23,7 @@ public class QueryByMetricTest extends BaseTest {
QueryMetricReq queryMetricReq = new QueryMetricReq();
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser());
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(6, queryResp.getResultList().size());
}
@@ -33,7 +33,7 @@ public class QueryByMetricTest extends BaseTest {
QueryMetricReq queryMetricReq = new QueryMetricReq();
queryMetricReq.setMetricNames(Arrays.asList("停留时长", "访问次数"));
queryMetricReq.setDimensionNames(Arrays.asList("用户", "部门"));
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser());
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(6, queryResp.getResultList().size());
}
@@ -44,7 +44,7 @@ public class QueryByMetricTest extends BaseTest {
queryMetricReq.setDomainId(1L);
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser());
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(6, queryResp.getResultList().size());
@@ -52,7 +52,7 @@ public class QueryByMetricTest extends BaseTest {
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
assertThrows(IllegalArgumentException.class,
() -> queryByMetric(queryMetricReq, User.getFakeUser()));
() -> queryByMetric(queryMetricReq, User.getDefaultUser()));
}
@Test
@@ -61,7 +61,7 @@ public class QueryByMetricTest extends BaseTest {
queryMetricReq.setDomainId(1L);
queryMetricReq.setMetricIds(Arrays.asList(1L, 3L));
queryMetricReq.setDimensionIds(Arrays.asList(1L, 2L));
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser());
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(6, queryResp.getResultList().size());
}

View File

@@ -49,7 +49,7 @@ public class QueryByStructTest extends BaseTest {
QueryStructReq queryStructReq =
buildQueryStructReq(Arrays.asList("user_name", "department"), QueryType.DETAIL);
SemanticQueryResp semanticQueryResp =
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
assertEquals(3, semanticQueryResp.getColumns().size());
QueryColumn firstColumn = semanticQueryResp.getColumns().get(0);
assertEquals("用户", firstColumn.getName());
@@ -64,7 +64,7 @@ public class QueryByStructTest extends BaseTest {
public void testSumQuery() throws Exception {
QueryStructReq queryStructReq = buildQueryStructReq(null);
SemanticQueryResp semanticQueryResp =
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
assertEquals(1, semanticQueryResp.getColumns().size());
QueryColumn queryColumn = semanticQueryResp.getColumns().get(0);
assertEquals("访问次数", queryColumn.getName());
@@ -75,7 +75,7 @@ public class QueryByStructTest extends BaseTest {
public void testGroupByQuery() throws Exception {
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
SemanticQueryResp result =
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
assertEquals(2, result.getColumns().size());
QueryColumn firstColumn = result.getColumns().get(0);
QueryColumn secondColumn = result.getColumns().get(1);
@@ -97,7 +97,7 @@ public class QueryByStructTest extends BaseTest {
queryStructReq.setDimensionFilters(dimensionFilters);
SemanticQueryResp result =
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
assertEquals(2, result.getColumns().size());
QueryColumn firstColumn = result.getColumns().get(0);
QueryColumn secondColumn = result.getColumns().get(1);

View File

@@ -15,7 +15,7 @@ public class QueryDimensionTest extends BaseTest {
queryDimValueReq.setBizName("department");
SemanticQueryResp queryResp =
semanticLayerService.queryDimensionValue(queryDimValueReq, User.getFakeUser());
semanticLayerService.queryDimensionValue(queryDimValueReq, User.getDefaultUser());
Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(4, queryResp.getResultList().size());
}

View File

@@ -22,7 +22,7 @@ public class QueryRuleTest extends BaseTest {
@Autowired
private QueryRuleService queryRuleService;
private User user = User.getFakeUser();
private User user = User.getDefaultUser();
public QueryRuleReq addSystemRule() {
QueryRuleReq queryRuleReq = new QueryRuleReq();

View File

@@ -18,7 +18,7 @@ public class TagObjectTest extends BaseTest {
@Test
void testCreateTagObject() throws Exception {
User user = User.getFakeUser();
User user = User.getDefaultUser();
TagObjectReq tagObjectReq = newTagObjectReq();
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, user);
tagObjectService.delete(tagObjectResp.getId(), user, false);
@@ -27,24 +27,25 @@ public class TagObjectTest extends BaseTest {
@Test
void testUpdateTagObject() throws Exception {
TagObjectReq tagObjectReq = newTagObjectReq();
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, User.getFakeUser());
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, User.getDefaultUser());
TagObjectReq tagObjectReqUpdate = new TagObjectReq();
BeanUtils.copyProperties(tagObjectResp, tagObjectReqUpdate);
tagObjectReqUpdate.setName("艺人1");
tagObjectService.update(tagObjectReqUpdate, User.getFakeUser());
tagObjectService.update(tagObjectReqUpdate, User.getDefaultUser());
TagObjectResp tagObject =
tagObjectService.getTagObject(tagObjectReqUpdate.getId(), User.getFakeUser());
tagObjectService.delete(tagObject.getId(), User.getFakeUser(), false);
tagObjectService.getTagObject(tagObjectReqUpdate.getId(), User.getDefaultUser());
tagObjectService.delete(tagObject.getId(), User.getDefaultUser(), false);
}
@Test
void testQueryTagObject() throws Exception {
TagObjectReq tagObjectReq = newTagObjectReq();
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, User.getFakeUser());
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, User.getDefaultUser());
TagObjectFilter filter = new TagObjectFilter();
List<TagObjectResp> tagObjects = tagObjectService.getTagObjects(filter, User.getFakeUser());
List<TagObjectResp> tagObjects =
tagObjectService.getTagObjects(filter, User.getDefaultUser());
tagObjects.size();
tagObjectService.delete(tagObjectResp.getId(), User.getFakeUser(), false);
tagObjectService.delete(tagObjectResp.getId(), User.getDefaultUser(), false);
}
private TagObjectReq newTagObjectReq() {

View File

@@ -21,7 +21,7 @@ public class TagTest extends BaseTest {
ItemValueReq itemValueReq = new ItemValueReq();
itemValueReq.setId(1L);
ItemValueResp itemValueResp =
tagQueryService.queryTagValue(itemValueReq, User.getFakeUser());
tagQueryService.queryTagValue(itemValueReq, User.getDefaultUser());
Assertions.assertNotNull(itemValueResp);
}
}

View File

@@ -19,7 +19,7 @@ public class TranslateTest extends BaseTest {
String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ";
SemanticTranslateResp explain = semanticLayerService.translate(
QueryReqBuilder.buildS2SQLReq(sql, DataUtils.getMetricAgentView()),
User.getFakeUser());
User.getDefaultUser());
assertNotNull(explain);
assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().contains("department"));
@@ -30,7 +30,7 @@ public class TranslateTest extends BaseTest {
public void testStructExplain() throws Exception {
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
SemanticTranslateResp explain =
semanticLayerService.translate(queryStructReq, User.getFakeUser());
semanticLayerService.translate(queryStructReq, User.getDefaultUser());
assertNotNull(explain);
assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().contains("department"));

View File

@@ -19,7 +19,7 @@ public class DataUtils {
public static final Integer tagAgentId = 2;
public static final Integer ONE_TURNS_CHAT_ID = 10;
public static final Integer MULTI_TURNS_CHAT_ID = 11;
private static final User user_test = User.getFakeUser();
private static final User user_test = User.getDefaultUser();
public static User getUser() {
return user_test;

View File

@@ -388,9 +388,9 @@ CREATE TABLE IF NOT EXISTS s2_agent
description varchar(500) null,
status int null,
examples varchar(500) null,
config varchar(2000) null,
tool_config varchar(2000) null,
llm_config varchar(2000) null,
model_config varchar(6000) null,
chat_model_config varchar(6000) null,
prompt_config varchar(5000) null,
multi_turn_config varchar(2000) null,
visual_config varchar(2000) null,