EventStream.transformer.lightning_modules.fine_tuning module¶
-
class EventStream.transformer.lightning_modules.fine_tuning.ESTForStreamClassificationLM(config: StructuredTransformerConfig | dict[str, Any], optimization_config: OptimizationConfig | dict[str, Any], pretrained_weights_fp: Path | str | None =
None, do_debug_mode: bool =True)[source]¶ Bases:
LightningModuleA PyTorch Lightning Module for a
ESTForStreamClassificationmodel.- configure_optimizers()[source]¶
Configures optimizer and learning rate scheduler.
Currently this module uses the AdamW optimizer, with configurable weight_decay, with a learning rate warming up from 0 on a per-step manner to the configurable
self.optimization_config.init_lr, then undergoes polynomial decay as specified viaself.optimization_config.
- log_metrics(results: StreamClassificationModelOutput, skip_metrics: Sequence[str], prefix: str)[source]¶
Logs metric results for a given output result.
- Parameters:¶
- results: StreamClassificationModelOutput¶
The results to assess across the suite of metrics.
- skip_metrics: Sequence[str]¶
A list of metrics to skip. Entries are not full metric names, but rather are partial names and any metric whose name contains an element of
skip_metricswill be skipped. For example, ifskip_metrics = ['AUROC', 'AUPRC'], then a metric with name'macro_AUROC'or'micro_AUPRC'would be skipped, whereas a metric named'weighted_accuracy'would not.- prefix: str¶
The prefix that should be used when logging metric results. Will likely be ‘train’, ‘tuning’, or ‘held_out’, for example.
- test_step(batch, batch_idx)[source]¶
Validation step.
Differs from training only in that it does not skip metrics.
- class EventStream.transformer.lightning_modules.fine_tuning.FinetuneConfig(experiment_dir: str | pathlib.Path | None = '${load_from_model_dir}/finetuning', load_from_model_dir: str | pathlib.Path | None = '???', task_df_name: str | None = '???', pretrained_weights_fp: pathlib.Path | str | None = '${load_from_model_dir}/pretrained_weights', save_dir: str | None = '${experiment_dir}/${task_df_name}/subset_size_${data_config.train_subset_size}/subset_seed_${data_config.train_subset_seed}/${now:%Y-%m-%d_%H-%M-%S}', wandb_logger_kwargs: dict[str, typing.Any] = <factory>, wandb_experiment_config_kwargs: dict[str, typing.Any] = <factory>, do_overwrite: bool = False, seed: int = 1, config: dict[str, typing.Any] = <factory>, optimization_config: EventStream.transformer.config.OptimizationConfig = <factory>, data_config: dict[str, typing.Any] | None = <factory>, trainer_config: dict[str, typing.Any] = <factory>, do_use_filesystem_sharing: bool = True)[source]¶
Bases:
object- optimization_config : OptimizationConfig¶
- EventStream.transformer.lightning_modules.fine_tuning.train(cfg: FinetuneConfig)[source]¶
Runs the end to end training procedure for the fine-tuning model.
- Parameters:¶
- cfg: FinetuneConfig¶
The fine-tuning configuration object specifying the cohort and task for which and model from which you wish to fine-tune.