EventStream.transformer.zero_shot_labeler module¶
- class EventStream.transformer.zero_shot_labeler.Labeler(config: StructuredTransformerConfig)[source]¶
Bases:
ABCA base class for zero-shot labeler functors.
Zero-shot labeler functors are used to enable users to run zero-shot evaluation over novel fine-tuning tasks. To produce a zero-shot labeler, users must:
Sub-class this base class in a new file.
Implement the
__call__method; this method must take as input a batch objectbatch, which will contain newly generated data, and an integralinput_seq_lenparameter which gives how long the input sequence was prior to generation. It must return a tuple of tensors – first, atorch.LongTensorcontaining one-hot classification labels that are implied by the generated sequences in the batch elements, and second, atorch.BoolTensorwhich indicates for each element of the generated set of labels whether or not a label was able to be produced for that sample.Copy the file containing this labeler class into the task directory with the name
${task_df_name}_labeler.py.
You can then use built-in zero-shot evaluation utilities on that task and your labeler will automatically be used to evaluate zero-shot performance via unsupervised generation.
- config¶
The
StructuredTransformerConfigconfig object defining the model being used. This holds information about vocabulary elements, index maps (which is important to decipher batch data into categories), etc.
- abstract __call__(batch: PytorchBatch, input_seq_len: int) tuple[LongTensor, BoolTensor][source]¶
The core labeling method of the class. Must be overwritten by subclass.
- Parameters:¶
- batch: PytorchBatch¶
The PyTorch Batch, containing both the initial raw input data (left padded), followed by the newly generated data.
- input_seq_len: int¶
The number of events (including padding) on the left side of the batch that were the original raw input, rather than the newly generated data. E.g.,
batch[: :input_seq_len]is just events that were in the original input, andbatch[:, input_seq_len:]is just the newly generated events.
- Returns:¶
- The classification labels (in one-hot, [batch_size x vocab_size] format) that
the labeler has generated in response to the input batch.
- torch.BoolTensor: A boolean tensor of shape [batch_size] indicating whether or not each sample in
the original input were able to be parsed into a label (
True) or not (False).
- Return type:¶
torch.LongTensor