Adapting to Distributed Setup
This step allows you to run single-device training and distributed training using the same script. Note that distributed adaptation has zero impact on the single-device training process. The porting workflow for distributed training goes through the following steps:
- Synchronizing initial values of variables between workers
In official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py, add the action of synchronizing trainable variables by inserting the npu.distribute.broadcast API.
1 2 3 4 5
with distribute_utils.get_strategy_scope(strategy): # Model creation runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback, per_epoch_steps) # Variable synchronization npu.distribute.broadcast(runnable.model.trainable_variables)
- Aggregating gradients between workers
Find the official/vision/image_classification/resnet/resnet_runnable.py script.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
def train_step(self, iterator): """See base class.""" def step_fn(inputs): """Function to run on the device.""" images, labels = inputs with tf.GradientTape() as tape: logits = self.model(images, training=True) prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_sum(prediction_loss) * (1.0 / self.flags_obj.batch_size) num_replicas = self.strategy.num_replicas_in_sync l2_weight_decay = 1e-4 if self.flags_obj.single_l2_loss_op: l2_loss = l2_weight_decay * 2 * tf.add_n([ tf.nn.l2_loss(v) for v in self.model.trainable_variables if 'bn' not in v.name ]) loss += (l2_loss / num_replicas) else: loss += (tf.reduce_sum(self.model.losses) / num_replicas) grad_utils.minimize_using_explicit_allreduce( tape, self.optimizer, loss, self.model.trainable_variables) self.train_loss.update_state(loss) self.train_accuracy.update_state(labels, logits)
In the source TF2 script, the minimize_using_explicit_allreduce function is used to shield the setup form and the function for executing gradient aggregation is implemented in official/staging/training/grad_utils.py.
1 2 3
def _filter_and_allreduce_gradients(grads_and_vars, allreduce_precision="float32", bytes_per_pack=0):
The TF Adapter requires that training is started in single-CPU training form. As single-device training does not involve aggregation, the original gradient aggregation code will not be executed. In this function, add the NPU gradient aggregation action by inserting the npu.distribute.all_reduce API. Add the following lines to the official/staging/training/grad_utils.py file:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Import npu at the beginning of the script to use the npu.distribue.all_reduce API. import npu_device as npu def _filter_and_allreduce_gradients(grads_and_vars, allreduce_precision="float32", bytes_per_pack=0): ... ... # The original script uses the SUM strategy. allreduced_grads = tf.distribute.get_strategy( # pylint: disable=protected-access ).extended._replica_ctx_all_reduce(tf.distribute.ReduceOp.SUM, grads, hints) if allreduce_precision == "float16": allreduced_grads = [tf.cast(grad, "float32") for grad in allreduced_grads] # Gradient aggregation added due to NPU adaptation. Keep it consistent with that of the original script, that is, "sum". allreduced_grads = npu.distribute.all_reduce(allreduced_grads,reduction="sum") return allreduced_grads, variables
- Sharding dataset to workers
Find the preprocessing function in official/vision/image_classification/resnet/resnet_runnable.py.
1 2 3 4 5 6 7 8 9 10 11 12
# Fake data. Ignore this branch. if self.flags_obj.use_synthetic_data: self.input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=self.dtype, drop_remainder=True) else: # Actual preprocessing method self.input_fn = imagenet_preprocessing.input_fn
Add the following code to official/vision/image_classification/resnet/imagenet_preprocessing.py:
1 2 3 4 5 6 7 8 9 10 11
# Import npu at the beginning of the script to use the npu.distribue.shard_and_rebatch API. import npu_device as npu if input_context: logging.info( 'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d', input_context.input_pipeline_id, input_context.num_input_pipelines) # Original shard logic. Shard is not performed, as training is performed in single-CPU mode. dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) # Shard logic added by the NPU. The dataset and global batch will be sharded based on the number of clusters. dataset, batch_size = npu.distribute.shard_and_rebatch_dataset(dataset, batch_size)
After the preceding steps are performed, the porting for distributed training is complete.
These APIs take effect depends on whether the environment variables for NPU distributed training are set. For single-device training, these APIs will not take effect.