EventStream.transformer.zero_shot_labeler module

class EventStream.transformer.zero_shot_labeler.Labeler(config: StructuredTransformerConfig)[source]

Bases: ABC

A base class for zero-shot labeler functors.

Zero-shot labeler functors are used to enable users to run zero-shot evaluation over novel fine-tuning tasks. To produce a zero-shot labeler, users must:

  1. Sub-class this base class in a new file.

  2. Implement the __call__ method; this method must take as input a batch object batch, which will contain newly generated data, and an integral input_seq_len parameter which gives how long the input sequence was prior to generation. It must return a tuple of tensors – first, a torch.LongTensor containing one-hot classification labels that are implied by the generated sequences in the batch elements, and second, a torch.BoolTensor which indicates for each element of the generated set of labels whether or not a label was able to be produced for that sample.

  3. Copy the file containing this labeler class into the task directory with the name ${task_df_name}_labeler.py.

You can then use built-in zero-shot evaluation utilities on that task and your labeler will automatically be used to evaluate zero-shot performance via unsupervised generation.

config

The StructuredTransformerConfig config object defining the model being used. This holds information about vocabulary elements, index maps (which is important to decipher batch data into categories), etc.

abstract __call__(batch: PytorchBatch, input_seq_len: int) tuple[LongTensor, BoolTensor][source]

The core labeling method of the class. Must be overwritten by subclass.

Parameters:
batch: PytorchBatch

The PyTorch Batch, containing both the initial raw input data (left padded), followed by the newly generated data.

input_seq_len: int

The number of events (including padding) on the left side of the batch that were the original raw input, rather than the newly generated data. E.g., batch[: :input_seq_len] is just events that were in the original input, and batch[:, input_seq_len:] is just the newly generated events.

Returns:

The classification labels (in one-hot, [batch_size x vocab_size] format) that

the labeler has generated in response to the input batch.

torch.BoolTensor: A boolean tensor of shape [batch_size] indicating whether or not each sample in

the original input were able to be parsed into a label (True) or not (False).

Return type:

torch.LongTensor