Skip to content

mdn_model

MDN

Bases: nn.Module

Mixed Density Network (MDN) model.

This model represents a mixture density network for modeling and predicting multi-modal distributions.

Parameters:

Name Type Description Default
input_dim int

Number of predictors for the first layer of the nueral network.

required
n_gaussians int

Number of Gaussian components in the mixture.

required
dense1_units int

Number of neurons in the first dense layer. Default is 10.

10
prediction_method str

Method for predicting the output distribution. Options are: - 'max_weight_mean': Choose the component with the highest weight and return the mean. - 'max_weight_sample': Choose a component from the mixture and sample from it. - 'average_sample': Draw multiple samples and take the average.

'max_weight_sample'

Attributes:

Name Type Description
z_h nn.Sequential

Hidden layer of the neural network.

z_pi nn.Linear

Linear layer for predicting mixture weights.

z_mu nn.Linear

Linear layer for predicting Gaussian means.

z_sigma nn.Linear

Linear layer for predicting Gaussian standard deviations.

prediction_method str

Method for predicting the output distribution.

Source code in uncertaintyplayground/models/mdn_model.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class MDN(nn.Module):
    """
    Mixed Density Network (MDN) model.

    This model represents a mixture density network for modeling and predicting multi-modal distributions.

    Args:
        input_dim (int): Number of predictors for the first layer of the nueral network.
        n_gaussians (int): Number of Gaussian components in the mixture.
        dense1_units (int): Number of neurons in the first dense layer. Default is 10.
        prediction_method (str): Method for predicting the output distribution. Options are:
                                 - 'max_weight_mean': Choose the component with the highest weight and return the mean.
                                 - 'max_weight_sample': Choose a component from the mixture and sample from it.
                                 - 'average_sample': Draw multiple samples and take the average.

    Attributes:
        z_h (nn.Sequential): Hidden layer of the neural network.
        z_pi (nn.Linear): Linear layer for predicting mixture weights.
        z_mu (nn.Linear): Linear layer for predicting Gaussian means.
        z_sigma (nn.Linear): Linear layer for predicting Gaussian standard deviations.
        prediction_method (str): Method for predicting the output distribution.
    """

    def __init__(self, input_dim, n_gaussians, dense1_units = 10, prediction_method='max_weight_sample'):
        super(MDN, self).__init__()
        self.z_h = nn.Sequential(
            nn.Linear(input_dim,dense1_units),
            nn.Tanh()
        )
        self.z_pi = nn.Linear(dense1_units, n_gaussians)
        self.z_mu = nn.Linear(dense1_units, n_gaussians)
        self.z_sigma = nn.Linear(dense1_units, n_gaussians)
        self.prediction_method = prediction_method

    def forward(self, x):
        """
        Forward pass of the MDN model.

        Computes the parameters (pi, mu, sigma) of the output distribution given the input.

        Args:
            x (tensor): Input tensor of shape (batch_size, num_features).

        Returns:
            tuple: A tuple containing the predicted mixture weights, means, and standard deviations.
        """
        z_h = self.z_h(x)
        pi = F.softmax(self.z_pi(z_h), -1)
        mu = self.z_mu(z_h)
        sigma = torch.exp(self.z_sigma(z_h))
        return pi, mu, sigma

    def sample(self, x, num_samples=100):
        """
        Generate samples from the output distribution given the input.

        Args:
            x (tensor): Input tensor of shape (batch_size, num_features).
            num_samples (int): Number of samples to generate. Default is 100.

        Returns:
            tensor: A tensor of shape (batch_size,) containing the generated samples.
        """
        pi, mu, sigma = self.forward(x)

        if self.prediction_method == 'max_weight_mean':
            # Choose component with the highest weight
            pis = torch.argmax(pi, dim=1)
            # Return the mean of the chosen component
            sample = mu[torch.arange(mu.size(0)), pis]

        elif self.prediction_method == 'max_weight_sample':
            # Choose component from the mixture
            categorical = torch.distributions.Categorical(pi)
            pis = list(categorical.sample().data)
            # Sample from the chosen component
            sample = Variable(sigma.data.new(sigma.size(0)).normal_())
            for i in range(sigma.size(0)):
                sample[i] = sample[i] * sigma[i, pis[i]] + mu[i, pis[i]]

        elif self.prediction_method == 'average_sample':
            # Draw multiple samples and take the average
            samples = []
            for _ in range(num_samples):
                # Choose component from the mixture
                categorical = torch.distributions.Categorical(pi)
                pis = list(categorical.sample().data)
                # Sample from the chosen component
                sample = Variable(sigma.data.new(sigma.size(0)).normal_())
                for i in range(sigma.size(0)):
                    sample[i] = sample[i] * sigma[i, pis[i]] + mu[i, pis[i]]
                samples.append(sample)
            sample = torch.mean(torch.stack(samples), dim=0)

        else:
            raise ValueError(f"Invalid prediction method: {self.prediction_method}")

        return sample

forward(x)

Forward pass of the MDN model.

Computes the parameters (pi, mu, sigma) of the output distribution given the input.

Parameters:

Name Type Description Default
x tensor

Input tensor of shape (batch_size, num_features).

required

Returns:

Name Type Description
tuple

A tuple containing the predicted mixture weights, means, and standard deviations.

Source code in uncertaintyplayground/models/mdn_model.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def forward(self, x):
    """
    Forward pass of the MDN model.

    Computes the parameters (pi, mu, sigma) of the output distribution given the input.

    Args:
        x (tensor): Input tensor of shape (batch_size, num_features).

    Returns:
        tuple: A tuple containing the predicted mixture weights, means, and standard deviations.
    """
    z_h = self.z_h(x)
    pi = F.softmax(self.z_pi(z_h), -1)
    mu = self.z_mu(z_h)
    sigma = torch.exp(self.z_sigma(z_h))
    return pi, mu, sigma

sample(x, num_samples=100)

Generate samples from the output distribution given the input.

Parameters:

Name Type Description Default
x tensor

Input tensor of shape (batch_size, num_features).

required
num_samples int

Number of samples to generate. Default is 100.

100

Returns:

Name Type Description
tensor

A tensor of shape (batch_size,) containing the generated samples.

Source code in uncertaintyplayground/models/mdn_model.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def sample(self, x, num_samples=100):
    """
    Generate samples from the output distribution given the input.

    Args:
        x (tensor): Input tensor of shape (batch_size, num_features).
        num_samples (int): Number of samples to generate. Default is 100.

    Returns:
        tensor: A tensor of shape (batch_size,) containing the generated samples.
    """
    pi, mu, sigma = self.forward(x)

    if self.prediction_method == 'max_weight_mean':
        # Choose component with the highest weight
        pis = torch.argmax(pi, dim=1)
        # Return the mean of the chosen component
        sample = mu[torch.arange(mu.size(0)), pis]

    elif self.prediction_method == 'max_weight_sample':
        # Choose component from the mixture
        categorical = torch.distributions.Categorical(pi)
        pis = list(categorical.sample().data)
        # Sample from the chosen component
        sample = Variable(sigma.data.new(sigma.size(0)).normal_())
        for i in range(sigma.size(0)):
            sample[i] = sample[i] * sigma[i, pis[i]] + mu[i, pis[i]]

    elif self.prediction_method == 'average_sample':
        # Draw multiple samples and take the average
        samples = []
        for _ in range(num_samples):
            # Choose component from the mixture
            categorical = torch.distributions.Categorical(pi)
            pis = list(categorical.sample().data)
            # Sample from the chosen component
            sample = Variable(sigma.data.new(sigma.size(0)).normal_())
            for i in range(sigma.size(0)):
                sample[i] = sample[i] * sigma[i, pis[i]] + mu[i, pis[i]]
            samples.append(sample)
        sample = torch.mean(torch.stack(samples), dim=0)

    else:
        raise ValueError(f"Invalid prediction method: {self.prediction_method}")

    return sample

mdn_loss(y, mu, sigma, pi)

Compute the MDN loss.

Calculates the negative log-likelihood of the target variable given the predicted parameters of the mixture.

Parameters:

Name Type Description Default
y tensor

Target tensor of shape (batch_size,).

required
mu tensor

Predicted means tensor of shape (batch_size, n_gaussians).

required
sigma tensor

Predicted standard deviations tensor of shape (batch_size, n_gaussians).

required
pi tensor

Predicted mixture weights tensor of shape (batch_size, n_gaussians).

required

Returns:

Name Type Description
tensor

The computed loss.

Source code in uncertaintyplayground/models/mdn_model.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def mdn_loss(y, mu, sigma, pi):
    """
    Compute the MDN loss.

    Calculates the negative log-likelihood of the target variable given the predicted parameters of the mixture.

    Args:
        y (tensor): Target tensor of shape (batch_size,).
        mu (tensor): Predicted means tensor of shape (batch_size, n_gaussians).
        sigma (tensor): Predicted standard deviations tensor of shape (batch_size, n_gaussians).
        pi (tensor): Predicted mixture weights tensor of shape (batch_size, n_gaussians).

    Returns:
        tensor: The computed loss.
    """
    m = Normal(loc=mu, scale=sigma)
    log_prob = m.log_prob(y.unsqueeze(1))
    log_mix = torch.log(pi) + log_prob
    return -torch.logsumexp(log_mix, dim=1).mean()