EventStream.transformer.utils module

EventStream.transformer.utils.expand_indexed_regression(X: Tensor, idx: Tensor, vocab_size: int)[source]

Expands sparse values X with indices idx into a dense representation.

Parameters:
X: Tensor

A tensor of shape […, # of observed values] containing observed values. Shape must match that of idx.

idx: Tensor

A tensor of shape […, # of observed values] containing indices of observed values. Each index must be in the range [0, vocab_size). Shape must match that of X.

vocab_size: int

The size of the vocabulary to expand into. Indices in idx are indexes into this vocabulary.

Returns:

A dense tensor of shape […, vocab_size], such that the value at index idx[i] in the last dimension is X[i] for all i and the value at all other indices is 0.

Examples

>>> import torch
>>> X = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
>>> idx = torch.LongTensor([[0, 1, 2], [1, 3, 0]])
>>> vocab_size = 5
>>> expand_indexed_regression(X, idx, vocab_size)
tensor([[1., 2., 3., 0., 0.],
        [6., 4., 0., 5., 0.]])
EventStream.transformer.utils.idx_distribution(D: Distribution, index: int | slice | ellipsis | Sequence[int | slice | ellipsis]) Distribution[source]

Slices a torch Distribution so its outputs are of the appropriate shape.

Torch distributions output tensors of consistent shape upon sample(). This method slices their internal parameters so as to yield a new, transformed distribution whose outputs are sliced into a desired shape. In many cases, but not all, if you use the slice/index you would use on an output sample() as the index input to this method, the output distribution will have the desired shape. Only works with select distributions. Sourced from: https://github.com/pytorch/pytorch/issues/52625 on 2-16-22 at 12:40 ET.

Parameters:
D: Distribution

The distribution to slice.

index: int | slice | ellipsis | Sequence[int | slice | ellipsis]

The index or slice to apply to the parameters.

Returns:

The sliced distribution.

Raises:

IndexError – If the index is invalid for the distribution.

Examples

>>> import torch
>>> logits_tensor = torch.Tensor([[1, 2, -3], [4, 1, 0]])
>>> D = torch.distributions.Bernoulli(logits=logits_tensor)
>>> D.sample().shape
torch.Size([2, 3])
>>> D2 = idx_distribution(D, (slice(None), slice(None, 1)))
>>> D2.sample().shape
torch.Size([2, 1])
>>> D2.logits
tensor([[1.],
        [4.]])
>>> probs_tensor = torch.FloatTensor([[0.1, 0.2, 0.7], [0.2, 0.8, 0.0]])
>>> D = torch.distributions.Categorical(probs=probs_tensor)
>>> D.sample().shape
torch.Size([2])
>>> D2 = idx_distribution(D, 1)
>>> D2.sample().shape
torch.Size([])
>>> # We have to round because distributions modify their probs params which yields precision errors
>>> D2.probs.round(decimals=1)
tensor([0.2000, 0.8000, 0.0000])
>>> D2 = idx_distribution(D, (slice(None), 2))
Traceback (most recent call last):
    ...
IndexError: Failed to slice probs of shape torch.Size([2, 3]) with (slice(None, None, None), 2) + (:,) * 1 = (slice(None, None, None), 2, slice(None, None, None))
EventStream.transformer.utils.safe_masked_max(X: Tensor, mask: BoolTensor) Tensor[source]

Returns a safe max over the last dimension of X respecting the mask mask.

This function takes the max over all elements of the last dimension of X where mask is True. mask can take one of two forms:

  • An element-wise mask, in which case it must have the same shape as X.

  • A column-wise mask, in which case it must have the same shape as X excluding the second to last dimension, which should be omitted, This case is used when you wish to, for example, take the maximum of the hidden states of a network over the sequence length, while respecting an event mask.

If mask is uniformly False for a row, the output is zero.

Parameters:
X: Tensor

A tensor of shape […, # of rows, # of columns] containing elements to take the max over.

mask: BoolTensor

A Boolean tensor either of shape […, # of rows, # of columns] or […, # of columns] containing a mask indicating which elements can be considered for the max.

Returns:

A tensor of shape […] containing the max over the last dimension of X respecting the mask mask. If mask is uniformly False for a row, the output is zero.

Raises:

AssertionError – If mask is not the correct shape for either mode.

Examples

>>> import torch
>>> # An element-wise mask
>>> X = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
>>> mask = torch.BoolTensor([[True, True, False], [False, False, False]])
>>> safe_masked_max(X, mask)
tensor([2., 0.])
>>> # A column-wise mask, with a batch dimension.
>>> X = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
>>> mask = torch.BoolTensor([[False, True, False], [True, False, True]])
>>> safe_masked_max(X, mask)
tensor([[ 2.,  5.],
        [ 9., 12.]])
>>> X = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
>>> mask = torch.BoolTensor([[[False, True], [True, True]]])
>>> safe_masked_max(X, mask)
Traceback (most recent call last):
    ...
AssertionError: mask torch.Size([1, 2, 2]) must be the same shape as X torch.Size([2, 2, 3]) or the same shape as X excluding the second to last dimension
>>> X = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
>>> mask = torch.BoolTensor([[False, True], [True, True]])
>>> safe_masked_max(X, mask)
Traceback (most recent call last):
    ...
AssertionError: mask torch.Size([2, 2]) must be the same shape as X torch.Size([2, 2, 3]) or the same shape as X excluding the second to last dimension
EventStream.transformer.utils.safe_weighted_avg(X: Tensor, weights: Tensor) tuple[Tensor, Tensor][source]

Returns the weighted average of the last dimension of X, weighted by weights.

Parameters:
X: Tensor

A tensor containing elements to take the weighted average over.

weights: Tensor

The weights for the weighted average. Must be >= 0, and can take one of two forms:

  • Element-wise, in which case it must have the same shape as X.

  • Column-wise, in which case it must have the same shape as X excluding the second to last dimension, which should be omitted, This case is used when you wish to, for example, take the average of the hidden states of a network over the sequence length, while respecting an event mask.

Returns:

  • The weighted average of the last dimension of X weighted by weights for that index, unless the weights for that index sum to 0, in which case the output returned is zero.

  • The sum of the weights for that index (the denominator of the weighted average).

Return type:

For each index in the last dimension of X, returns a tuple containing

Raises:

AssertionError – If weights contains negative elements or has an invalid shape.

Examples

>>> import torch
>>> X = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
>>> weights = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
>>> safe_weighted_avg(X, weights)
(tensor([2.3333, 5.1333]), tensor([ 6., 15.]))
>>> X = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
>>> weights = torch.FloatTensor([[0, 0, 0], [1, 0, 0]])
>>> safe_weighted_avg(X, weights)
(tensor([0., 4.]), tensor([0., 1.]))
>>> X = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
>>> weights = torch.FloatTensor([[0, 0, 0], [-1, 0, 0]])
>>> safe_weighted_avg(X, weights)
Traceback (most recent call last):
    ...
AssertionError: weights should be >= 0. Got min -1.0
>>> X = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
>>> weights = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]])
>>> safe_weighted_avg(X, weights)
Traceback (most recent call last):
    ...
AssertionError: weights torch.Size([2, 3, 2]) must be the same shape as X torch.Size([2, 2, 3]) or the same shape as X excluding the second to last dimension
EventStream.transformer.utils.str_summary(T: Tensor)[source]

Returns a string summary of a tensor for debugging purposes.

Parameters:
T: Tensor

The tensor to summarize.

Returns:

A string summary of the tensor, documenting the tensor’s shape, dtype, and the range of values it contains.

Examples

>>> import torch
>>> T = torch.FloatTensor([[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]])
>>> str_summary(T)
'shape: (1, 2, 5), type: torch.float32, range: 1-10'
>>> T = torch.LongTensor([[2, 3, -4, 5], [6, 7, 9, -10]])
>>> str_summary(T)
'shape: (2, 4), type: torch.int64, range: -10-9'
EventStream.transformer.utils.weighted_loss(loss_per_event: Tensor, event_mask: Tensor) Tensor[source]

Returns the weighted average of the average per-event loss for each subject.

Given a tensor loss_per_event of shape [# subjects, # events] containing loss values per event per subject and a tensor event_mask containing binary indicators of whether any given event is present or not, returns the average per-subject of the average per-event loss for each subject, excluding subjects who have no events.

Parameters:
loss_per_event: Tensor

A tensor of shape [# subjects, # events] containing loss values per event

event_mask: Tensor

A tensor of shape [# subjects, # events] containing binary indicators of whether an event was present or not.

Returns:

A tensor of shape [] containing the weighted average of the average per-event loss for each subject, excluding subjects who have no events. If no subjects have any events, returns 0.

Examples

>>> import torch
>>> loss_per_event = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
>>> event_mask = torch.FloatTensor([[1, 1, 1], [1, 0, 0]])
>>> weighted_loss(loss_per_event, event_mask)
tensor(3.)