Source code for src.ml.train

"""
Training loop.
"""

from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from src.ml.utils import calc_multi_acc, print_log


[docs]def train_model( model: nn.Module, epochs: int, train_loader: DataLoader, val_loader: DataLoader, optimizer, criterion, device: str, save_model_path: str, ): """Training loop. Parameters ---------- model: PyTorch model object. epochs: Integer that denotes the number of training iterations. train_loader: Dataloader for train dataset. val_loader Dataloader for validation dataset. optimizer: PyTorch optimizer. criterion: PyTorch loss function. device: CUDA/CPU identifier. save_model_path: Path to save the trained model weights. Returns ------- tuple A tuple of dictionaries (loss_stats, acc_stats). """ if type(epochs) != int: raise ValueError("epochs has to be an integer.") acc_stats = {"train": [], "val": []} loss_stats = {"train": [], "val": []} for e in tqdm(range(1, epochs + 1)): train_epoch_loss = 0 train_epoch_acc = 0 model.train() for x_train_batch, y_train_batch in train_loader: x_train_batch, y_train_batch = ( x_train_batch.to(device), y_train_batch.to(device), ) optimizer.zero_grad() y_train_pred = model(x_train_batch) train_loss = criterion(y_train_pred, y_train_batch) train_acc = calc_multi_acc(y_train_pred, y_train_batch) train_loss.backward() optimizer.step() train_epoch_loss += train_loss.item() train_epoch_acc += train_acc.item() val_epoch_loss = 0 val_epoch_acc = 0 with torch.no_grad(): model.eval() for x_val_batch, y_val_batch in val_loader: x_val_batch, y_val_batch = ( x_val_batch.to(device), y_val_batch.to(device), ) y_val_pred = model(x_val_batch).squeeze() val_loss = criterion(y_val_pred, y_val_batch) val_acc = calc_multi_acc(y_val_pred, y_val_batch) val_epoch_loss += val_loss.item() val_epoch_acc += val_acc.item() avg_train_epoch_loss = train_epoch_loss / len(train_loader) avg_train_epoch_acc = train_epoch_acc / len(train_loader) avg_val_epoch_loss = val_epoch_loss / len(val_loader) avg_val_epoch_acc = val_epoch_acc / len(val_loader) loss_stats["train"].append(avg_train_epoch_loss) loss_stats["val"].append(avg_val_epoch_loss) acc_stats["train"].append(avg_train_epoch_acc) acc_stats["val"].append(avg_val_epoch_acc) print_log( e, epochs, avg_train_epoch_loss, avg_val_epoch_loss, avg_train_epoch_acc, avg_val_epoch_acc, ) if save_model_path: torch.save(model.state_dict(), save_model_path) return loss_stats, acc_stats