EventStream.data.pytorch_dataset module¶
- class EventStream.data.pytorch_dataset.PytorchDataset(config: PytorchDatasetConfig, split: str)[source]¶
Bases:
SaveableMixin,SeedableMixin,TimeableMixin,DatasetA PyTorch Dataset class built on a pre-processed
DatasetBaseinstance.This class enables accessing the deep-learning friendly representation produced by
Dataset.build_DL_cached_representationin a PyTorch Dataset format. Thegetitemmethod of this class will return a dictionary containing a subject’s data from this deep learning representation, with event sequences sliced to be within max sequence length according to configuration parameters, and thecollatemethod of this class will collate those output dictionaries into aPytorchBatchobject usable by downstream pipelines.Upon construction, this class will try to load a number of dataset files from disk. These files should be saved in accordance with the
Dataset.savemethod; in particular,There should be pre-cached deep-learning representation parquet dataframes stored in
config.save_dir / 'DL_reps' / f"{split}*.parquet"There should be a vocabulary config object in json form stored in
config.save_dir / 'vocabulary_config.json'There should be a set of inferred measurement configs stored in
config.save_dir / 'inferred_measurement_configs.json'If a task dataframe name is specified in the configuration object, then there should be either a pre-cached task-specifid DL representation dataframe in
config.save_dir / 'DL_reps' / 'for_task' / config.task_df_name / f"{split}.parquet", or a “raw” task dataframe, containing subject IDs, start and end times, and labels, stored inconfig.save_dir / task_dfs / f"{config.task_df_name}.parquet". In the case that the latter is all that exists, then the former will be constructed by limiting the input cached dataframe down to the appropriate sequences and adding label columns. This newly constructed datafrmae will then be saved in the former filepath for future use. This construction process should happen first on the train split, so that inferred task vocabularies are shared across splits.
- Parameters:¶
- config: PytorchDatasetConfig¶
Configuration options for the dataset.
- split: str¶
The split of data which should be used in this dataset (e.g.,
'train','tuning','held_out'). This will dictate where the system looks for pre-cached deep-learning representation files.
-
TYPE_CHECKERS =
{'binary_classification': [({Boolean}, <function PytorchDataset.<lambda>>)], 'multi_class_classification': [({Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8}, None), ({Categorical}, <function to_int_index>), ({Utf8}, <function to_int_index>)], 'regression': [({Float32, Float64}, None)]}¶ Type checker and conversion parameters for labeled datasets.
- collate(batch: list[dict[str, list[float]]]) PytorchBatch[source]¶
Combines the ragged dictionaries produced by
__getitem__into a tensorized batch.This function handles conversion of arrays to tensors and padding of elements within the batch across static data elements, sequence events, and dynamic data elements.
- classmethod normalize_task(col: Expr, dtype: DataType) tuple[str, Expr][source]¶
Normalizes the task labels in
colof dtypedtypeto a common format.- Parameters:¶
- Returns:¶
The task type (a string key into the
TYPE_CHECKERSdictionary) and the normalized column expression.- Raises:¶
TypeError – If the task labels are not of a supported type.
- EventStream.data.pytorch_dataset.to_int_index(col: Expr) Expr[source]¶
Returns an integer index of the unique elements seen in this column.
The returned index is into a vocabulary sorted lexographically.
Examples
>>> import polars as pl >>> X = pl.DataFrame({ ... 'c': ['foo', 'bar', 'foo', 'bar', 'baz', None, 'bar', 'aba'], ... 'd': [1, 2, 3, 4, 5, 6, 7, 8] ... }) >>> X.with_columns(to_int_index(pl.col('c'))) shape: (8, 2) ┌──────┬─────┐ │ c ┆ d │ │ --- ┆ --- │ │ u32 ┆ i64 │ ╞══════╪═════╡ │ 4 ┆ 1 │ │ 1 ┆ 2 │ │ 4 ┆ 3 │ │ 1 ┆ 4 │ │ 2 ┆ 5 │ │ null ┆ 6 │ │ 1 ┆ 7 │ │ 0 ┆ 8 │ └──────┴─────┘