src.ml.train.train_model¶
- train_model(model, epochs, train_loader, val_loader, optimizer, criterion, device, save_model_path)[source]¶
Training loop.
- Parameters
model – PyTorch model object.
epochs – Integer that denotes the number of training iterations.
train_loader – Dataloader for train dataset.
val_loader – Dataloader for validation dataset.
optimizer – PyTorch optimizer.
criterion – PyTorch loss function.
device – CUDA/CPU identifier.
save_model_path – Path to save the trained model weights.
- Returns
A tuple of dictionaries (loss_stats, acc_stats).
- Return type