下载
中文
注册
mBART-PyTorch

概述

简述

MBART 是一种序列到序列去噪自动编码器,使用 BART 目标在多种语言的大规模单语语料库上进行预训练。mBART 是通过对多种语言的全文进行去噪来预训练完整序列到序列模型的首批方法之一,而以前的方法只关注编码器、解码器或重建文本的一部分。

  • 参考实现:

    [object Object]
  • 适配昇腾 AI 处理器的实现:

    [object Object]

准备训练环境

准备环境

  • 当前模型支持的 PyTorch 版本和已知三方库依赖如下表所示。

    表 1 版本支持表

    Torch_Version 三方库依赖版本
    PyTorch 1.5 -
    PyTorch 1.8 -
  • 环境准备指导。

    请参考《Pytorch框架训练环境准备》。

  • 安装依赖。

    [object Object]

准备数据集

方法一. 下载已预处理好的数据集

  1. 下载train_data.tar
  2. tar -xvf train_data.tar
  3. 将数据集放于工程根目录下,其目录结构如下:
[object Object]

说明: 该数据集的训练过程脚本只作为一种参考示例。

方法二. 下载数据集并自行处理

1. 分词处理

  1. 下载原始数据集并放于在源码包根目录下新建的“src_data/”目录下,以en_ro数据集为例。
  2. 下载并安装SPM
[object Object]

2. 数据预处理

[object Object]

获取预训练模型

  1. 下载mbart.CC25.tar.gz

  2. tar -xzvf mbart.CC25.tar.gz

  3. 将模型放于工程根目录下,其目录结构如下:

[object Object]

开始训练

训练模型

  1. 进入解压后的源码包根目录。

    [object Object]
  2. 运行训练脚本。

    该模型支持单机单卡训练和单机8卡训练。

    • 单机单卡训练

      启动单卡训练。

      [object Object]

      data_path为数据集路径,若训练en_ro数据集,路径写到en_ro;若训练en_de数据集,路径写到en_de ,同时需要将训练脚本中dropout的参数设置为0.1,target-lang设置为de_DE

    • 单机8卡训练

      启动8卡训练。

      [object Object]

      data_path为数据集路径,若训练en_ro数据集,路径写到en_ro;若训练en_de数据集,路径写到en_de ,同时需要将训练脚本中dropout的参数设置为0.1,total-num-update与max-update设置为300000,target-lang设置为de_DE

    模型训练脚本参数说明如下。

    [object Object]

    训练完成后,权重文件保存在当前路径下,并输出模型训练精度和性能信息。

训练结果展示

表 2 en_ro数据集训练结果展示表

NAME Acc@1 FPS Epochs AMP_Type Torch_Version
8p-竞品V - 39281.96 - - 1.8
8p-NPU 37.4 36171.24 - - 1.8

表 3 en_de数据集训练结果展示表

NAME Acc@1 FPS Epochs AMP_Type Torch_Version
8p-竞品V - 38365.15 - - 1.8
8p-NPU 32.5 35320.3 - - 1.8

说明: 由于该模型默认开启二进制,所以在性能测试时,需要安装二进制包

版本说明

变更

2022.12.14:首次发布。

FAQ

无。

使用模型资源和服务前,请您仔细阅读并理解透彻 《昇腾深度学习模型许可协议 3.0》