EventStream.evaluation.general_generative_evaluation module

class EventStream.evaluation.general_generative_evaluation.ESTForTrajectoryGeneration(config: StructuredTransformerConfig | dict[str, Any], pretrained_weights_fp: Path)[source]

Bases: LightningModule

A PyTorch Lightning Module for a zero-shot classification via generation for an EST model.

predict_step(batch: PytorchBatch, batch_idx: int) list[PytorchBatch][source]

Prediction step.

Generates new samples and writes them out.

class EventStream.evaluation.general_generative_evaluation.GenerateConfig(load_from_model_dir: str | pathlib.Path = '???', seed: int = 1, pretrained_weights_fp: pathlib.Path | None = None, save_dir: str | None = None, do_overwrite: bool = False, optimization_config: EventStream.transformer.config.OptimizationConfig = <factory>, task_df_name: str | None = None, data_config_overrides: dict[str, typing.Any] | None = <factory>, trainer_config: dict[str, typing.Any] = <factory>, task_specific_params: dict[str, typing.Any] = <factory>, config_overrides: dict[str, typing.Any] = <factory>, parallelize_conversion: int | None = None)[source]

Bases: object

config_overrides : dict[str, Any]
data_config_overrides : dict[str, Any] | None
do_overwrite : bool = False
load_from_model_dir : str | Path = '???'
optimization_config : OptimizationConfig
parallelize_conversion : int | None = None
pretrained_weights_fp : Path | None = None
save_dir : str | None = None
seed : int = 1
task_df_name : str | None = None
task_specific_params : dict[str, Any]
trainer_config : dict[str, Any]
EventStream.evaluation.general_generative_evaluation.generate_trajectories(cfg: GenerateConfig)[source]