From f50a3157d504b2f0fee86a994bbc9cd32085794e Mon Sep 17 00:00:00 2001 From: mainmain <57514971+mainmainer@users.noreply.github.com> Date: Fri, 23 Feb 2024 15:05:13 +0800 Subject: [PATCH] (improvement)add corrector additional information switch and evalution time cost (#743) --- .../core/corrector/BaseSemanticCorrector.java | 10 +++++++++- .../chat/core/corrector/HavingCorrector.java | 11 +++++++++-- evaluation/build_pred_result.py | 7 +++++++ evaluation/evaluation.py | 17 ++++++++++------- .../main/resources/META-INF/spring.factories | 2 +- .../src/main/resources/application-local.yaml | 3 +++ 6 files changed, 39 insertions(+), 11 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/BaseSemanticCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/BaseSemanticCorrector.java index 48011fe23..09a8beaa3 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/BaseSemanticCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/BaseSemanticCorrector.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.chat.core.corrector; +import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; @@ -12,6 +13,7 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; +import org.springframework.core.env.Environment; import org.springframework.util.CollectionUtils; import java.util.ArrayList; @@ -77,7 +79,13 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) { Set selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL)); Set needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL)); - //needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL)); + + //decide whether add order by expression field to select + Environment environment = ContextUtils.getBean(Environment.class); + String correctorAdditionalInfo = environment.getProperty("corrector.additional.information"); + if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) { + needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL)); + } // If there is no aggregate function in the S2SQL statement and // there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement. diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/HavingCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/HavingCorrector.java index ba4e48e0d..e09d68ce2 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/HavingCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/HavingCorrector.java @@ -3,11 +3,14 @@ package com.tencent.supersonic.chat.core.corrector; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.core.pojo.QueryContext; +import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; +import org.apache.commons.lang3.StringUtils; +import org.springframework.core.env.Environment; import org.springframework.util.CollectionUtils; import java.util.List; @@ -26,8 +29,12 @@ public class HavingCorrector extends BaseSemanticCorrector { //add aggregate to all metric addHaving(queryContext, semanticParseInfo); - //add having expression filed to select - //addHavingToSelect(semanticParseInfo); + //decide whether add having expression field to select + Environment environment = ContextUtils.getBean(Environment.class); + String correctorAdditionalInfo = environment.getProperty("corrector.additional.information"); + if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) { + addHavingToSelect(semanticParseInfo); + } } diff --git a/evaluation/build_pred_result.py b/evaluation/build_pred_result.py index e5ef4c18f..d21324b28 100644 --- a/evaluation/build_pred_result.py +++ b/evaluation/build_pred_result.py @@ -65,12 +65,19 @@ def get_pred_result(): questions=read_query(input_path) pred_sql_list=[] default_sql="select * from tablea " + time_cost=[] for i in range(0,len(questions)): + start_time = time.time() pred_sql=get_pred_sql(questions[i],url,agent_id,chat_id,authorization,default_sql) + end_time = time.time() + cost='%.3f'%(end_time-start_time) + time_cost.append(cost) pred_sql_list.append(pred_sql) time.sleep(60) write_sql(pred_sql_path, pred_sql_list) + return [float(cost) for cost in time_cost] + if __name__ == "__main__": print("pred") diff --git a/evaluation/evaluation.py b/evaluation/evaluation.py index ae2ed1cfb..1e0246579 100644 --- a/evaluation/evaluation.py +++ b/evaluation/evaluation.py @@ -482,7 +482,7 @@ def print_scores(scores, etype): print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) -def evaluate(gold, predict, db_dir, etype, kmaps,query_path): +def evaluate(gold, predict, db_dir, etype, kmaps,query_path,time_cost): with open(gold) as f: glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] @@ -597,7 +597,11 @@ def evaluate(gold, predict, db_dir, etype, kmaps,query_path): scores[level]['partial'][type_]['f1'] = \ 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) - + cost_dic = {} + cost_dic["max_time"] = max(time_cost) + cost_dic["min_time"] = min(time_cost) + cost_dic["avg_time"] = sum(time_cost)/len(time_cost) + log_list.append(cost_dic) print_scores(scores, etype) print(scores['all']['exec']) current_directory = os.path.dirname(os.path.abspath(__file__)) @@ -608,7 +612,6 @@ def evaluate(gold, predict, db_dir, etype, kmaps,query_path): with open(file_name, 'w') as json_file: json.dump(log_list, json_file, indent=4, ensure_ascii=False) - def eval_exec_match(db, p_str, g_str, pred, gold): """ return 1 if the values between prediction and gold are matching @@ -890,7 +893,7 @@ def build_foreign_key_map_from_json(table): tables[entry['db_id']] = build_foreign_key_map(entry) return tables -def get_evaluation_result(): +def get_evaluation_result(time_cost): current_directory = os.path.dirname(os.path.abspath(__file__)) config_file=current_directory+"/config/config.yaml" with open(config_file, 'r') as file: @@ -905,7 +908,7 @@ def get_evaluation_result(): etype="exec" kmaps = build_foreign_key_map_from_json(table) - evaluate(gold, pred, db_dir, etype, kmaps,query_path) + evaluate(gold, pred, db_dir, etype, kmaps,query_path,time_cost) def remove_unused_file(): current_directory = os.path.dirname(os.path.abspath(__file__)) @@ -927,8 +930,8 @@ def remove_unused_file(): if __name__ == "__main__": build_table() - get_pred_result() - get_evaluation_result() + time_cost=get_pred_result() + get_evaluation_result(time_cost) remove_unused_file() diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index a8fbd7053..8cf1b54fd 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -46,4 +46,4 @@ com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\ com.tencent.supersonic.chat.server.processor.execute.MetricRatioProcessor com.tencent.supersonic.common.util.embedding.S2EmbeddingStore=\ - com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore \ No newline at end of file + com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore diff --git a/launchers/standalone/src/main/resources/application-local.yaml b/launchers/standalone/src/main/resources/application-local.yaml index b4bf0c8d3..1acba84ae 100644 --- a/launchers/standalone/src/main/resources/application-local.yaml +++ b/launchers/standalone/src/main/resources/application-local.yaml @@ -42,6 +42,9 @@ metric: mybatis: mapper-locations=classpath:mappers/custom/*.xml,classpath*:/mappers/*.xml +corrector: + additional: + information: true llm: parser: