Source code for EventStream.evaluation.general_generative_evaluation

import dataclasses
import os
from datetime import datetime
from multiprocessing import Pool
from pathlib import Path
from typing import Any

import lightning as L
import omegaconf
import polars as pl
import torch
import torch.multiprocessing

from ..data.config import PytorchDatasetConfig, SeqPaddingSide
from ..data.pytorch_dataset import PytorchDataset
from ..data.types import PytorchBatch
from ..transformer.conditionally_independent_model import (
    CIPPTForGenerativeSequenceModeling,
)
from ..transformer.config import (
    OptimizationConfig,
    StructuredEventProcessingMode,
    StructuredTransformerConfig,
)
from ..transformer.nested_attention_model import NAPPTForGenerativeSequenceModeling
from ..utils import hydra_dataclass, task_wrapper


[docs] class ESTForTrajectoryGeneration(L.LightningModule): """A PyTorch Lightning Module for a zero-shot classification via generation for an EST model.""" def __init__( self, config: StructuredTransformerConfig | dict[str, Any], pretrained_weights_fp: Path, ): """Initializes the Lightning Module. Args: config (`Union[StructuredTransformerConfig, Dict[str, Any]]`): The configuration for the underlying model. Should be in the dedicated `StructuredTransformerConfig` class or be a dictionary parseable as such. """ super().__init__() # If the configurations are dictionaries, convert them to class objects. They may be passed as # dictionaries when the lightning module is loaded from a checkpoint, so we need to support # this functionality. if type(config) is dict: config = StructuredTransformerConfig(**config) self.config = config self.num_samples = config.task_specific_params["num_samples"] self.max_new_events = config.task_specific_params["max_new_events"] self.save_hyperparameters({"config": config.to_dict()}) if pretrained_weights_fp is None: raise ValueError("pretrained_weights_fp must be specified") elif self.config.structured_event_processing_mode == StructuredEventProcessingMode.NESTED_ATTENTION: self.model = NAPPTForGenerativeSequenceModeling.from_pretrained( pretrained_weights_fp, config=config ) else: self.model = CIPPTForGenerativeSequenceModeling.from_pretrained( pretrained_weights_fp, config=config )
[docs] def predict_step(self, batch: PytorchBatch, batch_idx: int) -> list[PytorchBatch]: """Prediction step. Generates new samples and writes them out. """ generated_expanded_batch = self.model.generate( batch, max_new_events=self.max_new_events, do_sample=True, return_dict_in_generate=False, output_scores=False, num_return_sequences=self.num_samples, output_attentions=False, output_hidden_states=False, use_cache=True, ) return generated_expanded_batch.split_repeated_batch(self.num_samples)
[docs] @hydra_dataclass class GenerateConfig: load_from_model_dir: str | Path = omegaconf.MISSING seed: int = 1 pretrained_weights_fp: Path | None = None save_dir: str | None = None do_overwrite: bool = False optimization_config: OptimizationConfig = dataclasses.field(default_factory=lambda: OptimizationConfig()) task_df_name: str | None = None data_config_overrides: dict[str, Any] | None = dataclasses.field( default_factory=lambda: { "seq_padding_side": SeqPaddingSide.LEFT, "do_include_start_time_min": True, "do_include_subsequence_indices": True, "do_include_subject_id": True, } ) trainer_config: dict[str, Any] = dataclasses.field( default_factory=lambda: { "accelerator": "auto", "devices": "auto", "detect_anomaly": False, "default_root_dir": None, } ) task_specific_params: dict[str, Any] = dataclasses.field( default_factory=lambda: { "num_samples": omegaconf.MISSING, "max_new_events": omegaconf.MISSING, } ) config_overrides: dict[str, Any] = dataclasses.field(default_factory=lambda: {}) parallelize_conversion: int | None = None def __post_init__(self): if isinstance(self.save_dir, str): self.save_dir = Path(self.save_dir) if self.load_from_model_dir in (None, omegaconf.MISSING): raise ValueError("Must load from a model!") if type(self.load_from_model_dir) is str: self.load_from_model_dir = Path(self.load_from_model_dir) if self.pretrained_weights_fp is None: self.pretrained_weights_fp = self.load_from_model_dir / "pretrained_weights" if self.save_dir is None: if self.task_df_name is not None: self.save_dir = self.load_from_model_dir / "finetuning" / self.task_df_name else: self.save_dir = self.load_from_model_dir if self.trainer_config.get("default_root_dir", None) is None: self.trainer_config["default_root_dir"] = self.save_dir / "model_checkpoints" data_config_fp = self.load_from_model_dir / "data_config.json" print(f"Loading data_config from {data_config_fp}") self.data_config = PytorchDatasetConfig.from_json_file(data_config_fp) if self.task_df_name is not None: self.data_config.task_df_name = self.task_df_name for param, val in self.data_config_overrides.items(): if param == "task_df_name": print( f"WARNING: task_df_name is set in data_config_overrides to {val}! " f"Original is {self.task_df_name}. Ignoring data_config_overrides..." ) continue print(f"Overwriting {param} in data_config from {getattr(self.data_config, param)} to {val}") setattr(self.data_config, param, val) config_fp = self.load_from_model_dir / "config.json" print(f"Loading config from {config_fp}") self.config = StructuredTransformerConfig.from_json_file(config_fp) for param, val in self.config_overrides.items(): print(f"Overwriting {param} in config from {getattr(self.config, param)} to {val}") setattr(self.config, param, val) if self.task_specific_params is None: raise ValueError("Must specify num samples to generate") if ( self.data_config_overrides.get("max_seq_len", None) is None and self.task_specific_params.get("max_new_events", None) is not None ): self.data_config.max_seq_len = ( self.config.max_seq_len - self.task_specific_params["max_new_events"] ) implied_max_new_events = self.config.max_seq_len - self.data_config.max_seq_len if implied_max_new_events <= 0: raise ValueError("Implied to not be generating any new events!") if self.config.task_specific_params is None: self.config.task_specific_params = {} self.config.task_specific_params.update(self.task_specific_params) if self.task_specific_params.get("max_new_events", None) in (omegaconf.MISSING, None): self.config.task_specific_params["max_new_events"] = implied_max_new_events assert self.config.task_specific_params["max_new_events"] == implied_max_new_events
[docs] @task_wrapper def generate_trajectories(cfg: GenerateConfig): L.seed_everything(cfg.seed) torch.multiprocessing.set_sharing_strategy("file_system") tuning_pyd = PytorchDataset(cfg.data_config, split="tuning") held_out_pyd = PytorchDataset(cfg.data_config, split="held_out") config = cfg.config cfg.data_config batch_size = cfg.optimization_config.validation_batch_size num_dataloader_workers = cfg.optimization_config.num_dataloader_workers orig_max_seq_len = config.max_seq_len orig_mean_log_inter_event_time = config.mean_log_inter_event_time_min orig_std_log_inter_event_time = config.std_log_inter_event_time_min config.set_to_dataset(tuning_pyd) config.max_seq_len = orig_max_seq_len config.mean_log_inter_event_time_min = orig_mean_log_inter_event_time config.std_log_inter_event_time_min = orig_std_log_inter_event_time output_dir = cfg.save_dir / "generated_trajectories" # Model LM = ESTForTrajectoryGeneration( config=config, pretrained_weights_fp=cfg.pretrained_weights_fp, ) # Setting up torch dataloader tuning_dataloader = torch.utils.data.DataLoader( tuning_pyd, batch_size=batch_size, num_workers=num_dataloader_workers, collate_fn=tuning_pyd.collate, shuffle=False, ) held_out_dataloader = torch.utils.data.DataLoader( held_out_pyd, batch_size=batch_size, num_workers=num_dataloader_workers, collate_fn=held_out_pyd.collate, shuffle=False, ) trainer = L.Trainer(**cfg.trainer_config) tuning_trajectories = trainer.predict(model=LM, dataloaders=tuning_dataloader) local_rank = os.environ.get("LOCAL_RANK", "0") for samp_idx, gen_batches in enumerate(zip(*tuning_trajectories)): out_fp = output_dir / "tuning" / f"sample_{samp_idx}_local_rank_{local_rank}.parquet" out_fp.parent.mkdir(exist_ok=True, parents=True) st_convert = datetime.now() print(f"Converting to DFs for sample {samp_idx}...") if cfg.parallelize_conversion is not None and cfg.parallelize_conversion > 1: with Pool(cfg.parallelize_conversion) as p: dfs = p.map(PytorchBatch.convert_to_DL_DF, gen_batches) else: dfs = [B.convert_to_DL_DF() for B in gen_batches] print(f"Conversion done in {datetime.now() - st_convert}") st_write = datetime.now() print(f"Writing DF to {out_fp}...") pl.concat(dfs).write_parquet(out_fp) print(f"Writing done in {datetime.now() - st_write}") held_out_trajectories = trainer.predict(model=LM, dataloaders=held_out_dataloader) for samp_idx, gen_batches in enumerate(zip(*held_out_trajectories)): out_fp = output_dir / "held_out" / f"sample_{samp_idx}_local_rank_{local_rank}.parquet" out_fp.parent.mkdir(exist_ok=True, parents=True) st_convert = datetime.now() print(f"Converting to DFs for sample {samp_idx}...") if cfg.parallelize_conversion is not None and cfg.parallelize_conversion > 1: with Pool(cfg.parallelize_conversion) as p: dfs = p.map(PytorchBatch.convert_to_DL_DF, gen_batches) else: dfs = [B.convert_to_DL_DF() for B in gen_batches] print(f"Conversion done in {datetime.now() - st_convert}") print(f"Conversion done in {datetime.now() - st_convert}") st_write = datetime.now() print(f"Writing DF to {out_fp}...") pl.concat(dfs).write_parquet(out_fp) print(f"Writing done in {datetime.now() - st_write}")