mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)add corrector additional information switch and evalution time cost (#743)
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.core.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
@@ -12,6 +13,7 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
|||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
import org.springframework.core.env.Environment;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -77,7 +79,13 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||||
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
|
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
|
||||||
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
|
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
|
||||||
//needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
|
|
||||||
|
//decide whether add order by expression field to select
|
||||||
|
Environment environment = ContextUtils.getBean(Environment.class);
|
||||||
|
String correctorAdditionalInfo = environment.getProperty("corrector.additional.information");
|
||||||
|
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||||
|
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
|
||||||
|
}
|
||||||
|
|
||||||
// If there is no aggregate function in the S2SQL statement and
|
// If there is no aggregate function in the S2SQL statement and
|
||||||
// there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement.
|
// there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement.
|
||||||
|
|||||||
@@ -3,11 +3,14 @@ package com.tencent.supersonic.chat.core.corrector;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.core.env.Environment;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -26,8 +29,12 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
|||||||
//add aggregate to all metric
|
//add aggregate to all metric
|
||||||
addHaving(queryContext, semanticParseInfo);
|
addHaving(queryContext, semanticParseInfo);
|
||||||
|
|
||||||
//add having expression filed to select
|
//decide whether add having expression field to select
|
||||||
//addHavingToSelect(semanticParseInfo);
|
Environment environment = ContextUtils.getBean(Environment.class);
|
||||||
|
String correctorAdditionalInfo = environment.getProperty("corrector.additional.information");
|
||||||
|
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||||
|
addHavingToSelect(semanticParseInfo);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -65,12 +65,19 @@ def get_pred_result():
|
|||||||
questions=read_query(input_path)
|
questions=read_query(input_path)
|
||||||
pred_sql_list=[]
|
pred_sql_list=[]
|
||||||
default_sql="select * from tablea "
|
default_sql="select * from tablea "
|
||||||
|
time_cost=[]
|
||||||
for i in range(0,len(questions)):
|
for i in range(0,len(questions)):
|
||||||
|
start_time = time.time()
|
||||||
pred_sql=get_pred_sql(questions[i],url,agent_id,chat_id,authorization,default_sql)
|
pred_sql=get_pred_sql(questions[i],url,agent_id,chat_id,authorization,default_sql)
|
||||||
|
end_time = time.time()
|
||||||
|
cost='%.3f'%(end_time-start_time)
|
||||||
|
time_cost.append(cost)
|
||||||
pred_sql_list.append(pred_sql)
|
pred_sql_list.append(pred_sql)
|
||||||
time.sleep(60)
|
time.sleep(60)
|
||||||
write_sql(pred_sql_path, pred_sql_list)
|
write_sql(pred_sql_path, pred_sql_list)
|
||||||
|
|
||||||
|
return [float(cost) for cost in time_cost]
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("pred")
|
print("pred")
|
||||||
|
|
||||||
|
|||||||
@@ -482,7 +482,7 @@ def print_scores(scores, etype):
|
|||||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||||
|
|
||||||
|
|
||||||
def evaluate(gold, predict, db_dir, etype, kmaps,query_path):
|
def evaluate(gold, predict, db_dir, etype, kmaps,query_path,time_cost):
|
||||||
with open(gold) as f:
|
with open(gold) as f:
|
||||||
glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
|
glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
|
||||||
|
|
||||||
@@ -597,7 +597,11 @@ def evaluate(gold, predict, db_dir, etype, kmaps,query_path):
|
|||||||
scores[level]['partial'][type_]['f1'] = \
|
scores[level]['partial'][type_]['f1'] = \
|
||||||
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
|
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
|
||||||
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
|
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
|
||||||
|
cost_dic = {}
|
||||||
|
cost_dic["max_time"] = max(time_cost)
|
||||||
|
cost_dic["min_time"] = min(time_cost)
|
||||||
|
cost_dic["avg_time"] = sum(time_cost)/len(time_cost)
|
||||||
|
log_list.append(cost_dic)
|
||||||
print_scores(scores, etype)
|
print_scores(scores, etype)
|
||||||
print(scores['all']['exec'])
|
print(scores['all']['exec'])
|
||||||
current_directory = os.path.dirname(os.path.abspath(__file__))
|
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||||||
@@ -608,7 +612,6 @@ def evaluate(gold, predict, db_dir, etype, kmaps,query_path):
|
|||||||
with open(file_name, 'w') as json_file:
|
with open(file_name, 'w') as json_file:
|
||||||
json.dump(log_list, json_file, indent=4, ensure_ascii=False)
|
json.dump(log_list, json_file, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def eval_exec_match(db, p_str, g_str, pred, gold):
|
def eval_exec_match(db, p_str, g_str, pred, gold):
|
||||||
"""
|
"""
|
||||||
return 1 if the values between prediction and gold are matching
|
return 1 if the values between prediction and gold are matching
|
||||||
@@ -890,7 +893,7 @@ def build_foreign_key_map_from_json(table):
|
|||||||
tables[entry['db_id']] = build_foreign_key_map(entry)
|
tables[entry['db_id']] = build_foreign_key_map(entry)
|
||||||
return tables
|
return tables
|
||||||
|
|
||||||
def get_evaluation_result():
|
def get_evaluation_result(time_cost):
|
||||||
current_directory = os.path.dirname(os.path.abspath(__file__))
|
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||||||
config_file=current_directory+"/config/config.yaml"
|
config_file=current_directory+"/config/config.yaml"
|
||||||
with open(config_file, 'r') as file:
|
with open(config_file, 'r') as file:
|
||||||
@@ -905,7 +908,7 @@ def get_evaluation_result():
|
|||||||
etype="exec"
|
etype="exec"
|
||||||
kmaps = build_foreign_key_map_from_json(table)
|
kmaps = build_foreign_key_map_from_json(table)
|
||||||
|
|
||||||
evaluate(gold, pred, db_dir, etype, kmaps,query_path)
|
evaluate(gold, pred, db_dir, etype, kmaps,query_path,time_cost)
|
||||||
|
|
||||||
def remove_unused_file():
|
def remove_unused_file():
|
||||||
current_directory = os.path.dirname(os.path.abspath(__file__))
|
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||||||
@@ -927,8 +930,8 @@ def remove_unused_file():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
build_table()
|
build_table()
|
||||||
get_pred_result()
|
time_cost=get_pred_result()
|
||||||
get_evaluation_result()
|
get_evaluation_result(time_cost)
|
||||||
remove_unused_file()
|
remove_unused_file()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -46,4 +46,4 @@ com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
|
|||||||
com.tencent.supersonic.chat.server.processor.execute.MetricRatioProcessor
|
com.tencent.supersonic.chat.server.processor.execute.MetricRatioProcessor
|
||||||
|
|
||||||
com.tencent.supersonic.common.util.embedding.S2EmbeddingStore=\
|
com.tencent.supersonic.common.util.embedding.S2EmbeddingStore=\
|
||||||
com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore
|
com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore
|
||||||
|
|||||||
@@ -42,6 +42,9 @@ metric:
|
|||||||
mybatis:
|
mybatis:
|
||||||
mapper-locations=classpath:mappers/custom/*.xml,classpath*:/mappers/*.xml
|
mapper-locations=classpath:mappers/custom/*.xml,classpath*:/mappers/*.xml
|
||||||
|
|
||||||
|
corrector:
|
||||||
|
additional:
|
||||||
|
information: true
|
||||||
|
|
||||||
llm:
|
llm:
|
||||||
parser:
|
parser:
|
||||||
|
|||||||
Reference in New Issue
Block a user