import inspect
from collections.abc import Sequence
from typing import Union
import torch
VALID_INDEX_T = Union[int, slice, type(Ellipsis)]
INDEX_SELECT_T = Union[VALID_INDEX_T, Sequence[VALID_INDEX_T]]
[docs]
def str_summary(T: torch.Tensor):
"""Returns a string summary of a tensor for debugging purposes.
Args:
T: The tensor to summarize.
Returns:
A string summary of the tensor, documenting the tensor's shape, dtype, and the range of values it
contains.
Examples:
>>> import torch
>>> T = torch.FloatTensor([[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]])
>>> str_summary(T)
'shape: (1, 2, 5), type: torch.float32, range: 1-10'
>>> T = torch.LongTensor([[2, 3, -4, 5], [6, 7, 9, -10]])
>>> str_summary(T)
'shape: (2, 4), type: torch.int64, range: -10-9'
"""
return f"shape: {tuple(T.shape)}, type: {T.dtype}, range: {T.min():n}-{T.max():n}"
[docs]
def expand_indexed_regression(X: torch.Tensor, idx: torch.Tensor, vocab_size: int):
"""Expands sparse values `X` with indices `idx` into a dense representation.
Args:
X: A tensor of shape [..., # of observed values] containing observed values. Shape must match that of
`idx`.
idx: A tensor of shape [..., # of observed values] containing indices of observed values. Each index
must be in the range [0, `vocab_size`). Shape must match that of `X`.
vocab_size: The size of the vocabulary to expand into. Indices in `idx` are indexes into this
vocabulary.
Returns:
A dense tensor of shape [..., `vocab_size`], such that the value at index `idx[i]` in the last
dimension is `X[i]` for all `i` and the value at all other indices is 0.
Examples:
>>> import torch
>>> X = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
>>> idx = torch.LongTensor([[0, 1, 2], [1, 3, 0]])
>>> vocab_size = 5
>>> expand_indexed_regression(X, idx, vocab_size)
tensor([[1., 2., 3., 0., 0.],
[6., 4., 0., 5., 0.]])
"""
expanded = torch.zeros(*idx.shape[:-1], vocab_size, device=X.device, dtype=X.dtype)
return expanded.scatter(-1, idx, X)
[docs]
def safe_masked_max(X: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
"""Returns a safe max over the last dimension of `X` respecting the mask `mask`.
This function takes the max over all elements of the last dimension of `X` where `mask` is True. `mask`
can take one of two forms:
* An element-wise mask, in which case it must have the same shape as `X`.
* A column-wise mask, in which case it must have the same shape as `X` excluding the second to last
dimension, which should be omitted, This case is used when you wish to, for example, take the
maximum of the hidden states of a network over the sequence length, while respecting an event
mask.
If `mask` is uniformly False for a row, the output is zero.
Args:
X: A tensor of shape [..., # of rows, # of columns] containing elements to take the max over.
mask: A Boolean tensor either of shape [..., # of rows, # of columns] or [..., # of columns]
containing a mask indicating which elements can be considered for the max.
Returns:
A tensor of shape [...] containing the max over the last dimension of `X` respecting the mask `mask`.
If `mask` is uniformly False for a row, the output is zero.
Raises:
AssertionError: If `mask` is not the correct shape for either mode.
Examples:
>>> import torch
>>> # An element-wise mask
>>> X = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
>>> mask = torch.BoolTensor([[True, True, False], [False, False, False]])
>>> safe_masked_max(X, mask)
tensor([2., 0.])
>>> # A column-wise mask, with a batch dimension.
>>> X = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
>>> mask = torch.BoolTensor([[False, True, False], [True, False, True]])
>>> safe_masked_max(X, mask)
tensor([[ 2., 5.],
[ 9., 12.]])
>>> X = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
>>> mask = torch.BoolTensor([[[False, True], [True, True]]])
>>> safe_masked_max(X, mask)
Traceback (most recent call last):
...
AssertionError: mask torch.Size([1, 2, 2]) must be the same shape as X torch.Size([2, 2, 3])\
or the same shape as X excluding the second to last dimension
>>> X = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
>>> mask = torch.BoolTensor([[False, True], [True, True]])
>>> safe_masked_max(X, mask)
Traceback (most recent call last):
...
AssertionError: mask torch.Size([2, 2]) must be the same shape as X torch.Size([2, 2, 3])\
or the same shape as X excluding the second to last dimension
"""
shape_err_string = (
f"mask {mask.shape} must be the same shape as X {X.shape} "
"or the same shape as X excluding the second to last dimension"
)
if len(mask.shape) < len(X.shape):
try:
mask = mask.unsqueeze(-2).expand_as(X)
except RuntimeError as e:
raise AssertionError(shape_err_string) from e
else:
torch._assert(mask.shape == X.shape, shape_err_string)
masked_X = torch.where(mask, X, -float("inf"))
maxes = masked_X.max(-1)[0]
return torch.nan_to_num(maxes, nan=None, posinf=None, neginf=0)
[docs]
def safe_weighted_avg(X: torch.Tensor, weights: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns the weighted average of the last dimension of `X`, weighted by `weights`.
Args:
X: A tensor containing elements to take the weighted average over.
weights: The weights for the weighted average. Must be >= 0, and can take one of two forms:
* Element-wise, in which case it must have the same shape as `X`.
* Column-wise, in which case it must have the same shape as `X` excluding the second to last
dimension, which should be omitted, This case is used when you wish to, for example, take the
average of the hidden states of a network over the sequence length, while respecting an event
mask.
Returns:
For each index in the last dimension of `X`, returns a tuple containing:
* The weighted average of the last dimension of `X` weighted by `weights` for that index, unless
the weights for that index sum to 0, in which case the output returned is zero.
* The sum of the weights for that index (the denominator of the weighted average).
Raises:
AssertionError: If `weights` contains negative elements or has an invalid shape.
Examples:
>>> import torch
>>> X = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
>>> weights = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
>>> safe_weighted_avg(X, weights)
(tensor([2.3333, 5.1333]), tensor([ 6., 15.]))
>>> X = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
>>> weights = torch.FloatTensor([[0, 0, 0], [1, 0, 0]])
>>> safe_weighted_avg(X, weights)
(tensor([0., 4.]), tensor([0., 1.]))
>>> X = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
>>> weights = torch.FloatTensor([[0, 0, 0], [-1, 0, 0]])
>>> safe_weighted_avg(X, weights)
Traceback (most recent call last):
...
AssertionError: weights should be >= 0. Got min -1.0
>>> X = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
>>> weights = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]])
>>> safe_weighted_avg(X, weights)
Traceback (most recent call last):
...
AssertionError: weights torch.Size([2, 3, 2]) must be the same shape as X torch.Size([2, 2, 3])\
or the same shape as X excluding the second to last dimension
"""
torch._assert(
(weights >= 0).all(),
f"weights should be >= 0. Got min {torch.min(weights)}",
)
shape_err_string = (
f"weights {weights.shape} must be the same shape as X {X.shape} "
"or the same shape as X excluding the second to last dimension"
)
if len(weights.shape) < len(X.shape):
try:
weights = weights.unsqueeze(-2).expand_as(X)
except RuntimeError as e:
raise AssertionError(shape_err_string) from e
else:
torch._assert(weights.shape == X.shape, shape_err_string)
denom = weights.float().sum(dim=-1)
safe_denom = torch.where(denom > 0, denom, torch.ones_like(denom))
return (
torch.where(denom > 0, (X * weights.float()).sum(dim=-1) / safe_denom, torch.zeros_like(denom)),
denom,
)
[docs]
def weighted_loss(loss_per_event: torch.Tensor, event_mask: torch.Tensor) -> torch.Tensor:
"""Returns the weighted average of the average per-event loss for each subject.
Given a tensor `loss_per_event` of shape [# subjects, # events] containing loss values per
event per subject and a tensor `event_mask` containing binary indicators of whether any given
event is present or not, returns the average per-subject of the average per-event loss for each
subject, excluding subjects who have no events.
Args:
loss_per_event: A tensor of shape [# subjects, # events] containing loss values per event
event_mask: A tensor of shape [# subjects, # events] containing binary indicators of whether an event
was present or not.
Returns:
A tensor of shape [] containing the weighted average of the average per-event loss for each subject,
excluding subjects who have no events. If no subjects have any events, returns 0.
Examples:
>>> import torch
>>> loss_per_event = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
>>> event_mask = torch.FloatTensor([[1, 1, 1], [1, 0, 0]])
>>> weighted_loss(loss_per_event, event_mask)
tensor(3.)
"""
loss_per_subject, events_per_subject = safe_weighted_avg(loss_per_event, event_mask)
return safe_weighted_avg(loss_per_subject, (events_per_subject > 0))[0]
_PROBS_LOGITS_NOT_BOTH_DISTRIBUTIONS: tuple[torch.distributions.Distribution] = (
torch.distributions.Bernoulli,
torch.distributions.Binomial,
torch.distributions.Categorical,
torch.distributions.ContinuousBernoulli,
torch.distributions.Multinomial,
torch.distributions.RelaxedBernoulli,
)
[docs]
def idx_distribution(
D: torch.distributions.Distribution,
index: INDEX_SELECT_T,
) -> torch.distributions.Distribution:
"""Slices a torch Distribution so its outputs are of the appropriate shape.
Torch distributions output tensors of consistent shape upon sample(). This method slices their internal
parameters so as to yield a new, transformed distribution whose outputs are sliced into a desired shape.
In many cases, but not all, if you use the slice/index you would use on an output sample() as the `index`
input to this method, the output distribution will have the desired shape.
Only works with select distributions.
Sourced from: https://github.com/pytorch/pytorch/issues/52625 on 2-16-22 at 12:40 ET.
Args:
D: The distribution to slice.
index: The index or slice to apply to the parameters.
Returns:
The sliced distribution.
Raises:
IndexError: If the index is invalid for the distribution.
Examples:
>>> import torch
>>> logits_tensor = torch.Tensor([[1, 2, -3], [4, 1, 0]])
>>> D = torch.distributions.Bernoulli(logits=logits_tensor)
>>> D.sample().shape
torch.Size([2, 3])
>>> D2 = idx_distribution(D, (slice(None), slice(None, 1)))
>>> D2.sample().shape
torch.Size([2, 1])
>>> D2.logits
tensor([[1.],
[4.]])
>>> probs_tensor = torch.FloatTensor([[0.1, 0.2, 0.7], [0.2, 0.8, 0.0]])
>>> D = torch.distributions.Categorical(probs=probs_tensor)
>>> D.sample().shape
torch.Size([2])
>>> D2 = idx_distribution(D, 1)
>>> D2.sample().shape
torch.Size([])
>>> # We have to round because distributions modify their probs params which yields precision errors
>>> D2.probs.round(decimals=1)
tensor([0.2000, 0.8000, 0.0000])
>>> D2 = idx_distribution(D, (slice(None), 2))
Traceback (most recent call last):
...
IndexError: Failed to slice probs of shape torch.Size([2, 3]) with\
(slice(None, None, None), 2) + (:,) * 1 = (slice(None, None, None), 2, slice(None, None, None))
"""
if not isinstance(index, tuple):
index = (index,)
# For custom distributions
if hasattr(D, "__getitem__"):
return D[index]
# We need to handle mixture and transformed distributions separately.
if isinstance(D, torch.distributions.MixtureSameFamily):
mixture_dist = D.mixture_distribution
component_dist = D.component_distribution
result = torch.distributions.MixtureSameFamily(
mixture_distribution=idx_distribution(mixture_dist, index),
component_distribution=idx_distribution(component_dist, index),
validate_args=False,
)
elif isinstance(D, torch.distributions.TransformedDistribution):
transforms = D.transforms
for transform in transforms:
assert transform.sign in (-1, 1) # Asserts transforms are univariate bij
base_dist = D.base_dist
result = torch.distributions.TransformedDistribution(
transforms=transforms,
base_distribution=idx_distribution(base_dist, index),
validate_args=False,
)
else:
params = {}
colon = (slice(None),)
for name, constraint in D.arg_constraints.items():
try:
params[name] = getattr(D, name)[index + colon * constraint.event_dim]
except IndexError as e:
raise IndexError(
f"Failed to slice {name} of shape {getattr(D, name).shape} with "
f"{index} + (:,) * {constraint.event_dim} = {index + colon * constraint.event_dim}"
) from e
cls = type(D)
if "validate_args" in inspect.signature(cls).parameters.keys():
params["validate_args"] = False
if isinstance(D, _PROBS_LOGITS_NOT_BOTH_DISTRIBUTIONS) and "probs" in params and "logits" in params:
params.pop("probs")
result = cls(**params)
if hasattr(D, "_validate_args"):
result._validate_args = getattr(D, "_validate_args")
return result