src.ml.utils.get_latent_emb_per_class

get_latent_emb_per_class(model, dataloader, agg_class, class2idx, idx2class)[source]

Get latent embedding divided by output class labels.

Parameters
  • model – Train torch model.

  • dataloader – Dataloader object.

  • agg_class – Averages across tensors of the same class.

  • class2idx – Maps class to integers idx.

  • idx2class – Maps integer idx to class.

Returns

A dictionary with keys as output labels and values as a list of tensors.

Return type

dict

Return type

dict