Scenarios Where Data Is Loaded in torch.utils.data.DataLoader Mode
torch.utils.data.DataLoader is a tool class used for data loading in PyTorch. It divides sample data into multiple small batches for training, testing, verification, and other required tasks. Check whether the dataset in your model script is loaded through torch.utils.data.DataLoader. The sample code is as follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | import torch from torchvision import datasets, transforms # Define data transformation. transform = transforms.Compose([ transforms.ToTensor(), # Transform images into tensors. transforms.Normalize((0.5,), (0.5,)) # Normalize the images. ]) # Load the MNIST dataset. train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) # Create a data loader. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4) # Use the data loader to iterate samples. for images, labels in train_loader: # Code for training a model ... |
Parent topic: Typical Cases