Source code for EventStream.transformer.generative_layers

"""This module implements the TTE and regression generative emission layers used in the model."""
import torch
from pytorch_lognormal_mixture import LogNormalMixtureDistribution


[docs] class LogNormalMixtureTTELayer(torch.nn.Module): """A class that outputs a mixture-of-lognormal distribution for time-to-event. This class is used to initialize a module and project the input tensor to get a specific LogNormal Mixture distribution. Args: in_dim: The dimension of the input tensor. num_components: The number of lognormal components in the mixture distribution. mean_log_inter_time: The mean of the log of the inter-event times. Used to initialize the mean of the log of the output distribution. Defaults to 0.0. std_log_inter_time: The standard deviation of the log of the inter-event times. Used to initialize the standard deviation of the logs of the output distributions. Defaults to 1.0. """ def __init__( self, in_dim: int, num_components: int, mean_log_inter_time: float = 0.0, std_log_inter_time: float = 1.0, ): super().__init__() # We multiply by 3 in the projections as we need to get the locs, log_scales, and weights for each # component. self.proj = torch.nn.Linear(in_dim, 3 * num_components) self.mean_log_inter_time = mean_log_inter_time self.std_log_inter_time = std_log_inter_time
[docs] def forward(self, T: torch.Tensor) -> LogNormalMixtureDistribution: """Forward pass. Args: T: The input tensor. Returns: A `LogNormalMixtureDistribution` with parameters specified by `self.proj(T)` which has output shape `(batch_size, sequence_length, 1)`. """ params = self.proj(T) locs = params[..., 0::3] log_scales = params[..., 1::3] log_weights = params[..., 2::3] return LogNormalMixtureDistribution( locs=locs, log_scales=log_scales, log_weights=log_weights, mean_log_inter_time=self.mean_log_inter_time, std_log_inter_time=self.std_log_inter_time, )
[docs] class ExponentialTTELayer(torch.nn.Module): """A class that outputs an exponential distribution for time-to-event. This class is used to initialize the ExponentialTTELayer and project the input tensor to get the implied exponential distribution. Args: in_dim: The dimensionality of the input. """ def __init__(self, in_dim: int): super().__init__() self.proj = torch.nn.Linear(in_dim, 1)
[docs] def forward(self, T: torch.Tensor) -> torch.distributions.exponential.Exponential: """Forward pass. Args: T: The input tensor. Returns: An `Exponential` distribution with parameters specified by `self.proj(T)` which has output shape `(batch_size, sequence_length, 1)`. """ # torch.nn.functional.elu has Image (-1, 1), but we need our rate parameter to be > 0. So we need to # add 1 to the output here. To ensure validity given numerical imprecision, we also add a buffer given # by the smallest possible positive value permissible given the type of `T`. rate = torch.nn.functional.elu(self.proj(T)) + 1 + torch.finfo(T.dtype).tiny # The rate currently has shape (batch_size, sequence_length, 1). We want to squeeze that last # dimension. rate = rate.squeeze(dim=-1) return torch.distributions.exponential.Exponential(rate=rate)
[docs] class GaussianIndexedRegressionLayer(torch.nn.Module): """This module implements an indexed, probabilistic regression layer. This module outputs `(proj @ X).gather(2, idx)` after projecting the input tensor and subselecting those down to just the set of regression targets `idx` that are needed. Args: n_regression_targets: How many regression targets there are. in_dim: The input dimensionality. """ def __init__(self, n_regression_targets: int, in_dim: int): super().__init__() # We multiply `n_regression_targets` by 2 because we need both mean and standard deviation outputs. self.proj = torch.nn.Linear(in_dim, n_regression_targets * 2)
[docs] def forward( self, X: torch.Tensor, idx: torch.LongTensor | None = None ) -> torch.distributions.normal.Normal: """Forward pass. Args: X: The input tensor. idx: The indices of the regression targets to output. If None, then all regression targets are predicted. Returns: The `torch.distributions.normal.Normal` distribution with parameters `self.proj(X)` on indices specified by `idx`, which will have output shape `(batch_size, sequence_length, num_predictions)`, unless `idx` is None in which case it will have predictions for all indices and have shape `(batch_size, sequence_length, n_regression_targets)`. """ Z = self.proj(X) Z_mean = Z[..., 0::2] # torch.nn.functional.elu has idxmage (-1, 1), but we need our std parameter to be > 0. So we need to # add 1 to the output here. To ensure validity given numerical imprecision, we also add a buffer given # by the smallest possible positive value permissible given the type of `T`. Z_std = torch.nn.functional.elu(Z[..., 1::2]) + 1 + torch.finfo(X.dtype).tiny if idx is None: return torch.distributions.normal.Normal(loc=Z_mean, scale=Z_std) mean = Z_mean.gather(-1, idx) std = Z_std.gather(-1, idx) return torch.distributions.normal.Normal(loc=mean, scale=std)
[docs] class GaussianRegressionLayer(torch.nn.Module): """This module implements a probabilistic regression layer. Given an input `X`, this module predicts probabilistic regression outputs for each input in `X` for one regression target. Args: in_dim: The input dimensionality. """ def __init__(self, in_dim: int): super().__init__() # We multiply `n_regression_targets` by 2 because we need both mean and standard deviation outputs. self.proj = torch.nn.Linear(in_dim, 2)
[docs] def forward(self, X: torch.Tensor) -> torch.distributions.normal.Normal: """Forward pass. Args: X: The input tensor of shape `(batch_size, sequence_length, in_dim)`. Returns: The `torch.distributions.normal.Normal` distribution with parameters `self.proj(X)`, which will have output shape `(batch_size, sequence_length, 1)`. """ Z = self.proj(X) Z_mean = Z[..., 0::2] # torch.nn.functional.elu has idxmage (-1, 1), but we need our std parameter to be > 0. So we need to # add 1 to the output here. To ensure validity given numerical imprecision, we also add a buffer given # by the smallest possible positive value permissible given the type of `T`. Z_std = torch.nn.functional.elu(Z[..., 1::2]) + 1 + torch.finfo(X.dtype).tiny return torch.distributions.normal.Normal(loc=Z_mean, scale=Z_std)