Source code for EventStream.data.pytorch_dataset

import json
from collections import defaultdict
from pathlib import Path

import numpy as np
import polars as pl
import torch
from loguru import logger
from mixins import SeedableMixin
from nested_ragged_tensors.ragged_numpy import (
    NP_FLOAT_TYPES,
    NP_INT_TYPES,
    NP_UINT_TYPES,
    JointNestedRaggedTensorDict,
)
from tqdm.auto import tqdm

from ..utils import count_or_proportion
from .config import PytorchDatasetConfig, SeqPaddingSide, SubsequenceSamplingStrategy
from .types import PytorchBatch

DATA_ITEM_T = dict[str, list[float]]


[docs] def to_int_index(col: pl.Expr) -> pl.Expr: """Returns an integer index of the unique elements seen in this column. The returned index is into a vocabulary sorted lexographically. Args: col: The column containing the data to be converted into integer indices. Examples: >>> import polars as pl >>> X = pl.DataFrame({ ... 'c': ['foo', 'bar', 'foo', 'bar', 'baz', None, 'bar', 'aba'], ... 'd': [1, 2, 3, 4, 5, 6, 7, 8] ... }) >>> X.with_columns(to_int_index(pl.col('c')).alias("c_index")) shape: (8, 3) ┌──────┬─────┬─────────┐ │ c ┆ d ┆ c_index │ │ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ u32 │ ╞══════╪═════╪═════════╡ │ foo ┆ 1 ┆ 3 │ │ bar ┆ 2 ┆ 1 │ │ foo ┆ 3 ┆ 3 │ │ bar ┆ 4 ┆ 1 │ │ baz ┆ 5 ┆ 2 │ │ null ┆ 6 ┆ null │ │ bar ┆ 7 ┆ 1 │ │ aba ┆ 8 ┆ 0 │ └──────┴─────┴─────────┘ """ indices = col.drop_nulls().unique().sort().search_sorted(col, side="left") return pl.when(col.is_null()).then(pl.lit(None)).otherwise(indices).alias(col.meta.output_name())
[docs] class PytorchDataset(SeedableMixin, torch.utils.data.Dataset): """A PyTorch Dataset class. This class enables accessing the deep-learning friendly representation produced by `Dataset.build_DL_cached_representation` in a PyTorch Dataset format. The `getitem` method of this class will return a dictionary containing a subject's data from this deep learning representation, with event sequences sliced to be within max sequence length according to configuration parameters, and the `collate` method of this class will collate those output dictionaries into a `PytorchBatch` object usable by downstream pipelines. Upon construction, this class will try to load a number of dataset files from disk. These files should be saved in accordance with the `Dataset.save` method; in particular, * There should be pre-cached deep-learning representation parquet dataframes stored in ``config.save_dir / 'DL_reps' / f"{split}*.parquet"`` * There should be a vocabulary config object in json form stored in ``config.save_dir / 'vocabulary_config.json'`` * There should be a set of inferred measurement configs stored in ``config.save_dir / 'inferred_measurement_configs.json'`` * If a task dataframe name is specified in the configuration object, then there should be either a pre-cached task-specifid DL representation dataframe in ``config.save_dir / 'DL_reps' / 'for_task' / config.task_df_name / f"{split}.parquet"``, or a "raw" task dataframe, containing subject IDs, start and end times, and labels, stored in ``config.save_dir / task_dfs / f"{config.task_df_name}.parquet"``. In the case that the latter is all that exists, then the former will be constructed by limiting the input cached dataframe down to the appropriate sequences and adding label columns. This newly constructed datafrmae will then be saved in the former filepath for future use. This construction process should happen first on the train split, so that inferred task vocabularies are shared across splits. Args: config: Configuration options for the dataset. split: The split of data which should be used in this dataset (e.g., ``'train'``, ``'tuning'``, ``'held_out'``). This will dictate where the system looks for pre-cached deep-learning representation files. """ TYPE_CHECKERS = { "multi_class_classification": [ ( {pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64, pl.Int8, pl.Int16, pl.Int32, pl.Int64}, None, ), ({pl.Categorical(ordering="physical"), pl.Categorical(ordering="lexical")}, to_int_index), ({pl.Utf8}, to_int_index), ], "binary_classification": [({pl.Boolean}, lambda Y: Y.cast(pl.Float32))], "regression": [({pl.Float32, pl.Float64}, None)], } """Type checker and conversion parameters for labeled datasets."""
[docs] @classmethod def normalize_task(cls, col: pl.Expr, dtype: pl.DataType) -> tuple[str, pl.Expr]: """Normalizes the task labels in `col` of dtype `dtype` to a common format. Args: col: The column containing the task labels, in polars expression format. dtype: The polars data type of the task labels. Returns: The task type (a string key into the `TYPE_CHECKERS` dictionary) and the normalized column expression. Raises: TypeError: If the task labels are not of a supported type. """ for task_type, checkers in cls.TYPE_CHECKERS.items(): for valid_dtypes, normalize_fn in checkers: if dtype in valid_dtypes: return task_type, (col if normalize_fn is None else normalize_fn(col)) raise TypeError(f"Can't process label of {dtype} type!")
def __init__(self, config: PytorchDatasetConfig, split: str, just_cache: bool = False): super().__init__() self.config = config self.split = split logger.info("Reading vocabulary") self.read_vocabulary() logger.info("Reading splits & patient shards") self.read_shards() logger.info("Reading patient descriptors") self.read_patient_descriptors() if self.config.min_seq_len is not None and self.config.min_seq_len > 1: logger.info(f"Restricting to subjects with at least {config.min_seq_len} events") self.filter_to_min_seq_len() if self.config.train_subset_size not in (None, "FULL") and self.split == "train": logger.info(f"Filtering training subset size to {self.config.train_subset_size}") self.filter_to_subset() self.set_inter_event_time_stats() @property def static_dir(self) -> Path: return self.config.save_dir / "DL_reps" @property def task_dir(self) -> Path: return self.config.save_dir / "task_dfs" @property def NRTs_dir(self) -> Path: return self.config.save_dir / "NRT_reps"
[docs] def read_vocabulary(self): """Reads the vocabulary either from the ESGPT or MEDS dataset.""" self.vocabulary_config = self.config.vocabulary_config
[docs] def read_shards(self): """Reads the split-specific patient shards from the ESGPT or MEDS dataset.""" shards_fp = self.config.save_dir / "DL_shards.json" all_shards = json.loads(shards_fp.read_text()) self.shards = {sp: subjs for sp, subjs in all_shards.items() if sp.startswith(f"{self.split}/")} self.subj_map = {subj: sp for sp, subjs in self.shards.items() for subj in subjs}
@property def measurement_configs(self): """Grabs the measurement configs from the config.""" return self.config.measurement_configs
[docs] def read_patient_descriptors(self): """Reads the patient descriptors from the ESGPT or MEDS dataset.""" self.static_dfs = {} self.subj_indices = {} self.subj_seq_bounds = {} shards = tqdm(self.shards.keys(), total=len(self.shards), desc="Reading static shards", leave=False) for shard in shards: static_fp = self.static_dir / f"{shard}.parquet" df = pl.read_parquet( static_fp, columns=[ "subject_id", "start_time", "static_indices", "static_measurement_indices", "time_delta", ], use_pyarrow=True, ) self.static_dfs[shard] = df subject_ids = df["subject_id"] n_events = df.select(pl.col("time_delta").list.lengths().alias("n_events")).get_column("n_events") for i, (subj, n_events) in enumerate(zip(subject_ids, n_events)): if subj in self.subj_indices or subj in self.subj_seq_bounds: raise ValueError(f"Duplicate subject {subj} in {shard}!") self.subj_indices[subj] = i self.subj_seq_bounds[subj] = (0, n_events) if self.config.task_df_name is None: self.index = [(subj, *bounds) for subj, bounds in self.subj_seq_bounds.items()] self.labels = {} self.tasks = None self.task_types = None self.task_vocabs = None else: task_df_fp = self.task_dir / f"{self.config.task_df_name}.parquet" task_info_fp = self.task_dir / f"{self.config.task_df_name}_info.json" logger.info(f"Reading task constraints for {self.config.task_df_name} from {task_df_fp}") task_df = pl.read_parquet(task_df_fp, use_pyarrow=True) task_info = self.get_task_info(task_df) if task_info_fp.is_file(): loaded_task_info = json.loads(task_info_fp.read_text()) if loaded_task_info != task_info: raise ValueError( f"Task info differs from on disk!\nDisk:\n{loaded_task_info}\n" f"Local:\n{task_info}\nSplit: {self.split}" ) logger.info(f"Re-built existing {task_info_fp} and it matches.") else: task_info_fp.parent.mkdir(exist_ok=True, parents=True) task_info_fp.write_text(json.dumps(task_info)) idx_col = "_row_index" while idx_col in task_df.columns: idx_col = f"_{idx_col}" task_df_joint = ( task_df.select("subject_id", "start_time", "end_time") .with_row_index(idx_col) .group_by("subject_id") .agg("start_time", "end_time", idx_col) .join( pl.concat(self.static_dfs.values()).select( "subject_id", pl.col("start_time").alias("start_time_global"), "time_delta" ), on="subject_id", how="left", ) .with_columns( pl.col("time_delta") .list.eval(pl.element().fill_null(0).cum_sum()) .alias("min_since_start") ) ) min_at_task_start = ( (pl.col("start_time") - pl.col("start_time_global")).dt.total_seconds() / 60 ).alias("min_at_task_start") min_at_task_end = ( (pl.col("end_time") - pl.col("start_time_global")).dt.total_seconds() / 60 ).alias("min_at_task_end") start_idx_expr = (pl.col("min_since_start").search_sorted(pl.col("min_at_task_start"))).alias( "start_idx" ) end_idx_expr = (pl.col("min_since_start").search_sorted(pl.col("min_at_task_end"))).alias( "end_idx" ) task_df_joint = ( task_df_joint.explode(idx_col, "start_time", "end_time") .with_columns(min_at_task_start, min_at_task_end) .explode("min_since_start") .group_by("subject_id", idx_col, "min_at_task_start", "min_at_task_end", maintain_order=True) .agg(start_idx_expr.first(), end_idx_expr.first()) .sort(by=idx_col, descending=False) ) subject_ids = task_df_joint["subject_id"] start_indices = task_df_joint["start_idx"] end_indices = task_df_joint["end_idx"] self.labels = {t: task_df.get_column(t).to_list() for t in self.tasks} self.index = list(zip(subject_ids, start_indices, end_indices))
[docs] def get_task_info(self, task_df: pl.DataFrame): """Gets the task information from the task dataframe.""" self.tasks = sorted([c for c in task_df.columns if c not in ["subject_id", "start_time", "end_time"]]) self.task_types = {} self.task_vocabs = {} normalized_cols = [] for t in self.tasks: task_type, normalized_vals = self.normalize_task(col=pl.col(t), dtype=task_df.schema[t]) self.task_types[t] = task_type normalized_cols.append(normalized_vals.alias(t)) task_df = task_df.with_columns(normalized_cols) for t in self.tasks: match self.task_types[t]: case "binary_classification": self.task_vocabs[t] = [False, True] case "multi_class_classification": self.task_vocabs[t] = list(range(task_df.select(pl.col(t).max()).item() + 1)) case _: raise NotImplementedError(f"Task type {self.task_types[t]} not implemented!") return {"tasks": sorted(self.tasks), "vocabs": self.task_vocabs, "types": self.task_types}
[docs] def filter_to_min_seq_len(self): """Filters the dataset to only include subjects with at least `config.min_seq_len` events.""" if self.config.task_df_name is not None: logger.warning( f"Filtering task {self.config.task_df_name} to min_seq_len {self.config.min_seq_len}. " "This may result in incomparable model results against runs with different constraints!" ) orig_len = len(self) orig_n_subjects = len(set(self.subject_ids)) valid_indices = [ i for i, (subj, start, end) in enumerate(self.index) if end - start >= self.config.min_seq_len ] self.index = [self.index[i] for i in valid_indices] self.labels = {t: [t_labels[i] for i in valid_indices] for t, t_labels in self.labels.items()} new_len = len(self) new_n_subjects = len(set(self.subject_ids)) logger.info( f"Filtered data due to sequence length constraint (>= {self.config.min_seq_len}) from " f"{orig_len} to {new_len} rows and {orig_n_subjects} to {new_n_subjects} subjects." )
[docs] def filter_to_subset(self): """Filters the dataset to only include a subset of subjects.""" orig_len = len(self) orig_n_subjects = len(set(self.subject_ids)) rng = np.random.default_rng(self.config.train_subset_seed) subset_subjects = rng.choice( list(set(self.subject_ids)), size=count_or_proportion(orig_n_subjects, self.config.train_subset_size), replace=False, ) valid_indices = [i for i, (subj, start, end) in enumerate(self.index) if subj in subset_subjects] self.index = [self.index[i] for i in valid_indices] self.labels = {t: [t_labels[i] for i in valid_indices] for t, t_labels in self.labels.items()} new_len = len(self) new_n_subjects = len(set(self.subject_ids)) logger.info( f"Filtered data to subset of {self.config.train_subset_size} subjects from " f"{orig_len} to {new_len} rows and {orig_n_subjects} to {new_n_subjects} subjects." )
[docs] def set_inter_event_time_stats(self): """Sets the inter-event time statistics for the dataset.""" data_for_stats = pl.concat([x.lazy() for x in self.static_dfs.values()]) stats = ( data_for_stats.select( pl.col("time_delta").explode().drop_nulls().drop_nans().alias("inter_event_time") ) .select( pl.col("inter_event_time").min().alias("min"), pl.col("inter_event_time").log().mean().alias("mean_log"), pl.col("inter_event_time").log().std().alias("std_log"), ) .collect() ) if stats["min"].item() <= 0: bad_inter_event_times = data_for_stats.filter(pl.col("time_delta").list.min() <= 0).collect() bad_subject_ids = set(bad_inter_event_times["subject_id"].to_list()) warning_strs = [ f"Observed inter-event times <= 0 for {len(bad_inter_event_times)} subjects!", f"Bad Subject IDs: {', '.join(str(x) for x in bad_subject_ids)}", f"Global min: {stats['min'].item()}", ] if self.config.save_dir is not None: fp = self.config.save_dir / f"malformed_data_{self.split}.parquet" bad_inter_event_times.write_parquet(fp) warning_strs.append(f"Wrote malformed data records to {fp}") warning_strs.append("Removing malformed subjects") logger.warning("\n".join(warning_strs)) self.index = [x for x in self.index if x[0] not in bad_subject_ids] self.mean_log_inter_event_time_min = stats["mean_log"].item() self.std_log_inter_event_time_min = stats["std_log"].item()
@property def subject_ids(self) -> list[int]: return [x[0] for x in self.index] def __len__(self): return len(self.index) @property def has_task(self) -> bool: return self.config.task_df_name is not None @property def seq_padding_side(self) -> SeqPaddingSide: return self.config.seq_padding_side @property def max_seq_len(self) -> int: return self.config.max_seq_len @property def is_subset_dataset(self) -> bool: return self.config.train_subset_size != "FULL" def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: """Returns a Returns a dictionary corresponding to a single subject's data. The output of this will not be tensorized as that work will need to be re-done in the collate function regardless. The output will have structure: `` { 'time_delta': [seq_len], 'dynamic_indices': [seq_len, n_data_per_event] (ragged), 'dynamic_values': [seq_len, n_data_per_event] (ragged), 'dynamic_measurement_indices': [seq_len, n_data_per_event] (ragged), 'static_indices': [seq_len, n_data_per_event] (ragged), 'static_measurement_indices': [seq_len, n_data_per_event] (ragged), } `` 1. ``time_delta`` captures the time between each event and the subsequent event. 2. ``dynamic_indices`` captures the categorical metadata elements listed in `self.data_cols` in a unified vocabulary space spanning all metadata vocabularies. 3. ``dynamic_values`` captures the numerical metadata elements listed in `self.data_cols`. If no numerical elements are listed in `self.data_cols` for a given categorical column, the according index in this output will be `float('nan')`. 4. ``dynamic_measurement_indices`` captures which measurement vocabulary was used to source a given data element. 5. ``static_indices`` captures the categorical metadata elements listed in `self.static_cols` in a unified vocabulary. 6. ``static_measurement_indices`` captures which measurement vocabulary was used to source a given data element. """ return self._seeded_getitem(idx) @SeedableMixin.WithSeed def _seeded_getitem(self, idx: int) -> dict[str, list[float]]: """Returns a Returns a dictionary corresponding to a single subject's data. This function is a seedable version of `__getitem__`. """ subject_id, st, end = self.index[idx] shard = self.subj_map[subject_id] subject_idx = self.subj_indices[subject_id] static_row = self.static_dfs[shard][subject_idx].to_dict() out = { "static_indices": static_row["static_indices"].item().to_list(), "static_measurement_indices": static_row["static_measurement_indices"].item().to_list(), } if self.config.do_include_subject_id: out["subject_id"] = subject_id seq_len = end - st if seq_len > self.max_seq_len: match self.config.subsequence_sampling_strategy: case SubsequenceSamplingStrategy.RANDOM: start_offset = np.random.choice(seq_len - self.max_seq_len) case SubsequenceSamplingStrategy.TO_END: start_offset = seq_len - self.max_seq_len case SubsequenceSamplingStrategy.FROM_START: start_offset = 0 case _: raise ValueError( f"Invalid subsequence sampling strategy {self.config.subsequence_sampling_strategy}!" ) st += start_offset end = min(end, st + self.max_seq_len) if self.config.do_include_subsequence_indices: out["start_idx"] = st out["end_idx"] = end out["dynamic"] = JointNestedRaggedTensorDict.load_slice(self.NRTs_dir / f"{shard}.pt", subject_idx)[ st:end ] if self.config.do_include_start_time_min: out["start_time"] = static_row["start_time"] = static_row[ "start_time" ].item().timestamp() / 60.0 + sum(static_row["time_delta"].item().to_list()[:st]) for t, t_labels in self.labels.items(): out[t] = t_labels[idx] return out def __dynamic_only_collate(self, batch: list[dict[str, list[float]]]) -> PytorchBatch: """An internal collate function for only dynamic data.""" keys = batch[0].keys() dense_keys = {k for k in keys if k not in ("dynamic", "static_indices", "static_measurement_indices")} if dense_keys: dense_collated = torch.utils.data.default_collate([{k: x[k] for k in dense_keys} for x in batch]) else: dense_collated = {} dynamic = JointNestedRaggedTensorDict.vstack([x["dynamic"] for x in batch]).to_dense( padding_side=self.seq_padding_side ) dynamic["event_mask"] = dynamic.pop("dim1/mask") dynamic["dynamic_values_mask"] = dynamic.pop("dim2/mask") & ~np.isnan(dynamic["dynamic_values"]) dynamic_collated = {} for k, v in dynamic.items(): if k.endswith("mask"): dynamic_collated[k] = torch.from_numpy(v) elif v.dtype in NP_UINT_TYPES + NP_INT_TYPES: dynamic_collated[k] = torch.from_numpy(v.astype(int)).long() elif v.dtype in NP_FLOAT_TYPES: dynamic_collated[k] = torch.from_numpy(v.astype(float)).float() else: raise TypeError(f"Don't know how to tensorify {k} of type {v.dtype}!") collated = {**dense_collated, **dynamic_collated} out_batch = {} out_batch["event_mask"] = collated["event_mask"] out_batch["dynamic_values_mask"] = collated["dynamic_values_mask"] out_batch["time_delta"] = torch.nan_to_num(collated["time_delta"].float(), nan=0) out_batch["dynamic_indices"] = collated["dynamic_indices"].long() out_batch["dynamic_measurement_indices"] = collated["dynamic_measurement_indices"].long() out_batch["dynamic_values"] = torch.nan_to_num(collated["dynamic_values"].float(), nan=0) if self.config.do_include_start_time_min: out_batch["start_time"] = collated["start_time"].float() if self.config.do_include_subsequence_indices: out_batch["start_idx"] = collated["start_idx"].long() out_batch["end_idx"] = collated["end_idx"].long() if self.config.do_include_subject_id: out_batch["subject_id"] = collated["subject_id"].long() out_batch = PytorchBatch(**out_batch) if not self.has_task: return out_batch out_labels = {} for task in self.tasks: match self.task_types[task]: case "multi_class_classification": out_labels[task] = collated[task].long() case "binary_classification": out_labels[task] = collated[task].float() case "regression": out_labels[task] = collated[task].float() case _: raise TypeError(f"Don't know how to tensorify task of type {self.task_types[task]}!") out_batch.stream_labels = out_labels return out_batch
[docs] def collate(self, batch: list[DATA_ITEM_T]) -> PytorchBatch: """Combines the ragged dictionaries produced by `__getitem__` into a tensorized batch. This function handles conversion of arrays to tensors and padding of elements within the batch across static data elements, sequence events, and dynamic data elements. Args: batch: A list of `__getitem__` format output dictionaries. Returns: A fully collated, tensorized, and padded batch. """ out_batch = self.__dynamic_only_collate(batch) max_n_static = max(len(x["static_indices"]) for x in batch) static_padded_fields = defaultdict(list) for e in batch: n_static = len(e["static_indices"]) static_delta = max_n_static - n_static for k in ("static_indices", "static_measurement_indices"): if static_delta > 0: static_padded_fields[k].append( torch.nn.functional.pad( torch.tensor(e[k], dtype=torch.long), (0, static_delta), value=0 ) ) else: static_padded_fields[k].append(torch.tensor(e[k], dtype=torch.long)) for k, v in static_padded_fields.items(): out_batch[k] = torch.cat([T.unsqueeze(0) for T in v], dim=0) return out_batch