src.ml.utils.get_class_weights

get_class_weights(y_data)[source]

Get balanced weights per class for cross-entropy loss.

Parameters

y_data – Numpy array of output class labels.

Returns

A tensor with class weights.

Return type

torch.tensor

Return type

tensor