Effective Data Handling with Custom PyTorch Dataset Classes

Daniel O'Keefe
9 min readSep 25, 2023

--

Training machine learning models involves carefully managing and processing your training data. This can get complicated to handle inside your training logic. In this article, we’ll look at building custom PyTorch Dataset classes to handle data during training. This helps to decouple our data processing code from the code we use to train our models. We’ll use two examples: one for a straightforward dataset and another for a more complex case where we load and preprocess image data from disk that wouldn’t otherwise fit into memory.

Link to both examples in full:
https://github.com/DanOKeefe/pytorch-custom-datasets

Custom Dataset for the Titanic Dataset

  1. Dataset Initialization
    Let’s start with a simple dataset, the Titanic dataset, which contains passenger records and whether each passenger survived. Load the entire dataset into memory when we initialize it. We’ll also pass in arguments to enable K-Fold cross validation.
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import KFold
from sklearn.preprocessing import LabelEncoder

class TitanicData(Dataset):
"""
Custom dataset class for handling the Titanic survivor dataset
"""
def __init__(self, df_path, current_fold, num_fold=5):
"""
Load the Titanic dataset and perform preprocessing.

Args:
df_path (str): Path to the Titanic dataset CSV file.
current_fold (int): The current fold of the dataset.
num_fold (int): The total number of folds to split the dataset into.
"""
super().__init__()

self.df = self.preprocess_data(df_path)
self.num_fold = num_fold
self.current_fold = current_fold

# Use KFold to split the dataset into 'num_fold' folds
self.kf = KFold(n_splits=num_fold, shuffle=True, random_state=42)

2. Data Loading and Preprocessing
The preprocess_data method loads the Titanic dataset from a CSV file. It handles missing values, label encodes categorical variables, and converts the label to an integer.

    def preprocess_data(self, df_path):
"""
Reads the Titanic dataset CSV file and performs preprocessing.

Args:
df_path (str): Path to the Titanic dataset CSV file.

Returns:
DataFrame: A Pandas DataFrame containing the preprocessed Titanic dataset.
"""
df = pd.read_csv(df_path)

# Handle missing values
median_age = df['Age'].median()
df['Age'].fillna(median_age, inplace=True)

mode_embarked = df['Embarked'].mode()[0]
df['Embarked'].fillna(mode_embarked, inplace=True)

median_fare = df['Fare'].median()
df['Fare'].fillna(median_fare, inplace=True)

mode_pclass = df['Pclass'].mode()[0]
df['Pclass'].fillna(mode_pclass, inplace=True)

df.drop(columns=['Cabin'], inplace=True)
df.dropna(subset=['Embarked'], inplace=True)

# Create LabelEncoder instances for 'Sex' and 'Embarked' columns
sex_label_encoder = LabelEncoder()
embarked_label_encoder = LabelEncoder()

# Label encode the categorical columns
df['Sex'] = sex_label_encoder.fit_transform(df['Sex'])
df['Embarked'] = embarked_label_encoder.fit_transform(df['Embarked'])

# Create a binary 'Survived' column
df['Survived'] = df['Survived'].astype(int)

return df

3. Data Retrieval
The __getitem__ method retrieves individual data samples, where features and labels are prepared and returned on demand.

    def __getitem__(self, idx):
"""
Retrieves the features and label at the given index.

Args:
idx (int): The index of the dataset element to retrieve.

Returns:
dict: A dictionary containing the features and label of the dataset at the given index.
"""

# Extract passenger data
row = self.df.iloc[idx]
pclass = row['Pclass'] # Passenger class
sex = row['Sex']
age = row['Age']
sibsp = row['SibSp'] # Number of siblings/spouses aboard
parch = row['Parch'] # Number of parents/children aboard
fare = row['Fare']
embarked = row['Embarked'] # Port of embarkation

# Create feature tensor
features = torch.tensor([pclass, sex, age, sibsp, parch, fare, embarked], dtype=torch.float)

# Create label tensor
label = torch.tensor(row['Survived'], dtype=torch.long)

return {
'features': features,
'label': label,
}

4. K-Fold Cross Validation Support
The get_splits method handles dataset splitting for cross-validation. It’s used along with the _get_subset method.

    def get_splits(self):
"""
Splits the dataset into training and validation subsets.

Returns:
tuple: A tuple containing the training and validation subsets.
"""

fold_data = list(self.kf.split(self.df))
train_indices, val_indices = fold_data[self.current_fold]

train_data = self._get_subset(train_indices)
val_data = self._get_subset(val_indices)

return train_data, val_data

def _get_subset(self, indices):
"""
Returns a Subset of the dataset at the given indices.

Args:
indices (list): A list of indices specifying the subset of the dataset to return.

Returns:
Subset: A Subset of the dataset at the given indices.
"""
return Subset(self, indices)

5. Implementing __len__() for Dataset Length
The __len__() method returns the total number of samples in the dataset.

    def __len__(self):
"""
Returns the length of the dataset.
"""
return len(self.df)

6. Using the Dataset Class in our Training Loop
Now that we have our dataset class, let’s use it to train a machine learning model.

First, define the model architecture and hyperparameter configuration.

import torch.nn.functional as F

class SimpleFeedForwardNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleFeedForwardNN, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)

def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = torch.sigmoid(x)
return x

class Config:
input_size = 7
hidden_size = 32
output_size = 1
learning_rate = 0.003
num_epochs = 256
batch_size = 64
num_fold = 3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

criterion = nn.BCEWithLogitsLoss()

In each fold, create an instance of the TitanicData class, specifying which fold you’re on. Split the data into training and validation sets. with .get_splits(). Then use PyTorch’s DataLoader class, which allows us to easily retrieve minibatches of data. Inside each epoch, we iterate through the DataLoaders to train the model and evaluate performance on the validation set.

# Set up storage for validation results and models trained on each fold 
# to evaluate and store the performance of our models during cross-validation

fold_results = [] # Store validation results for each fold
fold_models = [] # Store the model trained on each fold

# Iterate through each fold, creating a new data handler
# for the current fold using our custom dataset class.
# This data handles contains the data preprocessing and
# splitting functionality.

for fold in range(Config.num_fold):
data_handler = TitanicData(
df_path='titanic.csv',
current_fold=fold,
num_fold=Config.num_fold
)

# For each fold, we split the dataset into training and validation subsets.
train_data, val_data = data_handler.get_splits()

# Split the data into training and validation subsets.
# Data loaders are used to load and batch the data during training.
# We configure data loaders for both the training and validation subsets.
train_loader = DataLoader(train_data, batch_size=Config.batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=Config.batch_size)

# Create a new instance of the model for each fold
model = SimpleFeedForwardNN(Config.input_size, Config.hidden_size, Config.output_size)
model.to(device)

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=Config.learning_rate)

# Training loop. Iterate through mini-batches of data
for epoch in range(Config.num_epochs):
model.train()
total_loss = 0.0
for batch in train_loader:
# Retrieve features and labels from the current batch of training data
features = batch['features'].to(device)
labels = batch['label'].float().view(-1, 1).to(device) # Reshape labels to (batch_size, 1)

# Forward pass
outputs = model(features)

# Calculate the loss using BCEWithLogitsLoss
loss = criterion(outputs, labels)

# Backpropagation and optimization with gradient clipping
optimizer.zero_grad()
loss.backward()

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Add gradient clipping
optimizer.step()

total_loss += loss.item()

average_loss = total_loss / len(train_loader)
# Validation for the current fold
model.eval()
val_total_loss = 0.0
all_labels = []
all_predictions = []
with torch.no_grad():
for batch in val_loader:
# Retrieve features and labels from the current batch of validation data
features = batch['features'].to(device)
labels = batch['label'].float().view(-1, 1).to(device) # Convert labels to float for BCE loss

# Forward pass
outputs = model(features)
#print(outputs)

val_loss = criterion(outputs, labels)
val_total_loss += val_loss.item()

predictions = (outputs>0.5).float()
#print(predictions)

all_labels.extend(labels.tolist())
all_predictions.extend(predictions.tolist())

average_val_loss = val_total_loss / len(val_loader)

# Calculate accuracy for the current fold
accuracy = accuracy_score(all_labels, all_predictions)
print(f'Fold [{fold + 1}/{Config.num_fold}] - Epoch [{epoch + 1}/{Config.num_epochs}] - Loss: {average_loss:.4f} - Validation Loss: {average_val_loss:.4f} - Validation Accuracy: {accuracy:.4f}')

# Store validation results for the current fold
fold_results.append(accuracy)

# Save the model for the current fold
fold_models.append(model)

# Calculate and print the average validation accuracy across all folds
average_accuracy = sum(fold_results) / len(fold_results)
print(f'Average Validation Accuracy: {average_accuracy:.4f} across {Config.num_fold} folds')

Custom Dataset for Image Data
For more complex cases, like working with image data, you can modify this approach to load and preprocess images within the __getitem__() method. This avoids the need to load the entire dataset into memory upfront, making it more memory-efficient. This means you can handle datasets that do not fit entirely in RAM.

Let’s go through another example, this time working with the Food-101 dataset. It contains images across 101 different types of food, with 1,000 images of each type.

  1. Download the Food-101 Dataset
    Use the torchvision library to download and extract that Food-101 dataset.
from torchvision.datasets import Food101

dataset = Food101(root='.', download=True)

2. Dataset Initialization
Initialize the dataset by specifying the path to the directory containing the data. This dataset has already been split into train and test sets. That metadata is contained in files food101/meta/train.txt and food101/meta/test.txt. We will still split the training data into a training and validation set. Notice that for each item we retrieve, the Dataset class will load an image from disk and apply data augmentations.

import os
import torch
from torch import optim
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torchvision.models as models

class Food101Dataset(Dataset):
def __init__(self, data_dir, train_txt_path, test_txt_path, val_split_ratio=0.2):
self.data_dir = data_dir
self.train_txt_path = train_txt_path
self.test_txt_path = test_txt_path
self.val_split_ratio = val_split_ratio
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
])

self.train_data, self.val_data, self.test_data, self.label_to_int = self._load_data(train_txt_path, test_txt_path)

def _load_data(self, train_txt_path, test_txt_path):
train_data = []
test_data = []

label_to_int = {}
int_label = 0

with open(os.path.join(self.data_dir, train_txt_path), 'r') as f:
lines = f.readlines()
for line in lines:
filename = line.strip() + '.jpg'
label = filename.split('/')[0]
if label not in label_to_int:
label_to_int[label] = int_label
int_label += 1
image_path = os.path.join(self.data_dir, 'images', filename)
train_data.append((image_path, label))

with open(os.path.join(self.data_dir, test_txt_path), 'r') as f:
lines = f.readlines()
for line in lines:
filename = line.strip() + '.jpg'
label = filename.split('/')[0]
image_path = os.path.join(self.data_dir, 'images', filename)
test_data.append((image_path, label))

# Split train_data into train and validation sets using train_test_split
train_data, val_data = train_test_split(train_data, test_size=self.val_split_ratio, random_state=42)

return train_data, val_data, test_data, label_to_int

def __len__(self):
return len(self.train_data) + len(self.val_data) + len(self.test_data)

def __getitem__(self, idx):
if idx < len(self.train_data):
data_source = self.train_data
elif idx < len(self.train_data) + len(self.val_data):
data_source = self.val_data
idx -= len(self.train_data)
else:
data_source = self.test_data
idx -= (len(self.train_data) + len(self.val_data))

image_path, label = data_source[idx]
image = Image.open(image_path).convert('RGB')
if self.transform:
image = self.transform(image)

# Convert label to integer using label_to_int mapping
label = torch.tensor(self.label_to_int[label], dtype=torch.int64)

return {
'image': image,
'label': label,
}

def get_splits(self):
train_subset = Subset(self, list(range(len(self.train_data))))
val_subset = Subset(self, list(range(len(self.train_data), len(self.train_data) + len(self.val_data))))
test_subset = Subset(self, list(range(len(self.train_data) + len(self.val_data), len(self.train_data) + len(self.val_data) + len(self.test_data))))
return train_subset, val_subset, test_subset

Use a pretrained vision transformer as our starting point. This acts as a good feature extractor for images. Replace the final classification head of the model with a linear layer with output size 101, matching the number of food categories in the Food-101 dataset. Fine-tune only the weights of this newly added layer, keeping the rest of the pretrained model unchanged.

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the pre-trained ViT model
model = models.vit_b_16(pretrained=True)

in_features = model.heads.head.in_features
classifier = nn.Linear(in_features=in_features, out_features=Config.num_classes)
model.heads.head = classifier

for param in model.parameters():
param.requires_grad = False
model.heads.head.weight.requires_grad = True
model.heads.head.bias.requires_grad = True
model.to(device)

criterion = nn.CrossEntropyLoss()

Create data loaders for the training, validation, and test sets.

data_handler = Food101Dataset(data_dir='food101', train_txt_path='meta/train.txt', test_txt_path='meta/test.txt')

train_data, val_data, test_data = data_handler.get_splits()

# Create data loaders using the batch_size from the Config class
train_loader = DataLoader(train_data, batch_size=Config.batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=Config.batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=Config.batch_size, shuffle=False)

optimizer = optim.Adam(model.parameters(), lr=Config.learning_rate)

Write the training loop, similar to as before.

from collections import deque

# Initialize a deque to keep track of the last two validation accuracies
val_loss_history = deque(maxlen=2)

# Define a variable to track the number of consecutive times validation accuracy drops
consecutive_increases = 0

for epoch in range(Config.num_epochs):
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
model.train()
total_loss = 0.0

for batch_idx, batch in progress_bar:
# Retrieve features and labels from the current batch
features = batch['image'].to(device)
labels = batch['label'].to(device)

outputs = model(features)
predictions = torch.argmax(outputs, dim=1)
loss = criterion(outputs, labels)

optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.item()

progress_bar.set_description(f"Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")

progress_bar.close()
average_loss = total_loss / len(train_loader)

model.eval()
val_total_loss = 0.0
all_labels = []
all_predictions = []

with torch.no_grad():
for batch in val_loader:
features = batch['image'].to(device)
labels = batch['label'].to(device)

outputs = model(features)

val_loss = criterion(outputs, labels)
val_total_loss += val_loss.item()

predictions = torch.argmax(outputs, dim=1)

all_labels.extend(labels.tolist())
all_predictions.extend(predictions.tolist())

average_val_loss = val_total_loss / len(val_loader)

accuracy = accuracy_score(all_labels, all_predictions)
print(f'Epoch [{epoch + 1}/{Config.num_epochs}] - Loss: {average_loss:.4f} - Validation Loss: {average_val_loss:.4f} - Validation Accuracy: {accuracy:.4f}')

# Append the validation loss to the history
val_loss_history.append(average_val_loss)

# Check if validation loss has increased twice in a row
if len(val_loss_history) == 2 and val_loss_history[0] < val_loss_history[1]:
consecutive_increases += 1
if consecutive_increases >= 2:
print('Validation loss has increased twice in a row. Stopping training.')
break
else:
consecutive_increases = 0 # Reset counter

Asses the model’s performance on the test set using the test_loader created earlier. View the classification report, loss, and accuracy on the test set.

import os
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, accuracy_score

def evaluate_on_test_data(model, criterion, num_classes, device):
model.eval()
total_loss = 0.0
all_labels = []
all_predictions = []

with torch.no_grad():
for batch in test_loader:
features = batch['image'].to(device)
labels = batch['label'].to(device)

outputs = model(features)

loss = criterion(outputs, labels)
total_loss += loss.item()

# Assuming your model returns class probabilities, apply softmax
predictions = torch.argmax(outputs, dim=1)

all_labels.extend(labels.tolist())
all_predictions.extend(predictions.tolist())

average_loss = total_loss / len(test_loader)

# Compute the classification report
classification_rep = classification_report(all_labels, all_predictions, target_names=[f'Class {i}' for i in range(num_classes)])

accuracy = accuracy_score(all_labels, all_predictions)

print(f'Test Loss: {average_loss:.4f} - Test Accuracy: {accuracy:.4f}')
print('Classification Report:\n', classification_rep)

evaluate_on_test_data(model, criterion, Config.num_classes, device)

Conclusion
Custom PyTorch Dataset classes provide a flexible and efficient way to manage data in machine learning training scripts. Whether you’re working with straightforward tabular data or complex image datasets, understanding how to create an utilize these classes can be very useful.

Link to both examples in full:
https://github.com/DanOKeefe/pytorch-custom-datasets

--

--

No responses yet