Source code for EventStream.transformer.fine_tuning_model

"""A model for fine-tuning on classification tasks."""
import torch

from ..data.types import PytorchBatch
from .config import StructuredEventProcessingMode, StructuredTransformerConfig
from .model_output import StreamClassificationModelOutput
from .transformer import (
    ConditionallyIndependentPointProcessTransformer,
    NestedAttentionPointProcessTransformer,
    StructuredTransformerPreTrainedModel,
)
from .utils import safe_masked_max, safe_weighted_avg


[docs] class ESTForStreamClassification(StructuredTransformerPreTrainedModel): """A model for fine-tuning on classification tasks. Args: config: The model configuration class to use. This must contain the relevant fine-tuning task information (e.g., `num_labels`, `finetuning_task`, `pooling_method`, and `id2label`). """ def __init__( self, config: StructuredTransformerConfig, ): super().__init__(config) self.task = config.finetuning_task if self._uses_dep_graph: self.encoder = NestedAttentionPointProcessTransformer(config) else: self.encoder = ConditionallyIndependentPointProcessTransformer(config) self.pooling_method = config.task_specific_params["pooling_method"] is_binary = config.id2label == {0: False, 1: True} if is_binary: assert config.num_labels == 2 self.logit_layer = torch.nn.Linear(config.hidden_size, 1) self.criteria = torch.nn.BCEWithLogitsLoss() else: self.logit_layer = torch.nn.Linear(config.hidden_size, config.num_labels) self.criteria = torch.nn.CrossEntropyLoss() # Initialize weights and apply final processing self.post_init() @property def _uses_dep_graph(self): return self.config.structured_event_processing_mode == StructuredEventProcessingMode.NESTED_ATTENTION
[docs] def forward(self, batch: PytorchBatch, **kwargs) -> StreamClassificationModelOutput: """Runs the forward pass through the fine-tuning label prediction. Args: batch: The batch of data to model. Returns: A `StreamClassificationModelOutput` object capturing loss, predictions, and labels for the fine-tuning task in question. """ encoded = self.encoder(batch, **kwargs).last_hidden_state event_encoded = encoded[:, :, -1, :] if self._uses_dep_graph else encoded # `event_encoded` is of shape [batch X seq X hidden_dim]. For pooling, I want to put the sequence # dimension as last, so we'll transpose. event_encoded = event_encoded.transpose(1, 2) match self.pooling_method: case "cls": stream_encoded = event_encoded[:, :, 0] case "last": stream_encoded = event_encoded[:, :, -1] case "max": stream_encoded = safe_masked_max(event_encoded, batch["event_mask"]) case "mean": stream_encoded, _ = safe_weighted_avg(event_encoded, batch["event_mask"]) case _: raise ValueError(f"{self.pooling_method} is not a supported pooling method.") logits = self.logit_layer(stream_encoded).squeeze(-1) labels = batch["stream_labels"][self.task] loss = self.criteria(logits, labels) return StreamClassificationModelOutput( loss=loss, preds=logits, labels=labels, )