EventStream.transformer.lightning_modules.fine_tuning module

class EventStream.transformer.lightning_modules.fine_tuning.ESTForStreamClassificationLM(config: StructuredTransformerConfig | dict[str, Any], optimization_config: OptimizationConfig | dict[str, Any], pretrained_weights_fp: Path | str | None = None, do_debug_mode: bool = True)[source]

Bases: LightningModule

A PyTorch Lightning Module for a ESTForStreamClassification model.

build_metrics()[source]

Build the various torchmetrics we’ll use to track performance.

configure_optimizers()[source]

Configures optimizer and learning rate scheduler.

Currently this module uses the AdamW optimizer, with configurable weight_decay, with a learning rate warming up from 0 on a per-step manner to the configurable self.optimization_config.init_lr, then undergoes polynomial decay as specified via self.optimization_config.

log_metrics(results: StreamClassificationModelOutput, skip_metrics: Sequence[str], prefix: str)[source]

Logs metric results for a given output result.

Parameters:
results: StreamClassificationModelOutput

The results to assess across the suite of metrics.

skip_metrics: Sequence[str]

A list of metrics to skip. Entries are not full metric names, but rather are partial names and any metric whose name contains an element of skip_metrics will be skipped. For example, if skip_metrics = ['AUROC', 'AUPRC'], then a metric with name 'macro_AUROC' or 'micro_AUPRC' would be skipped, whereas a metric named 'weighted_accuracy' would not.

prefix: str

The prefix that should be used when logging metric results. Will likely be ‘train’, ‘tuning’, or ‘held_out’, for example.

save_pretrained(model_dir: Path)[source]
test_step(batch, batch_idx)[source]

Validation step.

Differs from training only in that it does not skip metrics.

training_step(batch, batch_idx)[source]

Training step.

Skips logging all AUROC, AUPRC, and per_class metric to save compute.

validation_step(batch, batch_idx)[source]

Validation step.

Differs from training only in that it does not skip metrics.

class EventStream.transformer.lightning_modules.fine_tuning.FinetuneConfig(experiment_dir: str | pathlib.Path | None = '${load_from_model_dir}/finetuning', load_from_model_dir: str | pathlib.Path | None = '???', task_df_name: str | None = '???', pretrained_weights_fp: pathlib.Path | str | None = '${load_from_model_dir}/pretrained_weights', save_dir: str | None = '${experiment_dir}/${task_df_name}/subset_size_${data_config.train_subset_size}/subset_seed_${data_config.train_subset_seed}/${now:%Y-%m-%d_%H-%M-%S}', wandb_logger_kwargs: dict[str, typing.Any] = <factory>, wandb_experiment_config_kwargs: dict[str, typing.Any] = <factory>, do_overwrite: bool = False, seed: int = 1, config: dict[str, typing.Any] = <factory>, optimization_config: EventStream.transformer.config.OptimizationConfig = <factory>, data_config: dict[str, typing.Any] | None = <factory>, trainer_config: dict[str, typing.Any] = <factory>, do_use_filesystem_sharing: bool = True)[source]

Bases: object

config : dict[str, Any]
data_config : dict[str, Any] | None
do_overwrite : bool = False
do_use_filesystem_sharing : bool = True
experiment_dir : str | Path | None = '${load_from_model_dir}/finetuning'
load_from_model_dir : str | Path | None = '???'
optimization_config : OptimizationConfig
pretrained_weights_fp : Path | str | None = '${load_from_model_dir}/pretrained_weights'
save_dir : str | None = '${experiment_dir}/${task_df_name}/subset_size_${data_config.train_subset_size}/subset_seed_${data_config.train_subset_seed}/${now:%Y-%m-%d_%H-%M-%S}'
seed : int = 1
task_df_name : str | None = '???'
trainer_config : dict[str, Any]
wandb_experiment_config_kwargs : dict[str, Any]
wandb_logger_kwargs : dict[str, Any]
EventStream.transformer.lightning_modules.fine_tuning.train(cfg: FinetuneConfig)[source]

Runs the end to end training procedure for the fine-tuning model.

Parameters:
cfg: FinetuneConfig

The fine-tuning configuration object specifying the cohort and task for which and model from which you wish to fine-tune.