EventStream.transformer.model_output module

Classes and utilities for model output layers.

EventStream.transformer.model_output.BERNOULLI_DIST_T

The type of a bernoulli distribution.

EventStream.transformer.model_output.CATEGORICAL_DIST_T

The type of a categorical distribution.

EventStream.transformer.model_output.REGRESSION_DIST_T

The type of a regression distribution.

class EventStream.transformer.model_output.GenerativeOutputLayerBase(config: StructuredTransformerConfig)[source]

Bases: Module

A base class for the output layer of a generative model.

This class is responsible for constructing the time-to-event (TTE) layer based on the TTE_generation_layer_type in the given config, along with observation and classification layers. It also establishes the criteria for observation and classification. It does not contain a forward method which actually calls these helper methods, as those are implemented by subclass specific methods depending on how the encoded state is structured.

This class should not be instantiated directly. Instead, use one of the derived classes.

Parameters:
config: StructuredTransformerConfig

A configuration object of type StructuredTransformerConfig.

Raises:
  • ValueError – If the TTE_generation_layer_type in the config is not valid.

  • ValueError – If any measurements are duplicated in the regression layers.

get_TTE_outputs(batch: PytorchBatch, encoded: FloatTensor, is_generation: bool = False) tuple[FloatTensor, Distribution, FloatTensor][source]

Produces time-to-event predictions and log likelihoods (not NLLs!) for the model.

Parameters:
batch: PytorchBatch

The batch of data for which the classification predictions are desired.

encoded: FloatTensor

The final encodings used to predict the time from the event at a position to the subsequent event. This tensor is of shape (batch size X sequence length X hidden dim).

is_generation: bool = False

A boolean to indicate if the function is used for generation. Defaults to False. If true, then the model will only return the predicted distribution (as that is all that is used in generative use-cases).

Returns:

TTE_LL: A torch scalar containing the average log-likelihood of observed time-to-events given the predicted distribution. TTE_dist: The predicted torch Distribution for modelling time-to-event. TTE_true: A tensor containing the observed time between events for each batch element.

Return type:

A tuple containing the following items

Raises:
  • ValueError – If NaNs are found in TTE_obs_mask_exp, TTE_true_exp or TTE_LL or if there is no

  • observed time-to-event for >= 1 patient in the batch.

get_classification_outputs(batch: PytorchBatch, encoded: FloatTensor, valid_measurements: set[str]) tuple[dict[str, FloatTensor], dict[str, tuple[None, Bernoulli] | tuple[Bernoulli, Categorical]], dict[str, LongTensor | FloatTensor]][source]

Produces classification predictions and losses for the model.

Parameters:
batch: PytorchBatch

The batch of data for which the classification predictions are desired.

encoded: FloatTensor

The final encodings to be used to predict for each position in the sequence. For example, the vector encoded[i][j] (which is of size hidden_dim) is not the summary encoding of the batch element at batch index i and sequence index j, but rather is the input to be used to form classification predictions corresponding to batch element i at sequence position j.

valid_measurements: set[str]

The classification measurements in the batch that should be predicted from this input encoded.

Returns:

The following three dictionaries

  1. classification_losses_by_measurement: A dictionary from measurement to scalar tensors consisting of the average NLL of the data given the classiciation model. Averaging happens via the following procedure:

    • For multi-label measurements:

      1. NLL is averaged over labels per sequence event, for all unmasked sequence events (as in theory any event could have observed labels for binary multi-lable predictions). TODO(mmd): this should likely be specific to events with certain event types.

      2. NLL is macro-averaged across unmasked sequence events per batch element.

      3. NLL is macro-averaged across batch elements.

    • For single-task measurements:

      1. NLL is computed on any event that has a label for that task. TODO(mmd): Maybe should be conditioned on specific event types too?

      2. NLL is macro-averaged across events which had a label for that task per sequence. Sequences without any events with that label receive a loss of zero.

      3. NLL is macro-averaged across batch elements.

  2. classification_dists_by_measurement: A dictionary from measurement to classification distributions of shape [batch_size X sequence_length X vocabulary_size] or [batch_size X sequence_length] reflecting the probabilities for each event for that measurement. Returns scores for all events, even those that are masked, including the final event.

  3. classification_labels_by_measurement: A dictionary from measurement to tensors of one of two types:

    • For multi-label measurements, returns FloatTensors of shape [batch_size X sequence_length X vocabulary_size] containing binary labels for each vocabulary element for each event.

    • For single-label measurements, returns LongTensors of shape [batch_size, sequence_length] containing label indices for each event with that task observed, otherwise contains zeros.

get_regression_outputs(batch: PytorchBatch, encoded: FloatTensor, valid_measurements: set[str], is_generation: bool = False) tuple[dict[str, FloatTensor], dict[str, Distribution], dict[str, FloatTensor], dict[str, LongTensor]][source]

Produces regression predictions and losses for the model.

Parameters:
batch: PytorchBatch

The batch of data for which the regression predictions are desired.

encoded: FloatTensor

The final encodings (of shape batch_size X sequence_length X hidden_dim) to be used to predict for each position in the sequence. For example, the vector encoded[i][j] (which is of size hidden_dim) is _not_ the summary encoding of the batch element at batch index i and sequence index j, but rather is the input to be used to form regression predictions corresponding to batch element i at sequence position j.

valid_measurements: set[str]

The regression measurements in the batch that should be predicted from this input encoded.

Returns:

  • regression_loss_values: A dictionary from measurement to scalar tensors consisting of the average NLL of the data given the regression model. Averaging happens via the following procedure:

    1. NLL is averaged over data elements of the correct measurement per event. TODO(mmd): This is likely a bit wrong; if a regression task has no observed value, that should be taken into account here but I don’t think it is currently.

    2. Per-event NLLs are averaged over unmasked events with labels per batch element.

    3. NLL is macro-averaged over the batch.

  • regression_dists: A dictionary from measurement to torch distributions modelling the regression targets for each data element in each event. In particular, samples from these distributions will have shape [batch_size, sequence_length, num_data_elements_per_event], such that sample[i][j][k] will correspond to a prediction for the regression target indexed by batch['dynamic_indices'][i][j][k].

  • regression_labels: A dictionary from measurement to tensors of shape [batch_size, sequence_length, num_data_elements_per_event] containing regression targets for each data element, or 0 if that regression target is unobserved.

  • regression_indices: A dictionary from measurement to tensors of shape [batch_size, sequence_length, num_data_elements_per_event] containing the integer index of the regression component observed in that position, or 0 if that regression target is unobserved. E.g., if we have 200 laboratory tests that we are regressing over, these indices state to which laboratory test results the values in regression_labels correspond.

Return type:

Four dictionaries

class EventStream.transformer.model_output.GenerativeSequenceModelLabels(classification: dict[str, LongTensor] | None = None, regression: dict[str, FloatTensor] | None = None, regression_indices: dict[str, LongTensor] | None = None, time_to_event: FloatTensor | None = None)[source]

Bases: ModelOutput

Contains the labels for the GenerativeSequenceModel head.

The labels are split by task type. Single-label classification task labels will have shape batch X seq and have raw integer labels, whereas multi-label classification task labels will have shape batch X seq X vocab size and have binary indicators for each label.

Parameters:
classification: dict[str, LongTensor] | None = None

The classification task labels.

regression: dict[str, FloatTensor] | None = None

The regression task labels.

regression_indices: dict[str, LongTensor] | None = None

The indices for the regression task.

time_to_event: FloatTensor | None = None

The time-to-event task labels.

classification : dict[str, LongTensor] | None = None
regression : dict[str, FloatTensor] | None = None
regression_indices : dict[str, LongTensor] | None = None
time_to_event : FloatTensor | None = None
class EventStream.transformer.model_output.GenerativeSequenceModelLosses(classification: dict[str, FloatTensor] | None = None, regression: dict[str, FloatTensor] | None = None, time_to_event: FloatTensor | None = None)[source]

Bases: ModelOutput

Holds losses data for a Generative Sequence Model.

This class is designed to manage losses from a Generative Sequence Model, which can include classification, regression and time to event losses.

Parameters:
classification: dict[str, FloatTensor] | None = None

Losses for the classification task.

regression: dict[str, FloatTensor] | None = None

Losses for the regression task.

time_to_event: FloatTensor | None = None

Loss for the time-to-event task.

classification : dict[str, FloatTensor] | None = None
regression : dict[str, FloatTensor] | None = None
time_to_event : FloatTensor | None = None
class EventStream.transformer.model_output.GenerativeSequenceModelOutput(loss: FloatTensor, losses: GenerativeSequenceModelLosses | None = None, preds: GenerativeSequenceModelPredictions | None = None, labels: GenerativeSequenceModelLabels | None = None, event_mask: BoolTensor | None = None, dynamic_values_mask: BoolTensor | None = None, past_key_values: tuple[tuple[FloatTensor]] | None = None, hidden_states: tuple[FloatTensor] | None = None, attentions: tuple[FloatTensor] | None = None)[source]

Bases: ModelOutput

Contains all GenerativeSequenceModel outputs.

The outputs include losses, predictions, labels, and masks, among others.

Parameters:
loss: FloatTensor

The overall model loss.

losses: GenerativeSequenceModelLosses | None = None

The specific model losses by task type.

preds: GenerativeSequenceModelPredictions | None = None

The model predictions.

labels: GenerativeSequenceModelLabels | None = None

The model labels.

event_mask: BoolTensor | None = None

A boolean tensor representing the event mask.

dynamic_values_mask: BoolTensor | None = None

A boolean tensor representing the dynamic values mask.

past_key_values: tuple[tuple[FloatTensor]] | None = None

The past key values from the model.

hidden_states: tuple[FloatTensor] | None = None

The hidden states from the model.

attentions: tuple[FloatTensor] | None = None

The attentions from the model.

attentions : tuple[FloatTensor] | None = None
dynamic_values_mask : BoolTensor | None = None
event_mask : BoolTensor | None = None
hidden_states : tuple[FloatTensor] | None = None
labels : GenerativeSequenceModelLabels | None = None
loss : FloatTensor
losses : GenerativeSequenceModelLosses | None = None
past_key_values : tuple[tuple[FloatTensor]] | None = None
preds : GenerativeSequenceModelPredictions | None = None
class EventStream.transformer.model_output.GenerativeSequenceModelPredictions(classification: dict[str, tuple[None, Bernoulli] | tuple[Bernoulli, Categorical]] | None = None, regression: dict[str, tuple[None, Normal] | tuple[Bernoulli, Normal]] | None = None, regression_indices: dict[str, LongTensor] | None = None, time_to_event: Distribution | None = None)[source]

Bases: ModelOutput, NestedIndexableMixin

Contains the predictions for the GenerativeSequenceModel head.

Parameters:
classification: dict[str, tuple[None, Bernoulli] | tuple[Bernoulli, Categorical]] | None = None

The predicted classification task results.

regression: dict[str, tuple[None, Normal] | tuple[Bernoulli, Normal]] | None = None

The predicted regression task results.

regression_indices: dict[str, LongTensor] | None = None

The predicted indices for the regression task.

time_to_event: Distribution | None = None

The predicted time-to-event results.

classification : dict[str, tuple[None, Bernoulli] | tuple[Bernoulli, Categorical]] | None = None
regression : dict[str, tuple[None, Normal] | tuple[Bernoulli, Normal]] | None = None
regression_indices : dict[str, LongTensor] | None = None
sample(event_mask: BoolTensor) GenerativeSequenceModelSamples[source]

Generates a sample from the contained predictions.

Parameters:
event_mask: BoolTensor

A boolean tensor representing the event mask. This is used only to provide a source for the sampled event’s mask (which is copied from the last sequence dimension of this input).

Returns:

A sample from the GenerativeSequenceModel.

Raises:

ValueError – If the classification or regression distributions are malformed or unrecognized.

time_to_event : Distribution | None = None
class EventStream.transformer.model_output.GenerativeSequenceModelSamples(event_mask: BoolTensor | None = None, time_to_event: FloatTensor | None = None, classification: dict[str, LongTensor] | None = None, regression: dict[str, FloatTensor] | None = None, regression_indices: dict[str, LongTensor] | None = None)[source]

Bases: ModelOutput

A single sample (event) of a generative sequence model.

Parameters:
event_mask: BoolTensor | None = None

A boolean tensor of shape [batch_size,] indicating whether events exist.

time_to_event: FloatTensor | None = None

A float tensor of shape [batch_size,]. Is 0 if the event does not exist, otherwise quantifies the time between the prior event and this event in the series.

classification: dict[str, LongTensor] | None = None

A dictionary with keys as measurements and values as tensors. Shape of value tensor is [batch_size,] if measurement is single label classification or [batch_size, vocab_size] if measurement is multi label classification. The tensor contains either the class index (starting at 0, not the global offset) for the prediction for that data type for this event or per-label binary labels for multi label data types for the prediction for that data type. If the event is not present, all predictions will be zero.

regression: dict[str, FloatTensor] | None = None

A dictionary with keys as measurements and values as tensors. Shape of value tensor is [batch_size,] if measurement is univariate or [batch_size, n_regression_targets] if measurement is multivariate. The tensor contains the floating-point predictions for that measurement. If an event is not present, predictions will be zero. Predictions are ordered in accordance with the index-labels (starting at zero) for the data-type vocabulary contained in regression_indices. If regression_indices is None, predictions span the entire vocabulary in vocabulary order.

regression_indices: dict[str, LongTensor] | None = None

A dictionary with keys as measurements and values as tensors. Shape of value tensor is [batch_size, n_regression_targets] Contains the indices for which regression contains predictions for each data type. If None, regression predictions correspond to the entire vocabulary in vocabulary order.

append_to_batch(batch: PytorchBatch, config: StructuredTransformerConfig) PytorchBatch[source]

Appends a new batch element to the input batch.

This function first constructs a new batch element from the current object, and then appends it to the given batch. It adjusts the time delta and event mask of the batch accordingly, and ensures that the dynamic data elements of the batch and the new element are of the same dimensions by applying padding as needed.

Parameters:
batch: PytorchBatch

The PytorchBatch object to which the new element will be added.

config: StructuredTransformerConfig

A StructuredTransformerConfig object containing configuration data.

Returns:

A new PytorchBatch object, which includes the original data plus the appended new batch element.

classification : dict[str, LongTensor] | None = None
event_mask : BoolTensor | None = None
format_updates_to_last_batch_event(batch: PytorchBatch, config: StructuredTransformerConfig, measurements_to_build: set[str | tuple[str, MeasIndexGroupOptions]] | None = None) tuple[LongTensor, LongTensor, FloatTensor, BoolTensor][source]

Generate a new batch element from the prediction sample in the object.

This function is used for generation. It dynamically builds various elements such as indices, values, types, and values_mask based on the given configuration and measurements.

Parameters:
batch: PytorchBatch

The Pytorch batch object.

config: StructuredTransformerConfig

The structured transformer configuration object.

measurements_to_build: set[str | tuple[str, MeasIndexGroupOptions]] | None = None

The set of measurements indices group to be built. If None, all are built.

Returns:

the new dynamic indices, the new dynamic measurement indices, the new dynamic values, and the new dynamic values mask.

Return type:

A tuple containing four tensors

Raises:
  • ValueError – If measurement is missing in the config’s vocab_offsets_by_measurement, or the shape

  • of the prediction does not match the expected shape, or the prediction is greater than or equal to

  • the vocab size.

  • RuntimeError – If the indices cannot be gathered due to mismatch in shape or values.

static pad_data_elements(batch: PytorchBatch, new_dynamic_indices: LongTensor, new_dynamic_measurement_indices: LongTensor, new_dynamic_values: FloatTensor, new_dynamic_values_mask: BoolTensor)[source]

Pads the dimensions of the new batch elements to match the old ones.

This static method adjusts the shape of the given new dynamic data elements (indices, measurement indices, values, and values mask) to match the shape of those in the given batch. It achieves this by padding the shorter one of the new and old data elements with zeros (for LongTensors and FloatTensors) or with False (for BoolTensors).

Parameters:
batch: PytorchBatch

A PytorchBatch object whose data element dimensions are to be matched.

new_dynamic_indices: LongTensor

The indices tensor to be resized. This just

new_dynamic_measurement_indices: LongTensor

The measurement indices tensor to be resized.

new_dynamic_values: FloatTensor

The values tensor to be resized.

new_dynamic_values_mask: BoolTensor

The values mask tensor to be resized.

Returns:

A tuple of two tuples. The first inner tuple contains the possibly-padded dynamic data elements of the given batch. The second inner tuple contains the possibly-padded new dynamic data elements.

Examples

>>> import torch
>>> batch = PytorchBatch(
...     dynamic_indices=torch.tensor([
...         [[1, 2, 3], [4, 5, 6]],
...         [[7, 8, 9], [10, 11, 12]]
...     ]),
...     dynamic_measurement_indices=torch.tensor([
...         [[1, 2, 3], [4, 5, 6]],
...         [[7, 8, 9], [10, 11, 12]]
...     ]),
...     dynamic_values=torch.tensor([
...         [[1., 2., 3.], [4., 5., 6.]],
...         [[7., 8., 9.], [10., 11., 12.]]
...     ]),
...     dynamic_values_mask=torch.tensor([
...         [[True, True, True], [True, True, True]],
...         [[True, True, True], [True, True, True]]
...     ])
... )
>>> new_dynamic_indices = torch.tensor([
...     [[1, 2], [3, 4]],
...     [[5, 6], [7, 8]]
... ])
>>> new_dynamic_measurement_indices = torch.tensor([
...     [[1, 2], [3, 4]],
...     [[5, 6], [7, 8]]
... ])
>>> new_dynamic_values = torch.tensor([
...     [[1., 2.], [3., 4.]],
...     [[5., 6.], [7., 8.]]
... ])
>>> new_dynamic_values_mask = torch.tensor([
...     [[True, True], [True, True]],
...     [[True, True], [True, True]]
... ])
>>> out = GenerativeSequenceModelSamples.pad_data_elements(
...     batch,
...     new_dynamic_indices,
...     new_dynamic_measurement_indices,
...     new_dynamic_values,
...     new_dynamic_values_mask
... )
>>> len(out)
2
>>> for tensor_tuple in out:
...     print(len(tensor_tuple))
...     for tensor in tensor_tuple:
...         print(tensor)
4
tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]]])
tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]]])
tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.]],

        [[ 7.,  8.,  9.],
         [10., 11., 12.]]])
tensor([[[True, True, True],
         [True, True, True]],

        [[True, True, True],
         [True, True, True]]])
4
tensor([[[1, 2, 0],
         [3, 4, 0]],

        [[5, 6, 0],
         [7, 8, 0]]])
tensor([[[1, 2, 0],
         [3, 4, 0]],

        [[5, 6, 0],
         [7, 8, 0]]])
tensor([[[1., 2., 0.],
         [3., 4., 0.]],

        [[5., 6., 0.],
         [7., 8., 0.]]])
tensor([[[ True,  True, False],
         [ True,  True, False]],

        [[ True,  True, False],
         [ True,  True, False]]])
>>> batch = PytorchBatch(
...     dynamic_indices=torch.tensor([
...         [[1], [4]],
...         [[7], [10]]
...     ]),
...     dynamic_measurement_indices=torch.tensor([
...         [[1], [4]],
...         [[7], [10]]
...     ]),
...     dynamic_values=torch.tensor([
...         [[1.], [4.]],
...         [[7.], [10.]]
...     ]),
...     dynamic_values_mask=torch.tensor([
...         [[True], [True]],
...         [[True], [True]]
...     ])
... )
>>> new_dynamic_indices = torch.tensor([
...     [[1, 2], [3, 4]],
...     [[5, 6], [7, 8]]
... ])
>>> new_dynamic_measurement_indices = torch.tensor([
...     [[1, 2], [3, 4]],
...     [[5, 6], [7, 8]]
... ])
>>> new_dynamic_values = torch.tensor([
...     [[1., 2.], [3., 4.]],
...     [[5., 6.], [7., 8.]]
... ])
>>> new_dynamic_values_mask = torch.tensor([
...     [[True, True], [True, True]],
...     [[True, True], [True, True]]
... ])
>>> out = GenerativeSequenceModelSamples.pad_data_elements(
...     batch,
...     new_dynamic_indices,
...     new_dynamic_measurement_indices,
...     new_dynamic_values,
...     new_dynamic_values_mask
... )
>>> len(out)
2
>>> for tensor_tuple in out:
...     print(len(tensor_tuple))
...     for tensor in tensor_tuple:
...         print(tensor)
4
tensor([[[ 1,  0],
         [ 4,  0]],

        [[ 7,  0],
         [10,  0]]])
tensor([[[ 1,  0],
         [ 4,  0]],

        [[ 7,  0],
         [10,  0]]])
tensor([[[ 1.,  0.],
         [ 4.,  0.]],

        [[ 7.,  0.],
         [10.,  0.]]])
tensor([[[ True, False],
         [ True, False]],

        [[ True, False],
         [ True, False]]])
4
tensor([[[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]]])
tensor([[[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]]])
tensor([[[1., 2.],
         [3., 4.]],

        [[5., 6.],
         [7., 8.]]])
tensor([[[True, True],
         [True, True]],

        [[True, True],
         [True, True]]])
regression : dict[str, FloatTensor] | None = None
regression_indices : dict[str, LongTensor] | None = None
time_to_event : FloatTensor | None = None
update_last_event_data(batch: PytorchBatch, config: StructuredTransformerConfig, measurements_to_fill: set[str | tuple[str, MeasIndexGroupOptions]] | None = None) PytorchBatch[source]

Updates the last batch element with data from the current object.

This method modifies the last batch element in the given PytorchBatch object, based on the data available in the current object. The measurements that will be filled in the batch element are determined by the configuration and the ‘measurements_to_fill’ argument.

Parameters:
batch: PytorchBatch

The PytorchBatch object containing the batch element to be updated.

config: StructuredTransformerConfig

A StructuredTransformerConfig object containing configuration data.

measurements_to_fill: set[str | tuple[str, MeasIndexGroupOptions]] | None = None

A set of MEAS_INDEX_GROUP_T that specifies which measurements to fill. If not specified, all dynamic measurements from the config that are not dropped will be filled.

Raises:

ValueError – If ‘time’ is included in the ‘measurements_to_fill’ set.

Returns:

A new PytorchBatch object that includes the updated batch element.

class EventStream.transformer.model_output.NestedIndexableMixin[source]

Bases: object

Mixin for indexable nested elements.

Provides a way to slice through nested indexable elements, using a static method and an instance method for slicing. This will index through dictionaries, tuples, torch distributions, and naturally indexable objects. Inputs of None will likewise return None. This assumes that inhereting classes can be mapped to plain dictionaries via dataclasses.asdict.

slice(idx: int | slice | ellipsis | Sequence[int | slice | ellipsis])[source]

Performs joint index selection on the nested elements.

Parameters:
idx: int | slice | ellipsis | Sequence[int | slice | ellipsis]

The indices to be selected.

Returns:

An instance of the class indexed to the appropriate parameters.

class EventStream.transformer.model_output.StreamClassificationModelOutput(loss: FloatTensor, preds: FloatTensor | None = None, labels: LongTensor | FloatTensor | None = None)[source]

Bases: ModelOutput

Contains all outputs for the Stream Classification Model.

Parameters:
loss: FloatTensor

The overall model loss.

preds: FloatTensor | None = None

The model predictions.

labels: LongTensor | FloatTensor | None = None

The model labels.

labels : LongTensor | FloatTensor = None
loss : FloatTensor
preds : FloatTensor = None
class EventStream.transformer.model_output.TransformerOutputWithPast(last_hidden_state: FloatTensor | None = None, past_key_values: tuple[tuple[FloatTensor]] | dict[str, tuple[FloatTensor]] | None = None, hidden_states: tuple[FloatTensor] | None = None, attentions: tuple[FloatTensor] | None = None)[source]

Bases: ModelOutput

Holds output data from a transformer model.

This class is designed to manage output data from a transformer model, which may include last hidden state, past key values, hidden states, and attentions.

Parameters:
last_hidden_state: FloatTensor | None = None

The last hidden state from the model.

past_key_values: tuple[tuple[FloatTensor]] | dict[str, tuple[FloatTensor]] | None = None

The past key values from the model.

hidden_states: tuple[FloatTensor] | None = None

The hidden states from the model.

attentions: tuple[FloatTensor] | None = None

The attentions from the model.

attentions : tuple[FloatTensor] | None = None
hidden_states : tuple[FloatTensor] | None = None
last_hidden_state : FloatTensor = None
past_key_values : tuple[tuple[FloatTensor]] | dict[str, tuple[FloatTensor]] | None = None
EventStream.transformer.model_output.get_event_types(dynamic_measurement_indices: LongTensor, dynamic_indices: LongTensor, event_type_measurement_idx: int, event_type_vocab_offset: int) LongTensor[source]

Identifies the event types from given dynamic measurements and indices.

Parameters:
dynamic_measurement_indices: LongTensor

Measurement indices to evaluate.

dynamic_indices: LongTensor

Dynamic indices related to the measurements.

event_type_measurement_idx: int

Index to determine the event type.

event_type_vocab_offset: int

Offset value applied to dynamic indices.

Returns:

The identified event types.

Raises:

AssertionError – If there is more than one event type per event.

Examples

>>> import torch
>>> dynamic_measurement_indices = torch.LongTensor([
...     [[1, 2, 2, 2], [1, 2, 2, 0], [2, 2, 1, 0], [2, 1, 0, 0]],
...     [[1, 0, 0, 0], [3, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
... ])
>>> dynamic_indices = torch.LongTensor([
...     [[1, 11, 14, 18], [3, 11, 12, 0], [11, 10, 2, 0], [15, 8, 0, 0]],
...     [[3, 0, 0, 0], [31, 9, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
... ])
>>> event_type_measurement_idx = 1
>>> event_type_vocab_offset = 1
>>> print(get_event_types(
...     dynamic_measurement_indices=dynamic_measurement_indices,
...     dynamic_indices=dynamic_indices,
...     event_type_measurement_idx=event_type_measurement_idx,
...     event_type_vocab_offset=event_type_vocab_offset,
... ))
tensor([[0, 2, 1, 7],
        [2, 8, 0, 0]])
>>> dynamic_measurement_indices = torch.LongTensor([
...     [[1, 1, 2, 2], [1, 2, 2, 0], [2, 2, 1, 0], [2, 1, 0, 0]],
...     [[1, 0, 0, 0], [3, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
... ])
>>> dynamic_indices = torch.LongTensor([
...     [[1, 4, 14, 18], [3, 11, 12, 0], [11, 10, 2, 0], [15, 8, 0, 0]],
...     [[3, 0, 0, 0], [31, 9, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
... ])
>>> get_event_types(
...     dynamic_measurement_indices=dynamic_measurement_indices,
...     dynamic_indices=dynamic_indices,
...     event_type_measurement_idx=event_type_measurement_idx,
...     event_type_vocab_offset=event_type_vocab_offset,
... )
Traceback (most recent call last):
    ...
AssertionError: Got 2 event types per event!
EventStream.transformer.model_output.strip_unused_indices(dynamic_indices, *other_tensors)[source]

Rearranges dynamic_indices and other passed tensors to minimize the number of padding (0) indices.

For each slice of dynamic_indices in the last dimension, this function re-arranges the elements of that slice (in dynamic_indices and all other passed tensors) such that the maximum number of zero-indices are removed and all non-zero indices are at the front of the tensor. This is used during generation, when newly generated elements may fill up the end of the tensor and may have zeros in them which we want to remove to minimize the size of the output tensors.

Parameters:
dynamic_indices

The indices to be evaluated. This is not the dynamic indices as input to the model, but rather that output during generation for a new event, so it is of shape (batch, num_dynamic_measurements)

*other_tensors

Additional tensors to be re-arranged identically to dynamic_indices. All such tensors must have the same shape as dynamic_indices.

Returns:

The processed indices or a tuple of processed tensors.

Examples

>>> import torch
>>> dynamic_indices = torch.LongTensor([
...     [1, 11, 0, 18], [3, 0, 12, 0], [0, 0, 2, 0], [15, 8, 0, 0],
... ])
>>> dynamic_measurement_indices = torch.LongTensor([
...     [1, 2, 3, 4], [1, 2, 3, 0], [2, 2, 1, 0], [2, 1, 0, 0],
... ])
>>> for T in strip_unused_indices(dynamic_indices, dynamic_measurement_indices):
...     print(T)
tensor([[ 1, 11, 18],
        [ 3, 12,  0],
        [ 2,  0,  0],
        [15,  8,  0]])
tensor([[1, 2, 4],
        [1, 3, 0],
        [1, 0, 0],
        [2, 1, 0]])