跟随训练脚本逻辑,找到数据预处理文件:official/vision/image_classification/resnet/imagenet_preprocessing.py。
在数据读取函数“input_fn”内部设置drop_remainder为True,该迁移点完成。
def input_fn(is_training,
data_dir,
batch_size,
dtype=tf.float32,
datasets_num_private_threads=None,
parse_record_fn=parse_record,
input_context=None,
drop_remainder=False,
tf_data_experimental_slack=False,
training_dataset_cache=False,
filenames=None):
……
Returns:
A dataset that can be used for iteration.
"""
drop_remainder=True
if filenames is None:
filenames = get_filenames(is_training, data_dir)
dataset = tf.data.Dataset.from_tensor_slices(filenames)