src.ml.utils.get_latent_emb_per_class¶
- get_latent_emb_per_class(model, dataloader, agg_class, class2idx, idx2class)[source]¶
Get latent embedding divided by output class labels.
- Parameters
model – Train torch model.
dataloader – Dataloader object.
agg_class – Averages across tensors of the same class.
class2idx – Maps class to integers idx.
idx2class – Maps integer idx to class.
- Returns
A dictionary with keys as output labels and values as a list of tensors.
- Return type
- Return type