EventStream.transformer.lightning_modules.generative_modeling module

class EventStream.transformer.lightning_modules.generative_modeling.ESTForGenerativeSequenceModelingLM(config: StructuredTransformerConfig | dict[str, Any], optimization_config: OptimizationConfig | dict[str, Any], metrics_config: MetricsConfig | dict[str, Any], pretrained_weights_fp: Path | None = None)[source]

Bases: LightningModule

A PyTorch Lightning Module for a ESTForGenerativeSequenceModeling.

CLASSIFICATION = {DataModality.MULTI_LABEL_CLASSIFICATION, DataModality.SINGLE_LABEL_CLASSIFICATION}
TRAIN_SKIP_METRICS = ('AUROC', 'AUPRC', 'per_class')
build_metrics()[source]

Build the various torchmetrics we’ll use to track performance.

configure_optimizers()[source]

Configures optimizer and learning rate scheduler.

Currently this module uses the AdamW optimizer, with configurable weight_decay, with a learning rate warming up from 0 on a per-step manner to the configurable self.optimization_config.init_lr, then undergoes polynomial decay as specified via self.optimization_config.

log_metrics(results: GenerativeSequenceModelOutput, split: Split)[source]

Logs metric results for a given output result.

Parameters:
results: GenerativeSequenceModelOutput

The results to assess across the suite of metrics.

split: Split

The split that should be used when logging metric results.

log_tte_metrics(results: GenerativeSequenceModelOutput, split: Split)[source]
save_pretrained(model_dir: Path)[source]
test_step(batch: PytorchBatch, batch_idx: int)[source]

Validation step.

Differs from training only in that it does not skip metrics.

training_step(batch: PytorchBatch, batch_idx: int) Tensor[source]

Training step.

Skips logging all AUROC, AUPRC, and per_class metric to save compute.

validation_step(batch: PytorchBatch, batch_idx: int)[source]

Validation step.

Differs from training only in that it does not skip metrics.

class EventStream.transformer.lightning_modules.generative_modeling.PretrainConfig(do_overwrite: bool = False, seed: int = 1, config: dict[str, typing.Any] = <factory>, optimization_config: EventStream.transformer.config.OptimizationConfig = <factory>, data_config: EventStream.data.config.PytorchDatasetConfig = <factory>, pretraining_metrics_config: EventStream.transformer.config.MetricsConfig = <factory>, final_validation_metrics_config: EventStream.transformer.config.MetricsConfig = <factory>, trainer_config: dict[str, typing.Any] = <factory>, experiment_dir: str = '???', save_dir: str = '${experiment_dir}/pretrain/${now:%Y-%m-%d_%H-%M-%S}', wandb_logger_kwargs: dict[str, typing.Any] = <factory>, wandb_experiment_config_kwargs: dict[str, typing.Any] = <factory>, do_final_validation_on_metrics: bool = True, do_use_filesystem_sharing: bool = True)[source]

Bases: object

config : dict[str, Any]
data_config : PytorchDatasetConfig
do_final_validation_on_metrics : bool = True
do_overwrite : bool = False
do_use_filesystem_sharing : bool = True
experiment_dir : str = '???'
final_validation_metrics_config : MetricsConfig
optimization_config : OptimizationConfig
pretraining_metrics_config : MetricsConfig
save_dir : str = '${experiment_dir}/pretrain/${now:%Y-%m-%d_%H-%M-%S}'
seed : int = 1
trainer_config : dict[str, Any]
wandb_experiment_config_kwargs : dict[str, Any]
wandb_logger_kwargs : dict[str, Any]
EventStream.transformer.lightning_modules.generative_modeling.train(cfg: PretrainConfig)[source]

Runs the end to end training procedure for the pre-training model.

Parameters:
cfg: PretrainConfig

The pre-training config defining the generative modeling task.