Skip to content

mdn_trainer

MDNTrainer

Bases: BaseTrainer

Trainer for the Mixed Density Network (MDN) model.

This class handles the training process for the MDN model.

Parameters:

Name Type Description Default
dense1_units int

Number of hidden units in first layer of the neural network.

20
n_gaussians int

Number of Gaussian components in the mixture.

5
**kwargs

Additional arguments passed to the BaseTrainer.

{}

Attributes:

Name Type Description
n_gaussians int

Number of Gaussian components in the mixture.

model MDN

The MDN model.

optimizer torch.optim.Optimizer

The optimizer for model training.

Source code in uncertaintyplayground/trainers/mdn_trainer.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
class MDNTrainer(BaseTrainer):
    """
    Trainer for the Mixed Density Network (MDN) model.

    This class handles the training process for the MDN model.

    Args:
        dense1_units (int): Number of hidden units in first layer of the neural network.
        n_gaussians (int): Number of Gaussian components in the mixture.
        **kwargs: Additional arguments passed to the BaseTrainer.

    Attributes:
        n_gaussians (int): Number of Gaussian components in the mixture.
        model (MDN): The MDN model.
        optimizer (torch.optim.Optimizer): The optimizer for model training.
    """

    def __init__(self, *args, dense1_units=20, n_gaussians=5, **kwargs):
        super().__init__(*args, **kwargs)
        self.n_gaussians = n_gaussians

        self.model = MDN(input_dim=self.X.shape[1], n_gaussians=self.n_gaussians, dense1_units = dense1_units).to(device = self.device)
        if self.dtype == torch.float64:
            self.model = self.model.double()  # Convert model parameters to float64
        optimizer_fn = getattr(torch.optim, self.optimizer_fn_name)
        self.optimizer = optimizer_fn(self.model.parameters(), lr=self.lr)
        print(f"Model device: {next(self.model.parameters()).device}")
        print(f"Data device: {next(iter(self.train_loader))[0].device}")

    def train(self):
        """
        Train the MDN model.
        """
        self.model.train()
        early_stopping = EarlyStopping(patience=self.patience, compare_fn=lambda x, y: x < y)

        for epoch in range(self.num_epochs):
            for X_batch, y_batch, weights_batch in self.train_loader:
                X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device)

                self.optimizer.zero_grad()

                pi, mu, sigma = self.model(X_batch)
                loss = mdn_loss(y_batch, mu, sigma, pi)

                loss.backward()
                self.optimizer.step()

            self.model.eval()
            with torch.no_grad():
                pi, mu, sigma = self.model(self.X_val.to(self.device))
                val_loss = mdn_loss(self.y_val.to(self.device), mu, sigma, pi)

            self.model.train()

            print(
                f"Epoch {epoch + 1}/{self.num_epochs}, Training Loss: {loss.item():.3f}, "
                f"Validation Loss: {val_loss.item():.3f}"
            )

            should_stop = early_stopping(val_loss.item(), self.model)

            if should_stop:
                print(f"Early stopping after {epoch + 1} epochs")
                break

        if early_stopping.best_model_state is not None:
            self.model.load_state_dict(early_stopping.best_model_state)
            self.model.eval()

    def predict_with_uncertainty(self, X):
        """
        Predict the output distribution given input data.

        Args:
            X (np.ndarray or torch.Tensor): Input data of shape (num_samples, num_features).

        Returns:
            tuple: A tuple containing the predicted mixture weights, means, standard deviations, and samples.
        """
        self.model.eval()

        # Convert numpy array to PyTorch tensor if necessary
        if isinstance(X, np.ndarray):
            X = torch.from_numpy(X).to(self.device)

        # Check if X is a single instance and add an extra dimension if necessary
        if X.ndim == 1:
            X = torch.unsqueeze(X, 0)

        with torch.no_grad():
            pi, mu, sigma = self.model(X)
            sample = self.model.sample(X, num_samples=1000)

        return pi.cpu().numpy(), mu.cpu().numpy(), sigma.cpu().numpy(), sample.cpu().numpy()

predict_with_uncertainty(X)

Predict the output distribution given input data.

Parameters:

Name Type Description Default
X np.ndarray or torch.Tensor

Input data of shape (num_samples, num_features).

required

Returns:

Name Type Description
tuple

A tuple containing the predicted mixture weights, means, standard deviations, and samples.

Source code in uncertaintyplayground/trainers/mdn_trainer.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def predict_with_uncertainty(self, X):
    """
    Predict the output distribution given input data.

    Args:
        X (np.ndarray or torch.Tensor): Input data of shape (num_samples, num_features).

    Returns:
        tuple: A tuple containing the predicted mixture weights, means, standard deviations, and samples.
    """
    self.model.eval()

    # Convert numpy array to PyTorch tensor if necessary
    if isinstance(X, np.ndarray):
        X = torch.from_numpy(X).to(self.device)

    # Check if X is a single instance and add an extra dimension if necessary
    if X.ndim == 1:
        X = torch.unsqueeze(X, 0)

    with torch.no_grad():
        pi, mu, sigma = self.model(X)
        sample = self.model.sample(X, num_samples=1000)

    return pi.cpu().numpy(), mu.cpu().numpy(), sigma.cpu().numpy(), sample.cpu().numpy()

train()

Train the MDN model.

Source code in uncertaintyplayground/trainers/mdn_trainer.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def train(self):
    """
    Train the MDN model.
    """
    self.model.train()
    early_stopping = EarlyStopping(patience=self.patience, compare_fn=lambda x, y: x < y)

    for epoch in range(self.num_epochs):
        for X_batch, y_batch, weights_batch in self.train_loader:
            X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device)

            self.optimizer.zero_grad()

            pi, mu, sigma = self.model(X_batch)
            loss = mdn_loss(y_batch, mu, sigma, pi)

            loss.backward()
            self.optimizer.step()

        self.model.eval()
        with torch.no_grad():
            pi, mu, sigma = self.model(self.X_val.to(self.device))
            val_loss = mdn_loss(self.y_val.to(self.device), mu, sigma, pi)

        self.model.train()

        print(
            f"Epoch {epoch + 1}/{self.num_epochs}, Training Loss: {loss.item():.3f}, "
            f"Validation Loss: {val_loss.item():.3f}"
        )

        should_stop = early_stopping(val_loss.item(), self.model)

        if should_stop:
            print(f"Early stopping after {epoch + 1} epochs")
            break

    if early_stopping.best_model_state is not None:
        self.model.load_state_dict(early_stopping.best_model_state)
        self.model.eval()