Skip to content

test_mdn_model

TestMDN

Bases: unittest.TestCase

Tests for the class MDN

Source code in uncertaintyplayground/tests/test_mdn_model.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class TestMDN(unittest.TestCase):
    """Tests for the class MDN"""

    def setUp(self):
        """Set up a test fixture with input_dim = 20, dense1_units  = 10, n_gaussians = 3"""
        self.input_dim = 20
        self.dense1_units = 10
        self.n_gaussians = 3
        self.mdn = MDN(input_dim=self.input_dim, n_gaussians=self.n_gaussians, dense1_units=self.dense1_units)

    def test_init(self):
        """Test that the MDN is initialized properly"""
        self.assertEqual(self.mdn.z_h[0].in_features, self.input_dim)
        self.assertEqual(self.mdn.z_h[0].out_features, self.dense1_units)
        self.assertEqual(self.mdn.z_pi.in_features, self.dense1_units)
        self.assertEqual(self.mdn.z_pi.out_features, self.n_gaussians)

    def test_forward(self):
        """Test the forward function with a tensor of shape (1, 20)"""
        x = torch.rand((1, self.input_dim))
        pi, mu, sigma = self.mdn.forward(x)
        self.assertEqual(pi.shape, (1, self.n_gaussians))
        self.assertEqual(mu.shape, (1, self.n_gaussians))
        self.assertEqual(sigma.shape, (1, self.n_gaussians))

    def test_sample(self):
        """Test the sample function with a tensor of shape (1, 20)"""
        x = torch.rand((1, self.input_dim))
        sample = self.mdn.sample(x)
        self.assertEqual(sample.shape, (1,))

setUp()

Set up a test fixture with input_dim = 20, dense1_units = 10, n_gaussians = 3

Source code in uncertaintyplayground/tests/test_mdn_model.py
 8
 9
10
11
12
13
def setUp(self):
    """Set up a test fixture with input_dim = 20, dense1_units  = 10, n_gaussians = 3"""
    self.input_dim = 20
    self.dense1_units = 10
    self.n_gaussians = 3
    self.mdn = MDN(input_dim=self.input_dim, n_gaussians=self.n_gaussians, dense1_units=self.dense1_units)

test_forward()

Test the forward function with a tensor of shape (1, 20)

Source code in uncertaintyplayground/tests/test_mdn_model.py
22
23
24
25
26
27
28
def test_forward(self):
    """Test the forward function with a tensor of shape (1, 20)"""
    x = torch.rand((1, self.input_dim))
    pi, mu, sigma = self.mdn.forward(x)
    self.assertEqual(pi.shape, (1, self.n_gaussians))
    self.assertEqual(mu.shape, (1, self.n_gaussians))
    self.assertEqual(sigma.shape, (1, self.n_gaussians))

test_init()

Test that the MDN is initialized properly

Source code in uncertaintyplayground/tests/test_mdn_model.py
15
16
17
18
19
20
def test_init(self):
    """Test that the MDN is initialized properly"""
    self.assertEqual(self.mdn.z_h[0].in_features, self.input_dim)
    self.assertEqual(self.mdn.z_h[0].out_features, self.dense1_units)
    self.assertEqual(self.mdn.z_pi.in_features, self.dense1_units)
    self.assertEqual(self.mdn.z_pi.out_features, self.n_gaussians)

test_sample()

Test the sample function with a tensor of shape (1, 20)

Source code in uncertaintyplayground/tests/test_mdn_model.py
30
31
32
33
34
def test_sample(self):
    """Test the sample function with a tensor of shape (1, 20)"""
    x = torch.rand((1, self.input_dim))
    sample = self.mdn.sample(x)
    self.assertEqual(sample.shape, (1,))