svgp_trainer
SparseGPTrainer
Bases: BaseTrainer
Trains an SVGP model using specified parameters and early stopping.
Attributes:
Name | Type | Description |
---|---|---|
num_inducing_points |
int
|
Number of inducing points for the SVGP. |
model |
SVGP
|
The Stochastic Variational Gaussian Process model. |
likelihood |
gpytorch.likelihoods.GaussianLikelihood
|
The likelihood of the model. |
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X |
array-like
|
The input features. |
required |
y |
array-like
|
The target outputs. |
required |
num_inducing_points |
int
|
Number of inducing points to use in the SVGP model. |
100
|
sample_weights |
array-like
|
Sample weights for each data point. Defaults to None. |
required |
test_size |
float
|
Fraction of the dataset to be used as test data. Defaults to 0.2. |
required |
random_state |
int
|
Random seed for reproducible results. Defaults to 42. |
required |
num_epochs |
int
|
Maximum number of training epochs. Defaults to 50. |
required |
batch_size |
int
|
Batch size for training. Defaults to 256. |
required |
optimizer_fn_name |
str
|
Name of the optimizer to use. Defaults to "Adam". |
required |
lr |
float
|
Learning rate for the optimizer. Defaults to 0.01. |
required |
use_scheduler |
bool
|
Whether to use a learning rate scheduler. Defaults to False. |
required |
patience |
int
|
Number of epochs with no improvement before stopping training. Defaults to 10. |
required |
dtype |
torch.dtype
|
The dtype to use for input tensors. Defaults to torch.float32. |
required |
Source code in uncertaintyplayground/trainers/svgp_trainer.py
|
|
predict_with_uncertainty(X)
Predicts the mean and variance of the output distribution given input tensor X.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X |
tensor
|
Input tensor of shape (num_samples, num_features). |
required |
Returns:
Name | Type | Description |
---|---|---|
tuple | A tuple of the mean and variance of the output distribution, both of shape (num_samples,). |
Source code in uncertaintyplayground/trainers/svgp_trainer.py
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
|