图像分类场景ResNet50
ResNet50模型训练参数及范围
ResNet50模型训练的参数名、类型、取值范围、默认值及说明如表1所示。
参数名 |
类型 |
取值范围 |
默认值 |
说明 |
---|---|---|---|---|
--train_dataset_path |
str |
- |
None |
训练数据集路径。 |
--pretrained_ckpt_path |
str |
- |
"pre_trained_ckpt" |
预训练模型存放路径。 |
--train_output_path |
str |
- |
"train_output_path" |
训练结果输出路径。 |
--device_num |
int |
当前仅支持单卡,取值为1 |
1 |
训练NPU的数量。 |
--device_id |
int |
[0,7] |
0 |
训练NPU的id。 |
--run_eval |
bool |
True 或 False |
True |
训练过程中是否评估。 |
--batch_size |
int |
|
16 |
训练batch数。 |
--epoch_size |
int |
[1, 10000] |
10 |
训练次数。 |
--init_lr |
float |
(0, 1) |
0.0001 |
训练学习率。 |
--enable_modelarts |
bool |
True 或 False |
False |
是否使用ModelArts进行训练。 |
--data_url |
str |
- |
None |
在ModelArts训练时,数据集路径。 |
--train_output_url |
str |
- |
None |
在ModelArts训练时,输出路径。 |
--checkpoint_url |
str |
- |
None |
在ModelArts训练时,预训练模型路径。 |
--net_width |
int |
[64, 608] 且为16的整数倍 |
224 |
网络宽度。 |
--net_height |
int |
[64, 608] 且为16的整数倍 |
224 |
网络高度。 |
ResNet50模型训练命令参考
ResNet50模型训练的命令参考如下:
python3 model_train.py --train_dataset_path={train_dataset_path} --train_output_path=./output_path --epoch_size=10 --batch_size=16 --init_lr=0.0001 --device_num=1 --device_id=0 --run_eval=True
ResNet50模型训练过程存在随机性,最终以评估精度为准。训练精度结果如图1所示。
训练结束后日志信息参考如图2所示。
训练结束后会在“--train_output_path”参数指定的输出目录中生成.ckpt、.a310.om、.a310p.om和.air格式的模型文件。
评估参数及范围
ResNet50模型评估的参数名、类型、取值范围、默认值及说明如表2所示。
参数名 |
类型 |
取值范围 |
默认值 |
说明 |
---|---|---|---|---|
--eval_dataset_path |
str |
- |
None |
评估数据集路径。 |
--eval_ckpt_path |
str |
- |
None |
评估ckpt获取路径。 |
--eval_output_path |
str |
- |
"eval_output_path" |
评估结果输出路径。 |
--device_id |
int |
[0,7] |
0 |
评估NPU的编号。 |
--enable_modelarts |
bool |
True 或 False |
False |
是否使用ModelArts进行训练。 |
--data_url |
str |
- |
None |
在ModelArts训练时,数据集路径。 |
--eval_output_url |
str |
- |
None |
在ModelArts训练时,输出路径。 |
--checkpoint_url |
str |
- |
None |
在ModelArts训练时,预训练模型路径。 |
评估命令参考
ResNet50模型评估的启动参考以下命令。
python3 model_eval.py --eval_dataset_path={eval_dataset_path} --eval_ckpt_path=./output_ path --eval_output_path=./eval_result --device_id=0
采用训练输出的ckpt,来评估模型的精度值,如图3所示。
在评估目录下会生成如图4所示的文件及目录,其中“ng_images”、“ok_images”文件夹,以及“statistics.csv”文件,其中“ng_images”文件夹存放评估不正确的每张图片、“ok_images”文件夹存放评估正确的每张图片,“statistics.csv”存放评估结果的类别名、标签数量、检出个数、正检个数、精确率、召回率和精度值结果。