mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 22:46:49 +00:00
[improvement][headless]Clean code logic of headless translator.
This commit is contained in:
@@ -41,8 +41,8 @@ public class S2CompanyDemo extends S2BaseDemo {
|
||||
ModelResp model_brand = addModel_2(domain, demoDatabase);
|
||||
ModelResp model_brand_revenue = addModel_3(domain, demoDatabase);
|
||||
|
||||
addModelRela(domain, model_company, model_brand, "company_id");
|
||||
addModelRela(domain, model_brand, model_brand_revenue, "brand_id");
|
||||
addModelRela(domain, model_brand, model_company, "company_id");
|
||||
addModelRela(domain, model_brand_revenue, model_brand, "brand_id");
|
||||
|
||||
DataSetResp dataset = addDataSet(domain);
|
||||
addAgent(dataset.getId());
|
||||
@@ -106,8 +106,7 @@ public class S2CompanyDemo extends S2BaseDemo {
|
||||
modelDetail.setMeasures(measures);
|
||||
|
||||
modelDetail.setQueryType("sql_query");
|
||||
modelDetail.setSqlQuery("SELECT company_id,company_name,headquarter_address,"
|
||||
+ "company_established_time,founder,ceo,annual_turnover,employee_count FROM company");
|
||||
modelDetail.setSqlQuery("SELECT * FROM company");
|
||||
modelReq.setModelDetail(modelDetail);
|
||||
ModelResp companyModel = modelService.createModel(modelReq, defaultUser);
|
||||
|
||||
@@ -146,8 +145,7 @@ public class S2CompanyDemo extends S2BaseDemo {
|
||||
modelDetail.setMeasures(measures);
|
||||
|
||||
modelDetail.setQueryType("sql_query");
|
||||
modelDetail.setSqlQuery("SELECT brand_id,brand_name,brand_established_time,"
|
||||
+ "company_id,legal_representative,registered_capital FROM brand");
|
||||
modelDetail.setSqlQuery("SELECT * FROM brand");
|
||||
modelReq.setModelDetail(modelDetail);
|
||||
ModelResp brandModel = modelService.createModel(modelReq, defaultUser);
|
||||
|
||||
@@ -187,8 +185,7 @@ public class S2CompanyDemo extends S2BaseDemo {
|
||||
modelDetail.setMeasures(measures);
|
||||
|
||||
modelDetail.setQueryType("sql_query");
|
||||
modelDetail.setSqlQuery("SELECT year_time,brand_id,revenue,profit,"
|
||||
+ "revenue_growth_year_on_year,profit_growth_year_on_year FROM brand_revenue");
|
||||
modelDetail.setSqlQuery("SELECT * FROM brand_revenue");
|
||||
modelReq.setModelDetail(modelDetail);
|
||||
return modelService.createModel(modelReq, defaultUser);
|
||||
}
|
||||
@@ -227,7 +224,7 @@ public class S2CompanyDemo extends S2BaseDemo {
|
||||
modelRelaReq.setDomainId(domain.getId());
|
||||
modelRelaReq.setFromModelId(fromModel.getId());
|
||||
modelRelaReq.setToModelId(toModel.getId());
|
||||
modelRelaReq.setJoinType("left join");
|
||||
modelRelaReq.setJoinType("inner join");
|
||||
modelRelaReq.setJoinConditions(joinConditions);
|
||||
modelRelaService.save(modelRelaReq, defaultUser);
|
||||
}
|
||||
|
||||
@@ -199,6 +199,7 @@ public class S2VisitsDemo extends S2BaseDemo {
|
||||
|
||||
List<Dim> dimensions = new ArrayList<>();
|
||||
dimensions.add(new Dim("部门", "department", DimensionType.categorical, 1));
|
||||
// dimensions.add(new Dim("用户", "user_name", DimensionType.categorical, 1));
|
||||
modelDetail.setDimensions(dimensions);
|
||||
List<Field> fields = Lists.newArrayList();
|
||||
fields.add(Field.builder().fieldName("user_name").dataType("Varchar").build());
|
||||
|
||||
@@ -5,7 +5,10 @@ import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
import com.tencent.supersonic.chat.BaseTest;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.*;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.DatasetTool;
|
||||
import com.tencent.supersonic.chat.server.agent.ToolConfig;
|
||||
import com.tencent.supersonic.common.config.ChatModel;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
@@ -133,11 +136,28 @@ public class Text2SQLEval extends BaseTest {
|
||||
assert result.getTextResult().contains("3");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_detail_query() throws Exception {
|
||||
long start = System.currentTimeMillis();
|
||||
QueryResult result = submitNewChat("特斯拉旗下有哪些品牌", agentId);
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() >= 1;
|
||||
assert result.getTextResult().contains("Model Y");
|
||||
assert result.getTextResult().contains("Model 3");
|
||||
}
|
||||
|
||||
public Agent getLLMAgent() {
|
||||
Agent agent = new Agent();
|
||||
agent.setName("Agent for Test");
|
||||
ToolConfig toolConfig = new ToolConfig();
|
||||
toolConfig.getTools().add(getDatasetTool());
|
||||
DatasetTool datasetTool = new DatasetTool();
|
||||
datasetTool.setType(AgentToolType.DATASET);
|
||||
datasetTool.setDataSetIds(Lists.newArrayList(DataUtils.productDatasetId));
|
||||
toolConfig.getTools().add(datasetTool);
|
||||
DatasetTool datasetTool2 = new DatasetTool();
|
||||
datasetTool2.setType(AgentToolType.DATASET);
|
||||
datasetTool2.setDataSetIds(Lists.newArrayList(DataUtils.companyDatasetId));
|
||||
toolConfig.getTools().add(datasetTool2);
|
||||
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
|
||||
// create chat model for this evaluation
|
||||
ChatModel chatModel = new ChatModel();
|
||||
@@ -154,11 +174,4 @@ public class Text2SQLEval extends BaseTest {
|
||||
return agent;
|
||||
}
|
||||
|
||||
private static DatasetTool getDatasetTool() {
|
||||
DatasetTool datasetTool = new DatasetTool();
|
||||
datasetTool.setType(AgentToolType.DATASET);
|
||||
datasetTool.setDataSetIds(Lists.newArrayList(1L));
|
||||
|
||||
return datasetTool;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user