Skip to content

test_svgp_predplot

TestSVGPRPlots

Bases: unittest.TestCase

This test class provides unit tests for the compare_distributions_svgpr and plot_results_grid functions for SVGPR model.

Source code in uncertaintyplayground/tests/test_svgp_predplot.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class TestSVGPRPlots(unittest.TestCase):
    """
    This test class provides unit tests for the compare_distributions_svgpr and plot_results_grid functions for SVGPR model.
    """

    def setUp(self):
        """
        Sets up the testing environment for each test method.
        """
        # Assuming data is generated in a similar way as in the MDN case
        self.modes = [
            {'mean': -3.0, 'std_dev': 0.5, 'weight': 0.3},
            {'mean': 0.0, 'std_dev': 1.0, 'weight': 0.4},
            {'mean': 3.0, 'std_dev': 0.7, 'weight': 0.3}
        ]

        torch.manual_seed(1)
        np.random.seed(42)
        self.num_samples = 1000
        self.X_test = np.random.rand(self.num_samples, 20)
        self.Y_test = generate_multi_modal_data(self.num_samples, self.modes)

        self.svgpr_trainer = SparseGPTrainer(self.X_test, self.Y_test, num_epochs=5, lr=0.01)
        self.svgpr_trainer.train()

    def test_compare_distributions_svgpr(self):
        """
        Tests the compare_distributions function for the SVGPR model.
        """
        index_instance = 900

        with DisablePlotDisplay():
            compare_distributions_svgpr(self.svgpr_trainer, x_instance = self.X_test[index_instance, :], y_actual=self.Y_test[index_instance])

    def test_plot_results_grid_svgpr(self):
        """
        Tests the plot_results_grid function with SVGPR model.
        """
        indices = [900, 100]  # Example indices

        # Testing with actual y values
        with DisablePlotDisplay():
            plot_results_grid(self.svgpr_trainer, compare_distributions_svgpr, self.X_test, indices, self.Y_test)

        # Testing without actual y values
        with DisablePlotDisplay():
            plot_results_grid(self.svgpr_trainer, compare_distributions_svgpr, self.X_test, indices, None)

setUp()

Sets up the testing environment for each test method.

Source code in uncertaintyplayground/tests/test_svgp_predplot.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def setUp(self):
    """
    Sets up the testing environment for each test method.
    """
    # Assuming data is generated in a similar way as in the MDN case
    self.modes = [
        {'mean': -3.0, 'std_dev': 0.5, 'weight': 0.3},
        {'mean': 0.0, 'std_dev': 1.0, 'weight': 0.4},
        {'mean': 3.0, 'std_dev': 0.7, 'weight': 0.3}
    ]

    torch.manual_seed(1)
    np.random.seed(42)
    self.num_samples = 1000
    self.X_test = np.random.rand(self.num_samples, 20)
    self.Y_test = generate_multi_modal_data(self.num_samples, self.modes)

    self.svgpr_trainer = SparseGPTrainer(self.X_test, self.Y_test, num_epochs=5, lr=0.01)
    self.svgpr_trainer.train()

test_compare_distributions_svgpr()

Tests the compare_distributions function for the SVGPR model.

Source code in uncertaintyplayground/tests/test_svgp_predplot.py
36
37
38
39
40
41
42
43
def test_compare_distributions_svgpr(self):
    """
    Tests the compare_distributions function for the SVGPR model.
    """
    index_instance = 900

    with DisablePlotDisplay():
        compare_distributions_svgpr(self.svgpr_trainer, x_instance = self.X_test[index_instance, :], y_actual=self.Y_test[index_instance])

test_plot_results_grid_svgpr()

Tests the plot_results_grid function with SVGPR model.

Source code in uncertaintyplayground/tests/test_svgp_predplot.py
45
46
47
48
49
50
51
52
53
54
55
56
57
def test_plot_results_grid_svgpr(self):
    """
    Tests the plot_results_grid function with SVGPR model.
    """
    indices = [900, 100]  # Example indices

    # Testing with actual y values
    with DisablePlotDisplay():
        plot_results_grid(self.svgpr_trainer, compare_distributions_svgpr, self.X_test, indices, self.Y_test)

    # Testing without actual y values
    with DisablePlotDisplay():
        plot_results_grid(self.svgpr_trainer, compare_distributions_svgpr, self.X_test, indices, None)