EventStream.evaluation.general_generative_evaluation module¶
- class EventStream.evaluation.general_generative_evaluation.ESTForTrajectoryGeneration(config: StructuredTransformerConfig | dict[str, Any], pretrained_weights_fp: Path)[source]¶
Bases:
LightningModuleA PyTorch Lightning Module for a zero-shot classification via generation for an EST model.
- predict_step(batch: PytorchBatch, batch_idx: int) list[PytorchBatch][source]¶
Prediction step.
Generates new samples and writes them out.
- class EventStream.evaluation.general_generative_evaluation.GenerateConfig(load_from_model_dir: str | pathlib.Path = '???', seed: int = 1, pretrained_weights_fp: pathlib.Path | None = None, save_dir: str | None = None, do_overwrite: bool = False, optimization_config: EventStream.transformer.config.OptimizationConfig = <factory>, task_df_name: str | None = None, data_config_overrides: dict[str, typing.Any] | None = <factory>, trainer_config: dict[str, typing.Any] = <factory>, task_specific_params: dict[str, typing.Any] = <factory>, config_overrides: dict[str, typing.Any] = <factory>, parallelize_conversion: int | None = None)[source]¶
Bases:
object- optimization_config : OptimizationConfig¶
- EventStream.evaluation.general_generative_evaluation.generate_trajectories(cfg: GenerateConfig)[source]¶