EventStream.transformer.lightning_modules.generative_modeling module¶
-
class EventStream.transformer.lightning_modules.generative_modeling.ESTForGenerativeSequenceModelingLM(config: StructuredTransformerConfig | dict[str, Any], optimization_config: OptimizationConfig | dict[str, Any], metrics_config: MetricsConfig | dict[str, Any], pretrained_weights_fp: Path | None =
None)[source]¶ Bases:
LightningModuleA PyTorch Lightning Module for a
ESTForGenerativeSequenceModeling.-
CLASSIFICATION =
{DataModality.MULTI_LABEL_CLASSIFICATION, DataModality.SINGLE_LABEL_CLASSIFICATION}¶
-
TRAIN_SKIP_METRICS =
('AUROC', 'AUPRC', 'per_class')¶
- configure_optimizers()[source]¶
Configures optimizer and learning rate scheduler.
Currently this module uses the AdamW optimizer, with configurable weight_decay, with a learning rate warming up from 0 on a per-step manner to the configurable
self.optimization_config.init_lr, then undergoes polynomial decay as specified viaself.optimization_config.
- log_metrics(results: GenerativeSequenceModelOutput, split: Split)[source]¶
Logs metric results for a given output result.
- Parameters:¶
- results: GenerativeSequenceModelOutput¶
The results to assess across the suite of metrics.
- split: Split¶
The split that should be used when logging metric results.
- log_tte_metrics(results: GenerativeSequenceModelOutput, split: Split)[source]¶
- test_step(batch: PytorchBatch, batch_idx: int)[source]¶
Validation step.
Differs from training only in that it does not skip metrics.
- training_step(batch: PytorchBatch, batch_idx: int) Tensor[source]¶
Training step.
Skips logging all AUROC, AUPRC, and per_class metric to save compute.
- validation_step(batch: PytorchBatch, batch_idx: int)[source]¶
Validation step.
Differs from training only in that it does not skip metrics.
-
CLASSIFICATION =
- class EventStream.transformer.lightning_modules.generative_modeling.PretrainConfig(do_overwrite: bool = False, seed: int = 1, config: dict[str, typing.Any] = <factory>, optimization_config: EventStream.transformer.config.OptimizationConfig = <factory>, data_config: EventStream.data.config.PytorchDatasetConfig = <factory>, pretraining_metrics_config: EventStream.transformer.config.MetricsConfig = <factory>, final_validation_metrics_config: EventStream.transformer.config.MetricsConfig = <factory>, trainer_config: dict[str, typing.Any] = <factory>, experiment_dir: str = '???', save_dir: str = '${experiment_dir}/pretrain/${now:%Y-%m-%d_%H-%M-%S}', wandb_logger_kwargs: dict[str, typing.Any] = <factory>, wandb_experiment_config_kwargs: dict[str, typing.Any] = <factory>, do_final_validation_on_metrics: bool = True, do_use_filesystem_sharing: bool = True)[source]¶
Bases:
object- data_config : PytorchDatasetConfig¶
- final_validation_metrics_config : MetricsConfig¶
- optimization_config : OptimizationConfig¶
- pretraining_metrics_config : MetricsConfig¶
- EventStream.transformer.lightning_modules.generative_modeling.train(cfg: PretrainConfig)[source]¶
Runs the end to end training procedure for the pre-training model.
- Parameters:¶
- cfg: PretrainConfig¶
The pre-training config defining the generative modeling task.