CenterNet检测模型在迁移到昇腾AI处理器上训练后,发现Loss收敛效果较GPU明显变差。迁移前后Loss曲线如下图所示。蓝色为GPU训练曲线,收敛正常;红线为NPU训练曲线,收敛效果变差。
在PyTorch模型迁移与训练过程中,出现loss变化收敛不一致,通过dump比对工具进行定位。检测是否存在单算子与标杆数据(GPU/CPU)存在精度差异,进行数据dump,分析dump数据比对结果,从而定位出单算子的适配与实现问题。
请参考ptdbg_ascend工具使用说明安装ptdbg_ascend精度工具,并准备已适配NPU环境的CenterNet训练工程,点击获取链接-ModelZoo,下载CenterNet模型脚本。
cd src vi main_npu_8p.py
from ptdbg_ascend import register_hook, overflow_check, seed_all, set_dump_path, set_dump_switch, acc_cmp_dump
def main(): seed_all() ...
def main(): ... model = create_model(opt.arch, opt.heads, opt.head_conv, opt.load_local_weights, opt.local_weights_path) model = model.to(opt.device) #npu ... if opt.precision_mode == 'must_keep_origin_dtype': optimizer = torch.optim.Adam(model.parameters(), opt.lr) model, optimizer = amp.initialize(model, optimizer, opt_level="O0", combine_grad=False) ###npu else: optimizer = apex.optimizers.NpuFusedAdam(model.parameters(), opt.lr) model, optimizer = amp.initialize(model, optimizer, opt_level="O1",loss_scale=19.0,combine_grad=True) ###npu set_dump_path("./dump_data_new/npu") # 设置dump路径,最终数据保存在此路径下 register_hook(model, acc_cmp_dump) # 添加hook函数和数据比对dump开关 start_epoch = 0 ...
set_dump_switch("ON") for epoch in range(start_epoch + 1, opt.num_epochs + 1): ... logger.write('\n') set_dump_switch("OFF")
#NPU训练命令 bash ./test/train_full_1p.sh
模型训练结束后,数据会落盘到输出目录。样例输出目录如下图。
from ptdbg_ascend import compare dump_result_param={ "npu_pkl_path": "/home/torch_test/dump_data_new/npu/ptdbg_dump_v2.0/rank0/dump.pkl", "bench_pkl_path": "/home/torch_test/dump_data_new/cpu/ptdbg_dump_v2.0/rank0/dump.pkl", "npu_dump_data_dir": "/home/torch_test/dump_data_new/npu/ptdbg_dump_v2.0/rank0/dump", "bench_dump_data_dir": "/home/torch_test/dump_data_new/cpu/ptdbg_dump_v2.0/rank0/dump" } compare(dump_result_param, "./out")
python3 compare.py
比对完成后会在指定的输出目录中生成对比结果文件“compare_result_timestamp.csv”,示例文件如下所示。
比对结果为csv类型文件,会包含两份Dump数据的比对信息,包括tensor shape信息,数据类型,余弦相似度,最大绝对误差,最大值,最小值,平均值等统计信息。
通常情况下,我们通过余弦相似度,来判断API的输入或输出是否存在问题。大部分API输入输出的余弦相似度要求大于0.99。我们以此为标准分析CenterNet在昇腾AI处理器上的比对结果,流程如下:
工具内部定义的不达标规则如下: