Skip to content

test_mdn_predplot

TestMDNPlots

Bases: unittest.TestCase

This test class provides unit tests for the compare_distributions_mdn and plot_results_grid functions for MDN model.

Source code in uncertaintyplayground/tests/test_mdn_predplot.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class TestMDNPlots(unittest.TestCase):
    """
    This test class provides unit tests for the compare_distributions_mdn and plot_results_grid functions for MDN model.
    """

    def setUp(self):
        """
        Sets up the testing environment for each test method.
        """
        self.modes = [
            {'mean': -3.0, 'std_dev': 0.5, 'weight': 0.3},
            {'mean': 0.0, 'std_dev': 1.0, 'weight': 0.4},
            {'mean': 3.0, 'std_dev': 0.7, 'weight': 0.3}
        ]

        torch.manual_seed(1)
        np.random.seed(42)
        self.num_samples = 1000
        self.X_test = np.random.rand(self.num_samples, 20)
        self.Y_test = generate_multi_modal_data(self.num_samples, self.modes)

        self.mdn_trainer = MDNTrainer(self.X_test, self.Y_test, num_epochs=5, lr=0.01, dense1_units=5, n_gaussians=3)
        self.mdn_trainer.train()

    def test_compare_distributions_mdn(self):
        """
        Tests the compare_distributions function for the MDN model.
        """
        index_instance = 900

        with DisablePlotDisplay():
            compare_distributions_mdn(self.mdn_trainer, x_instance = self.X_test[index_instance, :], y_actual=self.Y_test[index_instance])

    def test_plot_results_grid_mdn(self):
        """
        Tests the plot_results_grid function with MDN model.
        """
        indices = [900, 100]  # Example indices

        # Testing with existing actual values
        with DisablePlotDisplay():
            plot_results_grid(self.mdn_trainer, compare_distributions_mdn, self.X_test, indices, self.Y_test)
        # Testing with no y values
        with DisablePlotDisplay():
            plot_results_grid(self.mdn_trainer, compare_distributions_mdn, self.X_test, indices, None)

setUp()

Sets up the testing environment for each test method.

Source code in uncertaintyplayground/tests/test_mdn_predplot.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def setUp(self):
    """
    Sets up the testing environment for each test method.
    """
    self.modes = [
        {'mean': -3.0, 'std_dev': 0.5, 'weight': 0.3},
        {'mean': 0.0, 'std_dev': 1.0, 'weight': 0.4},
        {'mean': 3.0, 'std_dev': 0.7, 'weight': 0.3}
    ]

    torch.manual_seed(1)
    np.random.seed(42)
    self.num_samples = 1000
    self.X_test = np.random.rand(self.num_samples, 20)
    self.Y_test = generate_multi_modal_data(self.num_samples, self.modes)

    self.mdn_trainer = MDNTrainer(self.X_test, self.Y_test, num_epochs=5, lr=0.01, dense1_units=5, n_gaussians=3)
    self.mdn_trainer.train()

test_compare_distributions_mdn()

Tests the compare_distributions function for the MDN model.

Source code in uncertaintyplayground/tests/test_mdn_predplot.py
35
36
37
38
39
40
41
42
def test_compare_distributions_mdn(self):
    """
    Tests the compare_distributions function for the MDN model.
    """
    index_instance = 900

    with DisablePlotDisplay():
        compare_distributions_mdn(self.mdn_trainer, x_instance = self.X_test[index_instance, :], y_actual=self.Y_test[index_instance])

test_plot_results_grid_mdn()

Tests the plot_results_grid function with MDN model.

Source code in uncertaintyplayground/tests/test_mdn_predplot.py
44
45
46
47
48
49
50
51
52
53
54
55
def test_plot_results_grid_mdn(self):
    """
    Tests the plot_results_grid function with MDN model.
    """
    indices = [900, 100]  # Example indices

    # Testing with existing actual values
    with DisablePlotDisplay():
        plot_results_grid(self.mdn_trainer, compare_distributions_mdn, self.X_test, indices, self.Y_test)
    # Testing with no y values
    with DisablePlotDisplay():
        plot_results_grid(self.mdn_trainer, compare_distributions_mdn, self.X_test, indices, None)