src.ml.train.train_model

train_model(model, epochs, train_loader, val_loader, optimizer, criterion, device, save_model_path)[source]

Training loop.

Parameters
  • model – PyTorch model object.

  • epochs – Integer that denotes the number of training iterations.

  • train_loader – Dataloader for train dataset.

  • val_loader – Dataloader for validation dataset.

  • optimizer – PyTorch optimizer.

  • criterion – PyTorch loss function.

  • device – CUDA/CPU identifier.

  • save_model_path – Path to save the trained model weights.

Returns

A tuple of dictionaries (loss_stats, acc_stats).

Return type

tuple