(improvement)(chat) fix embedding null pointer (#712)

This commit is contained in:
mainmain
2024-02-02 10:59:33 +08:00
committed by GitHub
parent 1004f71ba4
commit 4d4922d269
12 changed files with 132 additions and 49 deletions

View File

@@ -11,6 +11,8 @@ import com.tencent.supersonic.chat.core.utils.HanlpHelper;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import java.util.List;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
@@ -37,7 +39,9 @@ public class EmbeddingMapper extends BaseMapper {
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
SchemaElement.class);
if (Objects.isNull(matchResult.getMetadata())) {
continue;
}
String modelIdStr = matchResult.getMetadata().get("modelId");
if (StringUtils.isBlank(modelIdStr)) {
continue;

View File

@@ -106,7 +106,9 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
if (StringUtils.isBlank(modelIdStr)) {
return true;
}
return detectModelIds.contains(Long.parseLong(modelIdStr));
//return detectModelIds.contains(Long.parseLong(modelIdStr));
Double modelId = Double.parseDouble(modelIdStr);
return detectModelIds.contains(modelId.longValue());
});
}
}

View File

@@ -1,10 +1,8 @@
# 评测流程
1. 正常启动项目(必须包括LLM服务)
2. 将要评测问题放到evalution/data目录下internet.txt将要评测问题对应的SQL也放到evalution/data目录下gold_example_dusql.txt
3. 执行evalution.sh脚本主要包括构建表数据、获取模型预测结果执行对比逻辑。可以在命令行看到执行准确率错误case会写到同目录的eval.json文件中。
2. 执行evalution.sh脚本主要包括构建表数据、获取模型预测结果执行对比逻辑。可以在命令行看到执行准确率错误case会写到同目录的error_case.json文件中
# 评测意义
制定大模型评估框架对于提示词或代码更改的影响至关重要,可以帮助我们了解这些变化是否会提高或降低准确率、响应速度。 随着产品规模的扩大,如果没有这样的框架,就会发现自己在盲目地调整黑匣子,有助于帮助我们减少问题、提高效率、增强模型能力。大模型评测的核心目的是确定模型的"聪明"程度,深入探讨其性能、特点和局限性,为行业应用提供方向。
通过评测,我们可以更好地了解模型的性能、特点、价值、局限性和潜在风险,并为其发展和应用提供支持,具有重要意义。
制定评估工具对于提示词或代码更改的影响至关重要,方便supersonic快速对接其他模型、更改配置可以帮助我们了解这些变化是否会提高或降低准确率、响应速度。

View File

@@ -48,7 +48,7 @@ def get_pred_result():
config_file=current_directory+"/config/config.yaml"
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
input_path=current_directory+"/data/"+config["domain"]+".txt"
input_path=current_directory+"/data/"+"internet.txt"
pred_sql_path = current_directory+"/data/"+"pred_example_dusql.txt"
pred_sql_exist=os.path.exists(pred_sql_path)
if pred_sql_exist:

View File

@@ -288,7 +288,7 @@ def build_table():
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
db_path=current_directory+"/data/"
db_file=db_path+config["domain"]+".db"
db_file=db_path+"internet.db"
db_exist=os.path.exists(db_file)
if db_exist:
os.remove(db_file)
@@ -301,7 +301,7 @@ if __name__ == '__main__':
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
db_path=current_directory+"/data/"
db_file=db_path+config["domain"]+".db"
db_file=db_path+"internet.db"
db_exist=os.path.exists(db_file)
if db_exist:
os.remove(db_file)
@@ -311,30 +311,3 @@ if __name__ == '__main__':
#build_china_travel_agency(path)
# sql="SELECT T3.company_name, T3.annual_turnover, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id"
# query_sql(sql)
# sql =" select * from ( SELECT brand_name, company_name, sum(revenue_proportion), sum(annual_turnover) FROM (select `sys_imp_date`, `company_name`, `brand_name`, `annual_turnover`, `revenue_proportion`from(select `sys_imp_date`, `company_name`, `brand_name`, `annual_turnover` as `annual_turnover`, `revenue_proportion` as `revenue_proportion`from(select `annual_turnover` as `annual_turnover`, `revenue_proportion` as `revenue_proportion`, `sys_imp_date`, `company_name`, `brand_name`from(select `src1_brand`.`sys_imp_date` as `sys_imp_date`, `src1_company`.`annual_turnover` as `annual_turnover`, `src1_company_revenue`.`revenue_proportion` as `revenue_proportion`, `src1_brand`.`company_id` as `company_id`, `src1_company`.`company_name` as `company_name`, `src1_brand`.`brand_name` as `brand_name`, `src1_brand`.`brand_id` as `brand_id`from(select `annual_turnover` as `annual_turnover`, `imp_date` as `sys_imp_date`, `company_id`, `company_name` as `company_name`, `imp_date` as `imp_date`from(select `imp_date`, `company_id`, `company_name`, `headquarter_address`, `company_established_time`, `founder`, `ceo`, `annual_turnover`, `employee_count`from`company`) as `company`) as `src1_company`inner join (select `revenue_proportion` as `revenue_proportion`, `imp_date` as `sys_imp_date`, `company_id`, `brand_id`, `imp_date` as `imp_date`from(select `imp_date`, `company_id`, `brand_id`, `revenue_proportion`, `profit_proportion`, `expenditure_proportion`from`company_revenue`) as `company_revenue`) as `src1_company_revenue` on `src1_company`.`company_id` = `src1_company_revenue`.`company_id`inner join (select `imp_date` as `sys_imp_date`, `company_id`, `brand_name` as `brand_name`, `brand_id`, `imp_date` as `imp_date`from(select `imp_date`, `brand_id`, `brand_name`, `brand_established_time`, `company_id`, `legal_representative`, `registered_capital`from`brand`) as `brand`) as `src1_brand` on `src1_company`.`company_id` = `src1_brand`.`company_id`) as `src11_`) as `company_company_revenue_brand_0`) as `company_company_revenue_brand_1`) t_103 WHERE sys_imp_date = '2024-01-11' GROUP BY brand_name, company_name, sys_imp_date ORDER BY sum(revenue_proportion) DESC ) a limit 1000 "
# a=query_sql(sql)
# print(a[0][0])
# import pymysql
#
# db = pymysql.connect(host="11.154.212.211", user="semantic", password="semantic2023", database="internet", port=3306)
#
# # 使用cursor()方法获取操作游标
# cursor = db.cursor()
#
# # 使用execute方法执行SQL语句
# cursor.execute("select * from company")
#
# # 使用 fetchone() 方法获取一条数据
# data = cursor.fetchone()
# print(data)

View File

@@ -1,4 +1,3 @@
chat_id: 3
agent_id: 4
domain: internet
url: http://localhost:9080

View File

@@ -0,0 +1,17 @@
[
{
"query": "在各公司所有品牌收入排名中,给出每一个品牌,其所在公司以及收入占该公司的总收入比例,同时给出该公司的年营业额",
"gold_sql": "SELECT T3.company_name, T3.annual_turnover, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id",
"pred_sql": "select * from tablea"
},
{
"query": "在各公司所有品牌收入排名中,给出每一个品牌,其所在公司以及收入占该公司的总收入比例",
"gold_sql": "SELECT T3.company_name, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id",
"pred_sql": "select * from tablea"
},
{
"query": "在各公司所有品牌收入排名中,给出每一个品牌和其法人,其所在公司以及收入占该公司的总收入比例",
"gold_sql": "SELECT T3.company_name, T2.brand_name, T2.legal_representative, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id",
"pred_sql": "select * from tablea"
}
]

View File

@@ -513,7 +513,6 @@ def evaluate(gold, predict, db_dir, etype, kmaps,query_path):
db_name = db
# db = os.path.join(db_dir, db, db + ".sqlite")
db = os.path.join(db_dir,db + ".db")
print(db)
schema = Schema(get_schema(db))
g_sql = get_sql(schema, g_str)
hardness = evaluator.eval_hardness(g_sql)
@@ -597,7 +596,7 @@ def evaluate(gold, predict, db_dir, etype, kmaps,query_path):
print_scores(scores, etype)
print(scores['all']['exec'])
current_directory = os.path.dirname(os.path.abspath(__file__))
file_name=current_directory+"/eval.json"
file_name=current_directory+"/error_case.json"
json_exist=os.path.exists(file_name)
if json_exist:
os.remove(file_name)
@@ -884,11 +883,11 @@ def get_evaluation_result():
config = yaml.safe_load(file)
db_dir=current_directory+"/data"
db_path=current_directory+"/data/"
db_file=db_path+config["domain"]+".db"
db_file=db_path+"internet.db"
pred = current_directory+"/data/"+"pred_example_dusql.txt"
gold = current_directory+"/data/"+"gold_example_dusql.txt"
table= current_directory+"/data/"+"tables_dusql.json"
query_path=current_directory+"/data/"+config["domain"]+".txt"
query_path=current_directory+"/data/"+"internet.txt"
etype="exec"
kmaps = build_foreign_key_map_from_json(table)
@@ -900,7 +899,7 @@ def remove_unused_file():
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
db_path=current_directory+"/data/"
db_file=db_path+config["domain"]+".db"
db_file=db_path+"internet.db"
pred_file = current_directory+"/data/"+"pred_example_dusql.txt"
db_exist=os.path.exists(db_file)

View File

@@ -6,18 +6,28 @@ import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
import com.tencent.supersonic.headless.api.pojo.ViewDetail;
import com.tencent.supersonic.headless.api.pojo.ViewModelConfig;
import com.tencent.supersonic.headless.api.pojo.Identify;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.MetricTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
import com.tencent.supersonic.headless.api.pojo.request.DomainReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.ViewReq;
import com.tencent.supersonic.headless.server.service.DomainService;
import com.tencent.supersonic.headless.server.service.ModelRelaService;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.ViewService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
@@ -40,6 +50,9 @@ public class BenchMarkDemoDataLoader {
@Autowired
private ModelRelaService modelRelaService;
@Autowired
private ViewService viewService;
public void doRun() {
try {
addDomain();
@@ -47,6 +60,7 @@ public class BenchMarkDemoDataLoader {
addModel_2();
addModel_3();
addModel_4();
addView_1();
addModelRela_1();
addModelRela_2();
addModelRela_3();
@@ -194,6 +208,42 @@ public class BenchMarkDemoDataLoader {
modelService.createModel(modelReq, user);
}
public void addView_1() {
ViewReq viewReq = new ViewReq();
viewReq.setName("cspider");
viewReq.setBizName("singer");
viewReq.setDomainId(3L);
viewReq.setDescription("包含cspider数据集相关标签和指标信息");
viewReq.setAdmins(Lists.newArrayList("admin"));
List<ViewModelConfig> viewModelConfigs = Lists.newArrayList(
new ViewModelConfig(5L, Lists.newArrayList(8L), Lists.newArrayList()),
new ViewModelConfig(6L, Lists.newArrayList(9L, 10L), Lists.newArrayList()),
new ViewModelConfig(7L, Lists.newArrayList(11L, 12L), Lists.newArrayList()),
new ViewModelConfig(8L, Lists.newArrayList(13L, 14L, 15L), Lists.newArrayList(8L, 9L))
);
ViewDetail viewDetail = new ViewDetail();
viewDetail.setViewModelConfigs(viewModelConfigs);
viewReq.setViewDetail(viewDetail);
viewReq.setTypeEnum(TypeEnums.VIEW);
QueryConfig queryConfig = new QueryConfig();
TagTypeDefaultConfig tagTypeDefaultConfig = new TagTypeDefaultConfig();
TimeDefaultConfig tagTimeDefaultConfig = new TimeDefaultConfig();
tagTimeDefaultConfig.setTimeMode(TimeMode.LAST);
tagTimeDefaultConfig.setUnit(7);
tagTypeDefaultConfig.setTimeDefaultConfig(tagTimeDefaultConfig);
tagTypeDefaultConfig.setDimensionIds(Lists.newArrayList());
tagTypeDefaultConfig.setMetricIds(Lists.newArrayList());
MetricTypeDefaultConfig metricTypeDefaultConfig = new MetricTypeDefaultConfig();
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
timeDefaultConfig.setTimeMode(TimeMode.RECENT);
timeDefaultConfig.setUnit(7);
metricTypeDefaultConfig.setTimeDefaultConfig(timeDefaultConfig);
queryConfig.setTagTypeDefaultConfig(tagTypeDefaultConfig);
queryConfig.setMetricTypeDefaultConfig(metricTypeDefaultConfig);
viewReq.setQueryConfig(queryConfig);
viewService.save(viewReq, User.getFakeUser());
}
public void addModelRela_1() {
List<JoinCondition> joinConditions = Lists.newArrayList();
joinConditions.add(new JoinCondition("g_name", "g_name", FilterOperatorEnum.EQUALS));
@@ -253,4 +303,4 @@ public class BenchMarkDemoDataLoader {
modelRelaReq.setJoinConditions(joinConditions);
modelRelaService.save(modelRelaReq, user);
}
}
}

View File

@@ -220,7 +220,7 @@ public class ChatDemoLoader implements CommandLineRunner {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setViewIds(Lists.newArrayList(5L, 6L, 7L, 8L));
llmParserTool.setViewIds(Lists.newArrayList(3L));
agentConfig.getTools().add(llmParserTool);
}
@@ -242,7 +242,7 @@ public class ChatDemoLoader implements CommandLineRunner {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setViewIds(Lists.newArrayList(9L, 10L, 11L, 12L));
llmParserTool.setViewIds(Lists.newArrayList(4L));
agentConfig.getTools().add(llmParserTool);
}

View File

@@ -5,6 +5,8 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
@@ -13,14 +15,21 @@ import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
import com.tencent.supersonic.headless.api.pojo.Identify;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.ViewDetail;
import com.tencent.supersonic.headless.api.pojo.ViewModelConfig;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.MetricTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.request.DomainReq;
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.ViewReq;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.server.service.DomainService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.ModelRelaService;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.ViewService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
@@ -46,6 +55,9 @@ public class DuSQLDemoDataLoader {
@Autowired
private MetricService metricService;
@Autowired
private ViewService viewService;
public void doRun() {
try {
addDomain();
@@ -53,6 +65,7 @@ public class DuSQLDemoDataLoader {
addModel_2();
addModel_3();
addModel_4();
addView_1();
addModelRela_1();
addModelRela_2();
addModelRela_3();
@@ -241,6 +254,34 @@ public class DuSQLDemoDataLoader {
}
public void addView_1() {
ViewReq viewReq = new ViewReq();
viewReq.setName("DuSQL 互联网企业");
viewReq.setBizName("internet");
viewReq.setDomainId(4L);
viewReq.setDescription("DuSQL互联网企业数据源相关的指标和维度等");
viewReq.setAdmins(Lists.newArrayList("admin"));
List<ViewModelConfig> viewModelConfigs = Lists.newArrayList(
new ViewModelConfig(9L, Lists.newArrayList(16L, 17L, 18L, 19L, 20L), Lists.newArrayList(10L, 11L)),
new ViewModelConfig(10L, Lists.newArrayList(21L, 22L, 23L), Lists.newArrayList(12L)),
new ViewModelConfig(11L, Lists.newArrayList(), Lists.newArrayList(13L, 14L, 15L)),
new ViewModelConfig(12L, Lists.newArrayList(24L), Lists.newArrayList(16L, 17L, 18L, 19L)));
ViewDetail viewDetail = new ViewDetail();
viewDetail.setViewModelConfigs(viewModelConfigs);
viewReq.setViewDetail(viewDetail);
viewReq.setTypeEnum(TypeEnums.VIEW);
QueryConfig queryConfig = new QueryConfig();
MetricTypeDefaultConfig metricTypeDefaultConfig = new MetricTypeDefaultConfig();
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
timeDefaultConfig.setTimeMode(TimeMode.RECENT);
timeDefaultConfig.setUnit(1);
metricTypeDefaultConfig.setTimeDefaultConfig(timeDefaultConfig);
queryConfig.setMetricTypeDefaultConfig(metricTypeDefaultConfig);
viewReq.setQueryConfig(queryConfig);
viewService.save(viewReq, User.getFakeUser());
}
public void addModelRela_1() {
List<JoinCondition> joinConditions = Lists.newArrayList();
joinConditions.add(new JoinCondition("company_id", "company_id", FilterOperatorEnum.EQUALS));

View File

@@ -548,4 +548,4 @@ public class ModelDemoDataLoader {
return relateDimension;
}
}
}