Source code for EventStream.transformer.zero_shot_labeler
import abc
import torch
from ..data.types import PytorchBatch
from .config import StructuredTransformerConfig
[docs]
class Labeler(abc.ABC):
"""A base class for zero-shot labeler functors.
Zero-shot labeler functors are used to enable users to run zero-shot evaluation over novel fine-tuning
tasks. To produce a zero-shot labeler, users must:
1. Sub-class this base class in a new file.
2. Implement the `__call__` method; this method must take as input a batch object `batch`, which will
contain newly generated data, and an integral `input_seq_len` parameter which gives how long the input
sequence was prior to generation. It must return a tuple of tensors -- first, a `torch.LongTensor`
containing one-hot classification labels that are implied by the generated sequences in the batch
elements, and second, a `torch.BoolTensor` which indicates for each element of the generated set of
labels whether or not a label was able to be produced for that sample.
3. Copy the file containing this labeler class into the task directory with the name
`${task_df_name}_labeler.py`.
You can then use built-in zero-shot evaluation utilities on that task and your labeler will automatically
be used to evaluate zero-shot performance via unsupervised generation.
Attributes:
config: The `StructuredTransformerConfig` config object defining the model being used. This holds
information about vocabulary elements, index maps (which is important to decipher batch data into
categories), etc.
.. automethod:: __call__
"""
def __init__(self, config: StructuredTransformerConfig):
self.config = config
[docs]
@abc.abstractmethod
def __call__(self, batch: PytorchBatch, input_seq_len: int) -> tuple[torch.LongTensor, torch.BoolTensor]:
"""The core labeling method of the class. Must be overwritten by subclass.
Args:
batch: The PyTorch Batch, containing both the initial raw input data (left padded), followed by
the newly generated data.
input_seq_len: The number of events (including padding) on the left side of the batch that were
the original raw input, rather than the newly generated data. E.g., `batch[: :input_seq_len]`
is just events that were in the original input, and `batch[:, input_seq_len:]` is just the
newly generated events.
Returns:
torch.LongTensor: The classification labels (in one-hot, [batch_size x vocab_size] format) that
the labeler has generated in response to the input batch.
torch.BoolTensor: A boolean tensor of shape [batch_size] indicating whether or not each sample in
the original input were able to be parsed into a label (`True`) or not (`False`).
"""
raise NotImplementedError("Must be overwritten by a subclass!")