diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/BaseSemanticCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/BaseSemanticCorrector.java index 2f43864c9..48011fe23 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/BaseSemanticCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/BaseSemanticCorrector.java @@ -77,7 +77,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) { Set selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL)); Set needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL)); - needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL)); + //needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL)); // If there is no aggregate function in the S2SQL statement and // there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement. diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/GroupByCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/GroupByCorrector.java index 9b9f0681b..696c6cb60 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/GroupByCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/GroupByCorrector.java @@ -5,13 +5,21 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; import com.tencent.supersonic.chat.core.pojo.QueryContext; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; +import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper; +import com.tencent.supersonic.headless.api.pojo.Dim; +import com.tencent.supersonic.headless.api.pojo.response.ModelResp; +import com.tencent.supersonic.headless.api.pojo.response.ViewResp; +import com.tencent.supersonic.headless.server.pojo.MetaFilter; +import com.tencent.supersonic.headless.server.service.ModelService; +import com.tencent.supersonic.headless.server.service.ViewService; import lombok.extern.slf4j.Slf4j; import org.springframework.util.CollectionUtils; import java.util.HashSet; import java.util.List; +import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; @@ -23,11 +31,35 @@ public class GroupByCorrector extends BaseSemanticCorrector { @Override public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { - + Boolean addGroupBy = addGroupBy(queryContext, semanticParseInfo); + log.info("addGroupBy:{}", addGroupBy); + if (!addGroupBy) { + return; + } addGroupByFields(queryContext, semanticParseInfo); } + private Boolean addGroupBy(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + Long viewId = semanticParseInfo.getViewId(); + ViewService viewService = ContextUtils.getBean(ViewService.class); + ModelService modelService = ContextUtils.getBean(ModelService.class); + ViewResp viewResp = viewService.getView(viewId); + List modelIds = viewResp.getViewDetail().getViewModelConfigs().stream().map(config -> config.getId() + ).collect(Collectors.toList()); + MetaFilter metaFilter = new MetaFilter(modelIds); + List modelRespList = modelService.getModelList(metaFilter); + for (ModelResp modelResp : modelRespList) { + List dimList = modelResp.getModelDetail().getDimensions(); + for (Dim dim : dimList) { + if (Objects.nonNull(dim.getTypeParams()) && dim.getTypeParams().getTimeGranularity().equals("none")) { + return false; + } + } + } + return true; + } + private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { Long viewId = semanticParseInfo.getViewId(); //add dimension group by diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/HavingCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/HavingCorrector.java index b4e40aa95..ba4e48e0d 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/HavingCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/HavingCorrector.java @@ -27,7 +27,7 @@ public class HavingCorrector extends BaseSemanticCorrector { addHaving(queryContext, semanticParseInfo); //add having expression filed to select - addHavingToSelect(semanticParseInfo); + //addHavingToSelect(semanticParseInfo); } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectHelper.java index f884814e1..f1596d3d8 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectHelper.java @@ -484,7 +484,15 @@ public class SqlSelectHelper { SelectBody selectBody = selectStatement.getSelectBody(); if (selectBody instanceof PlainSelect) { PlainSelect plainSelect = (PlainSelect) selectBody; - return (Table) plainSelect.getFromItem(); + if (plainSelect.getFromItem() instanceof Table) { + return (Table) plainSelect.getFromItem(); + } + if (plainSelect.getFromItem() instanceof SubSelect) { + + SubSelect subSelect = (SubSelect) plainSelect.getFromItem(); + return getTable(subSelect.getSelectBody().toString()); + } + } else if (selectBody instanceof SetOperationList) { SetOperationList setOperationList = (SetOperationList) selectBody; if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { diff --git a/evaluation/README_CN.md b/evaluation/README_CN.md index 61ec9f2ea..ba1818e92 100644 --- a/evaluation/README_CN.md +++ b/evaluation/README_CN.md @@ -5,4 +5,4 @@ # 评测意义 -制定评估工具方便supersonic快速对接其他模型、更改参数配置,对于提示词或代码更改的影响至关重要,可以帮助我们了解这些变化是否会提高或降低准确率、响应速度。 +制定评测工具方便supersonic快速对接其他大模型、更改参数配置,对于评估提示词、代码更改所带来的影响至关重要,可以帮助我们了解这些变化是否会提高或降低准确率、响应速度。 diff --git a/evaluation/build_models.py b/evaluation/build_models.py index 69da4561c..23fe27ac5 100644 --- a/evaluation/build_models.py +++ b/evaluation/build_models.py @@ -9,7 +9,7 @@ import jwt def get_authorization(): - exp = time.time() + 1000 + exp = time.time() + 100000 token= jwt.encode({"token_user_name": "admin","exp": exp}, "secret", algorithm="HS512") return "Bearer "+token diff --git a/evaluation/build_pred_result.py b/evaluation/build_pred_result.py index a9b294734..e5ef4c18f 100644 --- a/evaluation/build_pred_result.py +++ b/evaluation/build_pred_result.py @@ -25,7 +25,7 @@ def get_pred_sql(query,url,agentId,chatId,authorization,default_sql): header["Authorization"] =authorization try: result = requests.post(url=url, headers=header, json=data) - print(result.json()) + #print(result.json()) print(result.json()["traceId"]) if result.status_code == 200: data = result.json()["data"] @@ -68,7 +68,7 @@ def get_pred_result(): for i in range(0,len(questions)): pred_sql=get_pred_sql(questions[i],url,agent_id,chat_id,authorization,default_sql) pred_sql_list.append(pred_sql) - time.sleep(30) + time.sleep(60) write_sql(pred_sql_path, pred_sql_list) if __name__ == "__main__": diff --git a/evaluation/build_tables.py b/evaluation/build_tables.py index e903da727..6c7779bee 100644 --- a/evaluation/build_tables.py +++ b/evaluation/build_tables.py @@ -56,10 +56,10 @@ def build_internet(path,day): VALUES (?, ?, ?,?, ?, ?,?) ''' data = [ - (imp_date,"item_enterprise_13_136","阿里云","2009年9月10日","item_enterprise_13_134","张勇",50000000), - (imp_date,"item_enterprise_13_137","天猫","2012年1月11日","item_enterprise_13_134","张勇",100000000), - (imp_date,"item_enterprise_13_138","腾讯游戏","2003","item_enterprise_13_131","马化腾",50000000), - (imp_date,"item_enterprise_13_139","度小满","2018","item_enterprise_13_132","朱光",100000000), + (imp_date,"item_enterprise_13_136","阿里云","2009年9月10日","item_enterprise_13_132","张勇",50000000), + (imp_date,"item_enterprise_13_137","天猫","2012年1月11日","item_enterprise_13_132","张勇",100000000), + (imp_date,"item_enterprise_13_138","腾讯游戏","2003","item_enterprise_13_133","马化腾",50000000), + (imp_date,"item_enterprise_13_139","度小满","2018","item_enterprise_13_131","朱光",100000000), (imp_date,"item_enterprise_13_140","京东金融","2017","item_enterprise_13_134","刘强东",100000000) ] cursor.executemany(insert_data_query, data) @@ -83,10 +83,10 @@ def build_internet(path,day): ''' data = [ (imp_date,"item_enterprise_13_131","item_enterprise_13_139",0.1,0.1,0.3), - (imp_date,"item_enterprise_13_134","item_enterprise_13_138",0.8,0.8,0.6), - (imp_date,"item_enterprise_13_135","item_enterprise_13_139",0.8,0.8,0.6), - (imp_date,"item_enterprise_13_131","item_enterprise_13_137",0.8,0.8,0.6), - (imp_date,"item_enterprise_13_135","item_enterprise_13_137",0.1,0.1,0.3) + (imp_date,"item_enterprise_13_133","item_enterprise_13_138",0.8,0.8,0.6), + (imp_date,"item_enterprise_13_134","item_enterprise_13_140",0.8,0.8,0.6), + (imp_date,"item_enterprise_13_132","item_enterprise_13_137",0.8,0.8,0.6), + (imp_date,"item_enterprise_13_132","item_enterprise_13_136",0.1,0.1,0.3) ] cursor.executemany(insert_data_query, data) conn.commit() @@ -113,7 +113,7 @@ def build_internet(path,day): (imp_date, "2019", "item_enterprise_13_136", 100000000000, 50000000000, 1, 0.5), (imp_date, "2018", "item_enterprise_13_137", 100000000000, 50000000000, 1, -0.1), (imp_date, "2018", "item_enterprise_13_139", 500000000, 50000000000, 0.1, 0.5), - (imp_date, "2018", "item_enterprise_13_138", 100000000000, -300000000, 0.1, 0.5) + (imp_date, "2018", "item_enterprise_13_140", 100000000000, -300000000, 0.1, 0.5) ] cursor.executemany(insert_data_query, data) conn.commit() diff --git a/evaluation/data/gold_example_dusql.txt b/evaluation/data/gold_example_dusql.txt index d5fc41160..b0542af32 100644 --- a/evaluation/data/gold_example_dusql.txt +++ b/evaluation/data/gold_example_dusql.txt @@ -38,10 +38,6 @@ SELECT T1.brand_name, T2.company_name, T1.registered_capital, T2.headquarter_add SELECT T1.brand_name, T2.company_name, T1.legal_representative, T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id internet SELECT T2.company_name, T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id internet SELECT T2.company_name, T2.annual_turnover FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id internet -SELECT T3.company_name, T3.headquarter_address, T2.brand_name, T1.revenue FROM company_brand_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T2.company_id = T3.company_id internet -SELECT T3.company_name, T3.annual_turnover, T2.brand_name, T1.revenue FROM company_brand_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T2.company_id = T3.company_id internet -SELECT T3.company_name, T2.brand_name, T2.legal_representative, T1.revenue FROM company_brand_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T2.company_id = T3.company_id internet -SELECT T3.company_name, T2.brand_name, T1.revenue FROM company_brand_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T2.company_id = T3.company_id internet SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital >= 100000000 internet SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital < 100000000 internet SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital > 100000000 internet @@ -98,3 +94,7 @@ SELECT T2.legal_representative, T2.brand_name, sum(T1.revenue) FROM company_bran SELECT T2.legal_representative, T2.brand_name, max(T1.revenue) FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id internet SELECT T2.headquarter_address, T2.company_name, sum(T1.registered_capital) FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id GROUP BY T1.company_id internet SELECT T2.headquarter_address, T2.company_name, max(T1.registered_capital) FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id GROUP BY T1.company_id internet +select T2.headquarter_address , T2.company_name, min(T1.registered_capital) from brand as T1 join company as T2 on T1.company_id = T2.company_id group by T1.company_id internet +select T2.headquarter_address , T2.company_name, avg(T1.registered_capital) from brand as T1 join company as T2 on T1.company_id = T2.company_id group by T1.company_id internet +select T2.headquarter_address , T2.company_name from company_revenue as T1 join company as T2 on T1.company_id = T2.company_id group by T1.company_id having sum(T1.revenue_proportion) > 0.5 internet +select T2.headquarter_address , T2.company_name from company_revenue as T1 join company as T2 on T1.company_id = T2.company_id group by T1.company_id having count(*) > 5 internet diff --git a/evaluation/data/internet.txt b/evaluation/data/internet.txt index 2c7eb4ded..834008ef1 100644 --- a/evaluation/data/internet.txt +++ b/evaluation/data/internet.txt @@ -2,9 +2,9 @@ 在各公司所有品牌收入排名中,给出每一个品牌,其所在公司以及收入占该公司的总收入比例 在各公司所有品牌收入排名中,给出每一个品牌和其法人,其所在公司以及收入占该公司的总收入比例 在各公司所有品牌收入排名中,给出每一个品牌,其所在公司以及收入占该公司的总收入比例,同时给出该公司总部所在地 -在公司各品牌收入排名的利润占比最多10%时,给出公司的名称品牌的名称并给出公司各品牌收入排名的营收占比 -在公司各品牌收入排名的利润占比小于10%时,给出公司的名称品牌的名称并给出公司各品牌收入排名的营收占比 -在公司各品牌收入排名的利润占比不止10%时,给出公司的名称品牌的名称并给出公司各品牌收入排名的营收占比 +在公司各品牌收入排名的利润占比最多0.1时,给出公司的名称品牌的名称并给出公司各品牌收入排名的营收占比 +在公司各品牌收入排名的利润占比小于0.1时,给出公司的名称品牌的名称并给出公司各品牌收入排名的营收占比 +在公司各品牌收入排名的利润占比不止0.1时,给出公司的名称品牌的名称并给出公司各品牌收入排名的营收占比 注册资本不小于1亿的品牌中,哪个品牌的平均营收占比最大?并给出它的法定代表人 注册资本大于1亿的品牌中,哪5个品牌收入最少?并给出它们的法定代表人 找到注册资本少于一个亿的品牌及其法人,并给出对应公司的平均营收占比 @@ -12,17 +12,17 @@ 给出注册资本不超过一亿的的品牌及其法人,并给出对应的公司的总营收占比 给出注册资本少于一亿的品牌及其法人,并给出对应的公司的最大营收占比 给出注册资本不超过1亿的品牌及其法人,并找出对应的公司的最大营收占比 -在注册资本超过一亿的公司中,给出公司各品牌收入品牌的平均营收占比正好50%的品牌及其法人 +在注册资本超过一亿的公司中,给出公司各品牌收入品牌的平均营收占比正好0.5的品牌及其法人 在注册资本不超过一亿的公司中,给出个品牌收入排名五的品牌及其法人 -在各品牌收入在公司排名中,当品牌的注册资本不大于1亿时,给出公司各品牌收入排名的支出占比的平均值小于等于45%的那些品牌的名称以及公司各品牌收入排名的营收占比的平均值 +在各品牌收入在公司排名中,当品牌的注册资本不大于1亿时,给出公司各品牌收入排名的支出占比的平均值小于等于0.45的那些品牌的名称以及公司各品牌收入排名的营收占比的平均值 在各品牌收入在公司排名中,当品牌的注册资本小于1亿时,给出公司各品牌收入排名数量小于等于5的那些品牌的名称以及公司各品牌收入排名的营收占比的最小值 在各品牌收入在公司排名中,给出每个品牌的名称,品牌的法定代表人,以及公司各品牌收入排名的营收占比的平均值 在各品牌收入在公司排名中,给出每个品牌的名称,品牌的法定代表人,以及公司各品牌收入排名的营收占比的最小值 在各品牌收入在公司排名中,给出每个品牌的名称,品牌的法定代表人,以及公司各品牌收入排名的营收占比的总和 在各品牌收入在公司排名中,给出每个品牌的名称,品牌的法定代表人,以及公司各品牌收入排名的营收占比的最大值 在各品牌收入在公司排名中,给出收入排名不超过5的品牌及其法人 -在各品牌收入在公司排名中,给出在收入排名中的平均营收占比超过50%的品牌及其法人 -在各品牌收入在公司排名中,当公司各品牌收入排名的利润占比的平均值小于等于60%时,给出品牌的名称以及公司各品牌收入排名的营收占比的平均值 +在各品牌收入在公司排名中,给出在收入排名中的平均营收占比超过0.5的品牌及其法人 +在各品牌收入在公司排名中,当公司各品牌收入排名的利润占比的平均值小于等于0.6时,给出品牌的名称以及公司各品牌收入排名的营收占比的平均值 在各品牌收入在公司排名中,当公司各品牌收入排名数量等于5时,给出品牌的名称以及公司各品牌收入排名的营收占比的最小值 哪个品牌收入的平均利润占比最大,给出品牌的法定代表人,以及其收入平均营收占比 哪3个品牌的收入最多,给出品牌的法定代表人,以及其收入总营收占比 @@ -38,10 +38,6 @@ 给出每一个品牌和其法人,所属的公司以及总部所在城市 有自己品牌的公司有哪些?给出这些公司和总部所在地 有自己品牌的公司有哪些?给出这些公司和年营业额 -在各公司其品牌的历年收入中,给出每一个品牌,其所属的公司和公司总部所在地点,并给出该品牌近几年的营收 -在各公司其品牌的历年收入中,给出每一个品牌,其所属的公司和公司年营业额,并给出该品牌近几年的营收 -在各公司其品牌的历年收入中,给出每一个品牌,其所属的公司和公司法人,并给出该品牌近几年的营收 -在各公司其品牌的历年收入中,给出每一个品牌,其所属的公司,以及该品牌近几年的营收 在品牌的注册资本至少1亿时,给出公司的名称以及公司的总部地点品牌的法定代表人 在品牌的注册资本少于1亿时,给出公司的名称以及公司的总部地点品牌的法定代表人 在品牌的注册资本超过1亿时,给出公司的名称以及公司的总部地点品牌的法定代表人 @@ -56,7 +52,7 @@ 给出注册资本不超过一亿,且年营业额少于288亿的公司,以及总部地点和法人 找出注册资本少于一亿,且年营业额不超过288亿的公司,以及总部在哪,法人是谁 在公司品牌历年收入的利润最多500亿时,给出公司的名称品牌的名称并给出公司品牌历年收入的营收 -在公司品牌历年收入的营收同比增长至少100%时,给出公司的名称品牌的名称并给出公司品牌历年收入的营收 +在公司品牌历年收入的营收同比增长至少1时,给出公司的名称品牌的名称并给出公司品牌历年收入的营收 年营业额不小于288亿的公司中,哪5个公司的平均营收占比最少?,并给出它们的总部地点 年营业额不小于288亿的公司中,哪个公司的平均营收占比最大?并给出它的总部地点 注册资本大于1亿的品牌中,哪个品牌历年收入的平均营收最大?并给出它的法定代表人 @@ -78,12 +74,12 @@ 给出不超过288亿年营业额的公司及总部地点,并给出这些品牌中的平均注册资本 给出不超过288亿年营业额的公司及其总部地点,并给出这些品牌的的最大注册资本 给出年营业额超过288亿的公司及其总部地点,并给出这些品牌的的总注册资本 -给出年营业额不超过288亿的各公司品牌中,给出收入排名中的营收占比加起来不超过50%的公司及其总部地点 -给出年营业额低于288亿的各公司的各品牌中,给出收入排名中的总营收占比不超过50%的公司及总部地点中 +给出年营业额不超过288亿的各公司品牌中,给出收入排名中的营收占比加起来不超过0.5的公司及其总部地点 +给出年营业额低于288亿的各公司的各品牌中,给出收入排名中的总营收占比不超过0.5的公司及总部地点中 在年营业额不少于288亿的公司中,给出品牌不超过5个的公司及其总部地点 在年营业额不超过288亿的公司中,给出品牌少于5个的公司及其总部地点 在各公司的各品牌收入排名中,当公司的年营业额大于288亿时,给出公司各品牌收入排名数量大于5的那些公司的名称以及公司各品牌收入排名的营收占比的最小值 -在各公司的各品牌收入排名中,当公司的年营业额小于288亿时,给出公司各品牌收入排名的利润占比的平均值大于75%的那些公司的名称以及公司各品牌收入排名的营收占比的最大值 +在各公司的各品牌收入排名中,当公司的年营业额小于288亿时,给出公司各品牌收入排名的利润占比的平均值大于0.75的那些公司的名称以及公司各品牌收入排名的营收占比的最大值 在各品牌的历年收入中,当品牌的注册资本不大于1亿时,给出公司品牌历年收入的利润同比增长的总和小于等于1000000的那些品牌的名称以及公司品牌历年收入的营收的最大值 在各品牌的历年收入中,当品牌的注册资本小于1亿时,给出公司品牌历年收入数量小于5的那些品牌的名称以及公司品牌历年收入的营收的最小值 在各品牌所属的公司中,当公司的年营业额大于288亿时,给出品牌数量小于5的那些公司的名称以及品牌的注册资本的总和 @@ -98,3 +94,7 @@ 在各品牌的历年收入中,给出每个品牌的名称,品牌的法定代表人,以及公司品牌历年收入的营收的最大值 在各品牌所属的公司中,给出每个公司的名称,公司的总部地点,以及品牌的注册资本的总和 在各品牌所属的公司中,给出每个公司的名称,公司的总部地点,以及品牌的注册资本的最大值 +在各品牌所属的公司中,给出每个公司的名称,公司的总部地点,以及品牌的注册资本的最小值 +在各品牌所属的公司中,给出每个公司的名称,公司的总部地点,以及品牌的注册资本的平均值 +在各公司的各品牌收入排名种,哪些公司的品牌排名的总营收占比超过0.5,并给出总部的地点 +在各公司的各品牌收入排名中,哪些公司的品牌收入排名超过5,并给出公司的总部地点 diff --git a/evaluation/evaluation.py b/evaluation/evaluation.py index 081201689..ae2ed1cfb 100644 --- a/evaluation/evaluation.py +++ b/evaluation/evaluation.py @@ -522,14 +522,19 @@ def evaluate(gold, predict, db_dir, etype, kmaps,query_path): p_sql = p_str if etype in ["all", "exec"]: - exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql) - if not exec_score: + result = eval_exec_match(db, p_str, g_str, p_sql, g_sql) + #exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql) + if not result["equal"]: element={} element["query"]=questions[index] element["gold_sql"]=g_str element["pred_sql"]=p_str + if "p_res_map" in result: + element["p_res_map"]=result["p_res_map"] + if "q_res_map" in result: + element["q_res_map"]=result["q_res_map"] log_list.append(element) - if exec_score: + if result["equal"]: scores[hardness]['exec'] += 1.0 scores['all']['exec'] += 1.0 @@ -609,6 +614,7 @@ def eval_exec_match(db, p_str, g_str, pred, gold): return 1 if the values between prediction and gold are matching in the corresponding index. Currently not support multiple col_unit(pairs). """ + result={} conn = sqlite3.connect(db) cursor = conn.cursor() try: @@ -618,8 +624,10 @@ def eval_exec_match(db, p_str, g_str, pred, gold): for index in range(0,len(p_fields)): p_fields[index]=re.sub("t\d+.", "",p_fields[index].replace("`","").lower()) p_res = cursor.fetchall() - except: - return False + except Exception as e: + logging.info(e) + result["equal"]=False + return result cursor.execute(g_str) q_res = cursor.fetchall() @@ -635,9 +643,15 @@ def eval_exec_match(db, p_str, g_str, pred, gold): g_fields = parse_sql(g_str) - #print("p_res_map:{}".format(res_map(p_res, p_fields))) - #print("q_res_map:{}".format(res_map(q_res, g_fields))) - return res_map(p_res, p_fields) == res_map(q_res, g_fields) + p_res_map=res_map(p_res, p_fields) + q_res_map=res_map(q_res, g_fields) + # print("p_res_map:{}".format(p_res_map)) + # print("q_res_map:{}".format(q_res_map)) + result["equal"]=(p_res_map==q_res_map) + result["p_res_map"]=json.dumps(p_res_map, ensure_ascii=False) + result["q_res_map"]=json.dumps(q_res_map, ensure_ascii=False) + return result + #return res_map(p_res, p_fields) == res_map(q_res, g_fields) def parse_sql(sql): # 使用 sqlparse 库解析 SQL 查询语句 diff --git a/launchers/standalone/src/main/resources/application-local.yaml b/launchers/standalone/src/main/resources/application-local.yaml index 815e408a2..b4bf0c8d3 100644 --- a/launchers/standalone/src/main/resources/application-local.yaml +++ b/launchers/standalone/src/main/resources/application-local.yaml @@ -34,6 +34,11 @@ authentication: time: threshold: 100 +dimension: + topn: 20 +metric: + topn: 20 + mybatis: mapper-locations=classpath:mappers/custom/*.xml,classpath*:/mappers/*.xml diff --git a/launchers/standalone/src/main/resources/db/data-h2.sql b/launchers/standalone/src/main/resources/db/data-h2.sql index fb8c36f23..95b86d684 100644 --- a/launchers/standalone/src/main/resources/db/data-h2.sql +++ b/launchers/standalone/src/main/resources/db/data-h2.sql @@ -1118,22 +1118,22 @@ insert into company(imp_date,company_id,company_name,headquarter_address,company insert into company(imp_date,company_id,company_name,headquarter_address,company_established_time,founder,ceo,annual_turnover,employee_count) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_134','北京京东世纪贸易有限公司','北京','1998','刘强东','刘强东',28800000000,179000); insert into company(imp_date,company_id,company_name,headquarter_address,company_established_time,founder,ceo,annual_turnover,employee_count) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_135','网易公司','杭州','1997','丁磊','丁磊',67500000000,20000); -insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_136','阿里云','2009年9月10日','item_enterprise_13_134','张勇',50000000); -insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_137','天猫','2012年1月11日','item_enterprise_13_134','张勇',100000000); -insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_138','腾讯游戏','2003','item_enterprise_13_131','马化腾',50000000); -insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_139','度小满','2018','item_enterprise_13_132','朱光',100000000); +insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_136','阿里云','2009年9月10日','item_enterprise_13_132','张勇',50000000); +insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_137','天猫','2012年1月11日','item_enterprise_13_132','张勇',100000000); +insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_138','腾讯游戏','2003','item_enterprise_13_133','马化腾',50000000); +insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_139','度小满','2018','item_enterprise_13_131','朱光',100000000); insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_140','京东金融','2017','item_enterprise_13_134','刘强东',100000000); insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_131','item_enterprise_13_139',10,10,30); -insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_134','item_enterprise_13_138',80,80,60); -insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_135','item_enterprise_13_139',80,80,60); -insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_131','item_enterprise_13_137',80,80,60); -insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_135','item_enterprise_13_137',10,10,30); +insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_133','item_enterprise_13_138',80,80,60); +insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_134','item_enterprise_13_140',80,80,60); +insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_132','item_enterprise_13_137',80,80,60); +insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_132','item_enterprise_13_136',10,10,30); insert into company_brand_revenue(imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year) VALUES (DATEADD('DAY', -1, CURRENT_DATE()), '2018','item_enterprise_13_138',500000000,-300000000,10,-10); insert into company_brand_revenue(imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year) VALUES (DATEADD('DAY', -1, CURRENT_DATE()), '2019','item_enterprise_13_136',100000000000,50000000000,100,50); insert into company_brand_revenue(imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'2018','item_enterprise_13_137',100000000000,50000000000,100,-10); insert into company_brand_revenue(imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year) VALUES (DATEADD('DAY', -1, CURRENT_DATE()), '2018','item_enterprise_13_139',500000000,50000000000,10,50); -insert into company_brand_revenue(imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'2018','item_enterprise_13_138',100000000000,-300000000,10,50); +insert into company_brand_revenue(imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'2018','item_enterprise_13_140',100000000000,-300000000,10,50); -- benchmark