Skip to content

test_early_stopping

TestEarlyStopping

Bases: unittest.TestCase

Unit tests for EarlyStopping class.

Source code in uncertaintyplayground/tests/test_early_stopping.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class TestEarlyStopping(unittest.TestCase):
    """
    Unit tests for EarlyStopping class.
    """

    def setUp(self):
        """
        Test fixture setup method.
        """
        self.early_stopping = EarlyStopping()

    def test_init(self):
        """
        Test case for EarlyStopping initialization.
        """
        self.assertIsInstance(self.early_stopping, EarlyStopping)
        self.assertEqual(self.early_stopping.counter, 0)
        self.assertEqual(self.early_stopping.best_val_metric, np.inf)

    def test_call(self):
        """
        Test case for call method in EarlyStopping.
        """
        model = torch.nn.Linear(1, 1)  # simple model for testing
        val_metric = 10
        self.early_stopping(val_metric, model)
        self.assertEqual(self.early_stopping.best_val_metric, 10)
        self.assertEqual(self.early_stopping.counter, 0)

        val_metric = 20
        self.early_stopping(val_metric, model)
        self.assertEqual(self.early_stopping.best_val_metric, 10)
        self.assertEqual(self.early_stopping.counter, 1)

setUp()

Test fixture setup method.

Source code in uncertaintyplayground/tests/test_early_stopping.py
11
12
13
14
15
def setUp(self):
    """
    Test fixture setup method.
    """
    self.early_stopping = EarlyStopping()

test_call()

Test case for call method in EarlyStopping.

Source code in uncertaintyplayground/tests/test_early_stopping.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def test_call(self):
    """
    Test case for call method in EarlyStopping.
    """
    model = torch.nn.Linear(1, 1)  # simple model for testing
    val_metric = 10
    self.early_stopping(val_metric, model)
    self.assertEqual(self.early_stopping.best_val_metric, 10)
    self.assertEqual(self.early_stopping.counter, 0)

    val_metric = 20
    self.early_stopping(val_metric, model)
    self.assertEqual(self.early_stopping.best_val_metric, 10)
    self.assertEqual(self.early_stopping.counter, 1)

test_init()

Test case for EarlyStopping initialization.

Source code in uncertaintyplayground/tests/test_early_stopping.py
17
18
19
20
21
22
23
def test_init(self):
    """
    Test case for EarlyStopping initialization.
    """
    self.assertIsInstance(self.early_stopping, EarlyStopping)
    self.assertEqual(self.early_stopping.counter, 0)
    self.assertEqual(self.early_stopping.best_val_metric, np.inf)