[improvement][headless]Clean code logic of headless translator.

This commit is contained in:
jerryjzhang
2024-11-27 11:29:29 +08:00
parent 7bf1ba09c5
commit dad065d0ba
10 changed files with 184 additions and 276 deletions

View File

@@ -41,8 +41,8 @@ public class S2CompanyDemo extends S2BaseDemo {
ModelResp model_brand = addModel_2(domain, demoDatabase);
ModelResp model_brand_revenue = addModel_3(domain, demoDatabase);
addModelRela(domain, model_company, model_brand, "company_id");
addModelRela(domain, model_brand, model_brand_revenue, "brand_id");
addModelRela(domain, model_brand, model_company, "company_id");
addModelRela(domain, model_brand_revenue, model_brand, "brand_id");
DataSetResp dataset = addDataSet(domain);
addAgent(dataset.getId());
@@ -106,8 +106,7 @@ public class S2CompanyDemo extends S2BaseDemo {
modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT company_id,company_name,headquarter_address,"
+ "company_established_time,founder,ceo,annual_turnover,employee_count FROM company");
modelDetail.setSqlQuery("SELECT * FROM company");
modelReq.setModelDetail(modelDetail);
ModelResp companyModel = modelService.createModel(modelReq, defaultUser);
@@ -146,8 +145,7 @@ public class S2CompanyDemo extends S2BaseDemo {
modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT brand_id,brand_name,brand_established_time,"
+ "company_id,legal_representative,registered_capital FROM brand");
modelDetail.setSqlQuery("SELECT * FROM brand");
modelReq.setModelDetail(modelDetail);
ModelResp brandModel = modelService.createModel(modelReq, defaultUser);
@@ -187,8 +185,7 @@ public class S2CompanyDemo extends S2BaseDemo {
modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT year_time,brand_id,revenue,profit,"
+ "revenue_growth_year_on_year,profit_growth_year_on_year FROM brand_revenue");
modelDetail.setSqlQuery("SELECT * FROM brand_revenue");
modelReq.setModelDetail(modelDetail);
return modelService.createModel(modelReq, defaultUser);
}
@@ -227,7 +224,7 @@ public class S2CompanyDemo extends S2BaseDemo {
modelRelaReq.setDomainId(domain.getId());
modelRelaReq.setFromModelId(fromModel.getId());
modelRelaReq.setToModelId(toModel.getId());
modelRelaReq.setJoinType("left join");
modelRelaReq.setJoinType("inner join");
modelRelaReq.setJoinConditions(joinConditions);
modelRelaService.save(modelRelaReq, defaultUser);
}

View File

@@ -199,6 +199,7 @@ public class S2VisitsDemo extends S2BaseDemo {
List<Dim> dimensions = new ArrayList<>();
dimensions.add(new Dim("部门", "department", DimensionType.categorical, 1));
// dimensions.add(new Dim("用户", "user_name", DimensionType.categorical, 1));
modelDetail.setDimensions(dimensions);
List<Field> fields = Lists.newArrayList();
fields.add(Field.builder().fieldName("user_name").dataType("Varchar").build());

View File

@@ -5,7 +5,10 @@ import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.tencent.supersonic.chat.BaseTest;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.*;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.common.config.ChatModel;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.User;
@@ -133,11 +136,28 @@ public class Text2SQLEval extends BaseTest {
assert result.getTextResult().contains("3");
}
@Test
public void test_detail_query() throws Exception {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("特斯拉旗下有哪些品牌", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() >= 1;
assert result.getTextResult().contains("Model Y");
assert result.getTextResult().contains("Model 3");
}
public Agent getLLMAgent() {
Agent agent = new Agent();
agent.setName("Agent for Test");
ToolConfig toolConfig = new ToolConfig();
toolConfig.getTools().add(getDatasetTool());
DatasetTool datasetTool = new DatasetTool();
datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(DataUtils.productDatasetId));
toolConfig.getTools().add(datasetTool);
DatasetTool datasetTool2 = new DatasetTool();
datasetTool2.setType(AgentToolType.DATASET);
datasetTool2.setDataSetIds(Lists.newArrayList(DataUtils.companyDatasetId));
toolConfig.getTools().add(datasetTool2);
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// create chat model for this evaluation
ChatModel chatModel = new ChatModel();
@@ -154,11 +174,4 @@ public class Text2SQLEval extends BaseTest {
return agent;
}
private static DatasetTool getDatasetTool() {
DatasetTool datasetTool = new DatasetTool();
datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(1L));
return datasetTool;
}
}