EventStream.transformer.generative_layers module¶
This module implements the TTE and regression generative emission layers used in the model.
- class EventStream.transformer.generative_layers.ExponentialTTELayer(in_dim: int)[source]¶
Bases:
ModuleA class that outputs an exponential distribution for time-to-event.
This class is used to initialize the ExponentialTTELayer and project the input tensor to get the implied exponential distribution.
- class EventStream.transformer.generative_layers.GaussianIndexedRegressionLayer(n_regression_targets: int, in_dim: int)[source]¶
Bases:
ModuleThis module implements an indexed, probabilistic regression layer.
This module outputs
(proj @ X).gather(2, idx)after projecting the input tensor and subselecting those down to just the set of regression targetsidxthat are needed.- Parameters:¶
-
forward(X: Tensor, idx: LongTensor | None =
None) Normal[source]¶ Forward pass.
- Parameters:¶
- Returns:¶
The
torch.distributions.normal.Normaldistribution with parametersself.proj(X)on indices specified byidx, which will have output shape(batch_size, sequence_length, num_predictions), unlessidxis None in which case it will have predictions for all indices and have shape(batch_size, sequence_length, n_regression_targets).
- class EventStream.transformer.generative_layers.GaussianRegressionLayer(in_dim: int)[source]¶
Bases:
ModuleThis module implements a probabilistic regression layer.
Given an input
X, this module predicts probabilistic regression outputs for each input inXfor one regression target.
-
class EventStream.transformer.generative_layers.LogNormalMixtureTTELayer(in_dim: int, num_components: int, mean_log_inter_time: float =
0.0, std_log_inter_time: float =1.0)[source]¶ Bases:
ModuleA class that outputs a mixture-of-lognormal distribution for time-to-event.
This class is used to initialize a module and project the input tensor to get a specific LogNormal Mixture distribution.
- Parameters:¶
- in_dim: int¶
The dimension of the input tensor.
- num_components: int¶
The number of lognormal components in the mixture distribution.
- mean_log_inter_time: float =
0.0¶ The mean of the log of the inter-event times. Used to initialize the mean of the log of the output distribution. Defaults to 0.0.
- std_log_inter_time: float =
1.0¶ The standard deviation of the log of the inter-event times. Used to initialize the standard deviation of the logs of the output distributions. Defaults to 1.0.