Distributed Training Script Adaptation (Single Device)

The following figure shows a typical distributed NPU training setup. Each TensorFlow process manages only one exclusive NPU training card. Cluster-wide synchronization between TensorFlow processes is implemented through the collective communication APIs provided by CANN. The only difference from single NPU training is that distributed NPU training involves collective communication.

The TF Adapter considers a single-device setup as a distributed NPU setup containing only one worker, unifying the training script for single NPU training and distributed NPU training.

Compared with single NPU training, distributed NPU training requires the following additional adaptation steps:

  1. Synchronizing initial values of variables between workers

    In TF2 eager execution, variables are initialized immediately when a model is generated. It is important to synchronize the initial values of variables between the workers.

    When model building is complete, pass the variables to be synchronized to the npu.distribute.broadcast API call. You can call model.trainable_variables to obtain all the variables that need to be synchronized.

  2. Aggregating gradients between workers

    Gradients generated on different workers at training time are aggregated to evaluate the training error.

    • If the original script computes and updates gradients in separate steps (for example, by using tf.gradient and opt.apply_gradient), call the npu.distribute.all_reduce API to aggregate the gradients. This API call takes the gradients from all workers and the aggregation operation type (usually mean reduction).
    • If the original script uses a single API (for example, minimize or model.fit) to compute and update gradients, call npu.distribute.npu_distributed_keras_optimizer_wrapper to aggregate the gradients.
  3. Sharding dataset to workers

    In distributed training, each worker uses different samples to better reflect the actual distribution of the training dataset. For example, to train a model using an 8-NPU cluster, a typical strategy is to shard elements 0–1/8 to the first NPU, elements 1/8–2/8 to the second NPU, ..., and elements 7/8–8/8 to the last NPU.

    • If a dataset is in tf.data.Dataset format, use the npu.distribute.shard_and_rebatch_dataset API provided by the TF Adapter. This API call takes in the dataset to be sharded and the cluster's global batch size. Click here to find more about the dataset.
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
       # Import npu at the beginning of the script to use the npu.distribue.shard_and_rebatch API.
       import npu_device as npu
      
        if input_context:
          logging.info(
              'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d',
              input_context.input_pipeline_id, input_context.num_input_pipelines)
          # Original shard logic. Shard is not performed, as training is performed in single-CPU mode.
          dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) 
        # Shard logic added by the NPU. The dataset and global batch will be sharded based on the number of clusters.
        dataset, batch_size = npu.distribute.shard_and_rebatch_dataset(dataset, batch_size) 
      
    • For a dataset from NumPy arrays, call related NumPy methods to shard the dataset and global batch. For example:
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      13
      (x_train, _), (x_test, _) = keras.datasets.mnist.load_data(os.path.join(args.data_path, 'mnist.npz'))
      
      # Evenly divide the dataset based on the number of devices.
      x_trains = np.split(x_train, args.rank_size)
      # Obtain the dataset shard by device ID.
      x_train = x_trains[args.device_id]
      x_tests = np.split(x_test, args.rank_size)
      x_test = x_tests[args.device_id]
      # Shard the global batch.
      batch_size = args.batch_size // args.rank_size
      
      mnist_digits = np.concatenate([x_train, x_test], axis=0)
      mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255