Source code for src.explore.summary

"""
Functions to summarise data.
"""

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt


[docs]def plot_hierarchical_labels( all_data: dict, label_type: str, label_level: int, plot: bool = True, fig_size: tuple = (15, 7), ) -> np.ndarray: """Bar plots for hierarchical labels. 0 is the least granular level while 3 is the most granular level. Parameters ---------- all_data: Dictionary of numpy arrays which contains all data. label_type: Can be "train" or "test". label_level: Integer denoting the label level. Possible values are 0, 1, 2 and 3. plot: If True, plots the graph. Else, returns the label array. fig_size: optional The size of the plot. Returns ------- np.ndarray A numpy array with labels of the given level. Raises ------ ValueError If `label_level` is not 0, 1, 2, or 3. """ if label_level not in [0, 1, 2, 3]: raise ValueError("label_level has to be 0, 1, 2, or 3.") if label_type == "train": labels = all_data["train_labels"][:, label_level] elif label_type == "test": labels = all_data["test_labels"][:, label_level] else: raise ValueError("label_type can only be 'train' or 'test'.") if plot: plt.figure(figsize=fig_size) sns.countplot(y=labels) plt.title(f"{label_type.capitalize()} Labels: Level-{label_level}") plt.xlabel("Label") else: return labels