Adapting to Distributed Setup

This step allows you to run single-device training and distributed training using the same script. Note that distributed adaptation has zero impact on the single-device training process. The porting workflow for distributed training goes through the following steps:

  1. Synchronizing initial values of variables between workers

    In official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py, add the action of synchronizing trainable variables by inserting the npu.distribute.broadcast API.

    1
    2
    3
    4
    5
    with distribute_utils.get_strategy_scope(strategy):
      # Model creation
      runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback, per_epoch_steps)
    # Variable synchronization
    npu.distribute.broadcast(runnable.model.trainable_variables)
    
  2. Aggregating gradients between workers

    Find the official/vision/image_classification/resnet/resnet_runnable.py script.

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    def train_step(self, iterator):
      """See base class."""
    
      def step_fn(inputs):
        """Function to run on the device."""
        images, labels = inputs
        with tf.GradientTape() as tape:
          logits = self.model(images, training=True)
    
          prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
              labels, logits)
          loss = tf.reduce_sum(prediction_loss) * (1.0 /
                                                   self.flags_obj.batch_size)
          num_replicas = self.strategy.num_replicas_in_sync
          l2_weight_decay = 1e-4
          if self.flags_obj.single_l2_loss_op:
            l2_loss = l2_weight_decay * 2 * tf.add_n([
                tf.nn.l2_loss(v)
                for v in self.model.trainable_variables
                if 'bn' not in v.name
            ])
    
            loss += (l2_loss / num_replicas)
          else:
            loss += (tf.reduce_sum(self.model.losses) / num_replicas)
    
        grad_utils.minimize_using_explicit_allreduce(
            tape, self.optimizer, loss, self.model.trainable_variables)
        self.train_loss.update_state(loss)
        self.train_accuracy.update_state(labels, logits)
    

    In the source TF2 script, the minimize_using_explicit_allreduce function is used to shield the setup form and the function for executing gradient aggregation is implemented in official/staging/training/grad_utils.py.

    1
    2
    3
    def _filter_and_allreduce_gradients(grads_and_vars,
                                        allreduce_precision="float32",
                                        bytes_per_pack=0):
    

    The TF Adapter requires that training is started in single-CPU training form. As single-device training does not involve aggregation, the original gradient aggregation code will not be executed. In this function, add the NPU gradient aggregation action by inserting the npu.distribute.all_reduce API. Add the following lines to the official/staging/training/grad_utils.py file:

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    # Import npu at the beginning of the script to use the npu.distribue.all_reduce API.
    import npu_device as npu
    
    def _filter_and_allreduce_gradients(grads_and_vars,
                                        allreduce_precision="float32",
                                        bytes_per_pack=0):
    ... ...
    
    # The original script uses the SUM strategy.
      allreduced_grads = tf.distribute.get_strategy(  # pylint: disable=protected-access
      ).extended._replica_ctx_all_reduce(tf.distribute.ReduceOp.SUM, grads, hints)
      if allreduce_precision == "float16":
        allreduced_grads = [tf.cast(grad, "float32") for grad in allreduced_grads]
    
    # Gradient aggregation added due to NPU adaptation. Keep it consistent with that of the original script, that is, "sum".
      allreduced_grads = npu.distribute.all_reduce(allreduced_grads,reduction="sum")  
    
      return allreduced_grads, variables
    
  3. Sharding dataset to workers

    Find the preprocessing function in official/vision/image_classification/resnet/resnet_runnable.py.

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
        # Fake data. Ignore this branch.
        if self.flags_obj.use_synthetic_data:  
          self.input_fn = common.get_synth_input_fn(
              height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
              width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
              num_channels=imagenet_preprocessing.NUM_CHANNELS,
              num_classes=imagenet_preprocessing.NUM_CLASSES,
              dtype=self.dtype,
              drop_remainder=True)
        else:
        # Actual preprocessing method
          self.input_fn = imagenet_preprocessing.input_fn
    

    Add the following code to official/vision/image_classification/resnet/imagenet_preprocessing.py:

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
     # Import npu at the beginning of the script to use the npu.distribue.shard_and_rebatch API.
     import npu_device as npu
    
      if input_context:
        logging.info(
            'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d',
            input_context.input_pipeline_id, input_context.num_input_pipelines)
        # Original shard logic. Shard is not performed, as training is performed in single-CPU mode.
        dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) 
      # Shard logic added by the NPU. The dataset and global batch will be sharded based on the number of clusters.
      dataset, batch_size = npu.distribute.shard_and_rebatch_dataset(dataset, batch_size) 
    

After the preceding steps are performed, the porting for distributed training is complete.

These APIs take effect depends on whether the environment variables for NPU distributed training are set. For single-device training, these APIs will not take effect.