Skip to content

grid_predplot

DisablePlotDisplay

Context manager to disable the display of matplotlib plots.

Source code in uncertaintyplayground/predplot/grid_predplot.py
47
48
49
50
51
52
53
54
55
56
57
class DisablePlotDisplay:
    """
    Context manager to disable the display of matplotlib plots.
    """

    def __enter__(self):
        plt.ioff()  # Turn off interactive mode
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        plt.close()  # Close the plot

plot_results_grid(trainer, compare_func, X_test, indices, Y_test=None, ncols=2, dtype=np.float32)

Plot a grid of comparison plots (minimum 2) for a set of test instances.

Parameters:

Name Type Description Default
trainer object

The trained MDNTrainer or SparseGPTrainer instance.

required
compare_func function

Function to compare distributions (compare_distributions for MDN or compare_distributions_svgpr for SVGPR).

required
X_test np.ndarray

The test input data of shape (num_samples, num_features).

required
indices list

The indices of the instances to plot.

required
ncols int

Number of columns in the grid. Default is 2.

2
Y_test np.ndarray

The test target data of shape (num_samples,). Default is None.

None
dtype np.dtype

Data type to use for plotting. Default is np.float32.

np.float32

Returns:

Type Description

None

Source code in uncertaintyplayground/predplot/grid_predplot.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def plot_results_grid(trainer, compare_func, X_test, indices, Y_test= None, ncols=2, dtype=np.float32):
    """
    Plot a grid of comparison plots (minimum 2) for a set of test instances.

    Args:
        trainer (object): The trained MDNTrainer or SparseGPTrainer instance.
        compare_func (function): Function to compare distributions (compare_distributions for MDN or compare_distributions_svgpr for SVGPR).
        X_test (np.ndarray): The test input data of shape (num_samples, num_features).
        indices (list): The indices of the instances to plot.
        ncols (int, optional): Number of columns in the grid. Default is 2.
        Y_test (np.ndarray): The test target data of shape (num_samples,). Default is None.
        dtype (np.dtype, optional): Data type to use for plotting. Default is np.float32.

    Returns:
        None
    """
    num_instances = len(indices)
    nrows = (num_instances - 1) // ncols + 1

    _, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5 * ncols, 5 * nrows))

    for i, ax in zip(indices, axes.flat):
        x_instance = X_test[i].astype(dtype)
        if Y_test is None:
            y_actual = None
        else:
            y_actual = Y_test[i].astype(dtype)
        compare_func(trainer, x_instance, y_actual, ax=ax)
        ax.set_title(f"Test Instance: {i}")
        ax.set_xlabel("Value")
        ax.set_ylabel("Density")
        ax.legend()
        ax.grid(axis='y', alpha=0.75)

    # Remove empty subplots
    if num_instances < nrows * ncols:
        for ax in axes.flat[num_instances:]:
            ax.remove()

    plt.tight_layout()
    plt.show()