EventStream.data.pytorch_dataset module

class EventStream.data.pytorch_dataset.PytorchDataset(config: PytorchDatasetConfig, split: str, just_cache: bool = False)[source]

Bases: SeedableMixin, Dataset

A PyTorch Dataset class.

This class enables accessing the deep-learning friendly representation produced by Dataset.build_DL_cached_representation in a PyTorch Dataset format. The getitem method of this class will return a dictionary containing a subject’s data from this deep learning representation, with event sequences sliced to be within max sequence length according to configuration parameters, and the collate method of this class will collate those output dictionaries into a PytorchBatch object usable by downstream pipelines.

Upon construction, this class will try to load a number of dataset files from disk. These files should be saved in accordance with the Dataset.save method; in particular,

  • There should be pre-cached deep-learning representation parquet dataframes stored in config.save_dir / 'DL_reps' / f"{split}*.parquet"

  • There should be a vocabulary config object in json form stored in config.save_dir / 'vocabulary_config.json'

  • There should be a set of inferred measurement configs stored in config.save_dir / 'inferred_measurement_configs.json'

  • If a task dataframe name is specified in the configuration object, then there should be either a pre-cached task-specifid DL representation dataframe in config.save_dir / 'DL_reps' / 'for_task' / config.task_df_name / f"{split}.parquet", or a “raw” task dataframe, containing subject IDs, start and end times, and labels, stored in config.save_dir / task_dfs / f"{config.task_df_name}.parquet". In the case that the latter is all that exists, then the former will be constructed by limiting the input cached dataframe down to the appropriate sequences and adding label columns. This newly constructed datafrmae will then be saved in the former filepath for future use. This construction process should happen first on the train split, so that inferred task vocabularies are shared across splits.

Parameters:
config: PytorchDatasetConfig

Configuration options for the dataset.

split: str

The split of data which should be used in this dataset (e.g., 'train', 'tuning', 'held_out'). This will dictate where the system looks for pre-cached deep-learning representation files.

property NRTs_dir : Path
TYPE_CHECKERS = {'binary_classification': [({Boolean}, <function PytorchDataset.<lambda>>)], 'multi_class_classification': [({Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8}, None), ({Categorical(ordering='lexical'), Categorical(ordering='physical')}, <function to_int_index>), ({String}, <function to_int_index>)], 'regression': [({Float32, Float64}, None)]}

Type checker and conversion parameters for labeled datasets.

collate(batch: list[dict[str, list[float]]]) PytorchBatch[source]

Combines the ragged dictionaries produced by __getitem__ into a tensorized batch.

This function handles conversion of arrays to tensors and padding of elements within the batch across static data elements, sequence events, and dynamic data elements.

Parameters:
batch: list[dict[str, list[float]]]

A list of __getitem__ format output dictionaries.

Returns:

A fully collated, tensorized, and padded batch.

filter_to_min_seq_len()[source]

Filters the dataset to only include subjects with at least config.min_seq_len events.

filter_to_subset()[source]

Filters the dataset to only include a subset of subjects.

get_task_info(task_df: DataFrame)[source]

Gets the task information from the task dataframe.

property has_task : bool
property is_subset_dataset : bool
property max_seq_len : int
property measurement_configs

Grabs the measurement configs from the config.

classmethod normalize_task(col: Expr, dtype: DataType) tuple[str, Expr][source]

Normalizes the task labels in col of dtype dtype to a common format.

Parameters:
col: Expr

The column containing the task labels, in polars expression format.

dtype: DataType

The polars data type of the task labels.

Returns:

The task type (a string key into the TYPE_CHECKERS dictionary) and the normalized column expression.

Raises:

TypeError – If the task labels are not of a supported type.

read_patient_descriptors()[source]

Reads the patient descriptors from the ESGPT or MEDS dataset.

read_shards()[source]

Reads the split-specific patient shards from the ESGPT or MEDS dataset.

read_vocabulary()[source]

Reads the vocabulary either from the ESGPT or MEDS dataset.

property seq_padding_side : SeqPaddingSide
set_inter_event_time_stats()[source]

Sets the inter-event time statistics for the dataset.

property static_dir : Path
property subject_ids : list[int]
property task_dir : Path
EventStream.data.pytorch_dataset.to_int_index(col: Expr) Expr[source]

Returns an integer index of the unique elements seen in this column.

The returned index is into a vocabulary sorted lexographically.

Parameters:
col: Expr

The column containing the data to be converted into integer indices.

Examples

>>> import polars as pl
>>> X = pl.DataFrame({
...     'c': ['foo', 'bar', 'foo', 'bar', 'baz', None, 'bar', 'aba'],
...     'd': [1, 2, 3, 4, 5, 6, 7, 8]
... })
>>> X.with_columns(to_int_index(pl.col('c')).alias("c_index"))
shape: (8, 3)
┌──────┬─────┬─────────┐
│ c    ┆ d   ┆ c_index │
│ ---  ┆ --- ┆ ---     │
│ str  ┆ i64 ┆ u32     │
╞══════╪═════╪═════════╡
│ foo  ┆ 1   ┆ 3       │
│ bar  ┆ 2   ┆ 1       │
│ foo  ┆ 3   ┆ 3       │
│ bar  ┆ 4   ┆ 1       │
│ baz  ┆ 5   ┆ 2       │
│ null ┆ 6   ┆ null    │
│ bar  ┆ 7   ┆ 1       │
│ aba  ┆ 8   ┆ 0       │
└──────┴─────┴─────────┘