昇腾社区首页
中文
注册

语义分割场景Unet++/Unet

Unet和Unet++两个网络模型在训练和评估时使用方式相似,仅网络结构设置不同,故将两者在同一套代码中实现。

文件夹“unet_mindspore”“unet_nested_mindspore”中除“on_platform/plat_cfg.yaml”文件有差异外,其余均相同。

Unet++/Unet训练参数及范围

Unet和Unet++模型训练的参数名、类型、取值范围、默认值及说明如表1所示。

表1 Unet++/Unet训练参数信息(unet_mindspore/model_train.py)

参数名

类型

取值范围

默认值

说明

--device_id

int

[0,7]

0

训练NPU的编号。

--device_num

int

当前仅支持单卡,取值为1。

1

训练NPU的数量。

--epoch_size

int

[1, 10000]

200

训练次数。

--batch_size

int

小于或等于数据集中的图片个数。

[1, 512]

16

训练batch数。

--pretrained_ckpt_path

str

-

“”

预训练模型存放路径。

“”代表不加载预训练模型。

--input_width

str

[1, 960],且是16的倍数

"256"

模型的输入宽度。

--input_height

str

[1, 960],且是16的倍数

"192"

模型的输入高度。

--init_lr

float

(0, 1)

0.0003

训练学习率。

--train_dataset_path

str

-

None

训练数据集路径。

--train_output_path

str

-

"train_output_path"

训练结果输出路径。

--run_eval

bool

True 或 False

True

训练过程中是否验证。

--model

str

"unet" 或 "unet_nested"

unet

模型名称配置,“unet”表示选择Unet模型,“unet_nested”表示选择Unet++模型。

--split_ratio

float

(0, 1)

0.8

边训边评估中数据集的切分比例。

Unet++/Unet训练命令参考

Unet++/Unet模型训练的启动参考以下命令执行。

UNet支持传入pretrain ckpt和train from scratch两种方式,默认使用train from scratch(不加载预训练模型)。推荐优先使用不加载预训练模型的方式进行训练,如用户需要进行finetune可以将之前训练好的模型路径通过“--pretrained_ckpt_path”参数传入训练脚本。

python3 model_train.py --train_dataset_path={train_dataset_path} --train_output_path=./unetpp_output_dir --epoch_size=50 --batch_size=8 --input_width=256 --input_height=192 --init_lr=0.0005 --device_num=1 --device_id=1 --model=unet_nested --run_eval=True

模型训练过程存在随机性,最终以评估精度为准。训练精度参考如图1所示。

图1 Unet++/Unet训练精度结果

模型训练结束后日志信息参考如图2所示。

图2 Unet++/Unet训练完成信息

模型训练结束后会在“--train_output_path”参数指定的输出目录中生成.ckpt、.a310.om、.a310p.om和.air格式的模型文件。

Unet++/Unet评估参数及范围

Unet++/Unet模型评估的参数名、类型、取值范围、默认值及说明如表2所示。

表2 Unet++/Unet模型评估参数信息(unet_mindspore/model_eval.py)

参数名

类型

取值范围

默认值

说明

--eval_dataset_path

str

-

None

评估数据集路径。

--eval_ckpt_path

str

-

None

评估ckpt获取路径,需要指定到训练的输出路径。

--device_id

int

[0,7]

0

训练NPU的编号。

--eval_output_path

str

-

"eval_output_path"

评估结果输出路径。

--model

str

"unet" 或 "unet_nested"

"unet"

模型名称配置。

“unet”表示选择Unet模型,“unet_nested”表示选择Unet++模型。

Unet++/Unet评估命令参考

Unet++/Unet模型评估的启动参考以下命令执行。

python3 model_eval.py --eval_dataset_path={eval_dataset_path} --eval_ckpt_path=./output_dir --eval_output_path=./eval_result --device_id=0 --model=unet_nested

采用训练输出的ckpt,来评估模型的精度值,参考如图3所示。

图3 Unet++/Unet评估结果

在评估目录下会生成如图4所示的文件及目录,其中“all_images”文件夹存放评估的每张图片结果、“every_images_statistics.csv”存放每张图片对应的IOU和dice精度结果,“statistics.csv”存放平均IOU和dice精度结果。

图4 Unet++/Unet评估生成目录