Source code for EventStream.transformer.lightning_modules.embedding

"""A PyTorch Lightning runnable module for getting embeddings from a pre-trained ESGPT model."""

import os
from pathlib import Path
from typing import Any

import lightning as L
import torch
from loguru import logger

from ...data.pytorch_dataset import PytorchDataset
from ..config import StructuredEventProcessingMode, StructuredTransformerConfig
from ..transformer import (
    ConditionallyIndependentPointProcessTransformer,
    NestedAttentionPointProcessTransformer,
    StructuredTransformerPreTrainedModel,
)
from ..utils import safe_masked_max, safe_weighted_avg
from .fine_tuning import FinetuneConfig


[docs] class EmbeddingsOnlyModel(StructuredTransformerPreTrainedModel): """A "model" which simply retrieves the final embeddings from the encoder and returns them.""" def __init__(self, config: StructuredTransformerConfig): super().__init__(config) if self.config.structured_event_processing_mode == StructuredEventProcessingMode.NESTED_ATTENTION: self.encoder = NestedAttentionPointProcessTransformer(config=config) else: self.encoder = ConditionallyIndependentPointProcessTransformer(config)
[docs] def forward(self, *args, **kwargs): return self.encoder(*args, **kwargs)
[docs] class ESTForEmbedding(L.LightningModule): """A PyTorch Lightning Module for extracting embeddings only model.""" def __init__( self, config: StructuredTransformerConfig | dict[str, Any], pretrained_weights_fp: Path, ): """Initializes the Lightning Module. Args: config (`Union[StructuredEventstreamTransformerConfig, Dict[str, Any]]`): The configuration for the underlying `StructuredTransformerForGenerativeSequenceModeling` model. Should be in the dedicated `StructuredTransformerConfig` class or be a dictionary parseable as such. pretrained_weights_fp (`pathlib.Path`): The path to the pre-trained model that should be loaded whose embeddings will be retrieved. """ super().__init__() # If the configurations are dictionaries, convert them to class objects. They may be passed as # dictionaries when the lightning module is loaded from a checkpoint, so we need to support # this functionality. if type(config) is dict: config = StructuredTransformerConfig(**config) self.config = config self.uses_dep_graph = ( self.config.structured_event_processing_mode == StructuredEventProcessingMode.NESTED_ATTENTION ) self.pooling_method = config.task_specific_params["pooling_method"] self.model = EmbeddingsOnlyModel.from_pretrained(pretrained_weights_fp, config=config)
[docs] def predict_step(self, batch, batch_idx): """Retrieves the embeddings and returns them.""" encoded = self.model(batch).last_hidden_state event_encoded = encoded[:, :, -1, :] if self.uses_dep_graph else encoded # `event_encoded` is of shape [batch X seq X hidden_dim]. For pooling, I want to put the sequence # dimension as last, so we'll transpose. event_encoded = event_encoded.transpose(1, 2) match self.pooling_method: case "last": return event_encoded[:, :, -1] case "max": return safe_masked_max(event_encoded, batch["event_mask"]) case "mean": return safe_weighted_avg(event_encoded, batch["event_mask"])[0] case "none": return event_encoded case _: raise ValueError(f"{self.pooling_method} is not a supported pooling method.")
[docs] def get_embeddings(cfg: FinetuneConfig): """Gets embeddings. Writes embeddings to ``cfg.load_from_model_dir / "embeddings" / cfg.task_df_name / "{split}_embeddings.pt"``. Args: cfg: The fine-tuning configuration object specifying the cohort for which and model from which you wish to get embeddings. """ torch.multiprocessing.set_sharing_strategy("file_system") train_pyd = PytorchDataset(cfg.data_config, split="train") tuning_pyd = PytorchDataset(cfg.data_config, split="tuning") held_out_pyd = PytorchDataset(cfg.data_config, split="held_out") config = cfg.config cfg.data_config optimization_config = cfg.optimization_config config.set_to_dataset(train_pyd) # Model LM = ESTForEmbedding(config, pretrained_weights_fp=cfg.pretrained_weights_fp) # Setting up torch dataloader train_dataloader = torch.utils.data.DataLoader( train_pyd, batch_size=optimization_config.validation_batch_size, num_workers=optimization_config.num_dataloader_workers, collate_fn=train_pyd.collate, shuffle=False, ) tuning_dataloader = torch.utils.data.DataLoader( tuning_pyd, batch_size=optimization_config.validation_batch_size, num_workers=optimization_config.num_dataloader_workers, collate_fn=tuning_pyd.collate, shuffle=False, ) held_out_dataloader = torch.utils.data.DataLoader( held_out_pyd, batch_size=optimization_config.validation_batch_size, num_workers=optimization_config.num_dataloader_workers, collate_fn=held_out_pyd.collate, shuffle=False, ) trainer = L.Trainer(**cfg.trainer_config) for sp, dataloader in ( ("train", train_dataloader), ("tuning", tuning_dataloader), ("held_out", held_out_dataloader), ): # Getting Embeddings model embeddings = torch.cat(trainer.predict(LM, dataloader), 0) embeddings_fp = cfg.load_from_model_dir / "embeddings" / cfg.task_df_name / f"{sp}_embeddings.pt" if os.environ.get("LOCAL_RANK", "0") == "0": if embeddings_fp.is_file() and not cfg.do_overwrite: logger.info( f"Embeddings already exist at {embeddings_fp}. To overwrite, set `do_overwrite=True`." ) else: logger.info(f"Saving {sp} embeddings to {embeddings_fp}.") embeddings_fp.parent.mkdir(exist_ok=True, parents=True) torch.save(embeddings, embeddings_fp)