"""The base class for core dataset processing logic.
Attributes:
INPUT_DF_T: This defines the type of the allowable input dataframes -- e.g., databases, filepaths,
dataframes, etc.
DF_T: This defines the type of internal dataframes -- e.g. polars DataFrames.
"""
import abc
import copy
import itertools
import json
from collections import defaultdict
from collections.abc import Hashable, Sequence
from pathlib import Path
from typing import Any, Generic, TypeVar
import humanize
import numpy as np
import pandas as pd
import polars as pl
from loguru import logger
from mixins import SaveableMixin, SeedableMixin, TimeableMixin, TQDMableMixin
from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict
from plotly.graph_objs._figure import Figure
from tqdm.auto import tqdm
from ..utils import lt_count_or_proportion
from .config import (
DatasetConfig,
DatasetSchema,
InputDFSchema,
MeasurementConfig,
VocabularyConfig,
)
from .types import DataModality, InputDataType, InputDFType, TemporalityType
from .visualize import Visualizer
from .vocabulary import Vocabulary
INPUT_DF_T = TypeVar("INPUT_DF_T")
DF_T = TypeVar("DF_T")
[docs]
class DatasetBase(
abc.ABC, Generic[DF_T, INPUT_DF_T], SeedableMixin, SaveableMixin, TimeableMixin, TQDMableMixin
):
"""A unified base class for dataset objects using different processing libraries.
Args:
config: Configuration object for this dataset.
subjects_df: The dataframe containing all static, subject-level data. If this is specified,
`events_df` and `dynamic_measurements_df` should also be specified. Otherwise, this will be built
from source via the extraction pipeline defined in `input_schema`.
events_df: The dataframe containing all event timestamps, types, and subject IDs. If this is
specified, `subjects_df` and `dynamic_measurements_df` should also be specified. Otherwise, this
will be built from source via the extraction pipeline defined in `input_schema`.
dynamic_measurements_df: The dataframe containing all time-varying measurement observations. If this
is specified, `subjects_df` and `events_df` should also be specified. Otherwise, this will be
built from source via the extraction pipeline defined in `input_schema`.
input_schema: The schema configuration object to define the extraction pipeline for pulling raw data
from source and produce the `subjects_df`, `events_df`, `dynamic_measurements_df` input view.
"""
_PICKLER: str = "dill"
"""Dictates via which pickler the `_save` and `_load` methods will save/load objects of this class, as
defined in `SaveableMixin`."""
_DEL_BEFORE_SAVING_ATTRS: list[str] = [
"_subjects_df",
"_events_df",
"_dynamic_measurements_df",
"config",
"inferred_measurement_configs",
]
"""Attributes that are saved via separate files, and will be deleted before pickling."""
DF_SAVE_FORMAT: str = "parquet"
"""The save format for internal dataframes in this dataset."""
SUBJECTS_FN: str = "subjects_df"
"""The name for the ``subjects_df`` save file."""
EVENTS_FN: str = "events_df"
"""The name for the ``events_df`` save file."""
DYNAMIC_MEASUREMENTS_FN: str = "dynamic_measurements_df"
"""The name for the ``dynamic_measurements_df`` save file."""
[docs]
@classmethod
def subjects_fp(cls, save_dir: Path) -> Path:
"""Returns the filepath for the ``subjects_df`` given `save_dir` and class parameters."""
return save_dir / f"{cls.SUBJECTS_FN}.{cls.DF_SAVE_FORMAT}"
[docs]
@classmethod
def events_fp(cls, save_dir: Path) -> Path:
"""Returns the filepath for the ``events_df`` given `save_dir` and class parameters."""
return save_dir / f"{cls.EVENTS_FN}.{cls.DF_SAVE_FORMAT}"
[docs]
@classmethod
def dynamic_measurements_fp(cls, save_dir: Path) -> Path:
"""Returns the filepath for the ``dynamic_measurements_df`` given `save_dir` and class parameters."""
return save_dir / f"{cls.DYNAMIC_MEASUREMENTS_FN}.{cls.DF_SAVE_FORMAT}"
@classmethod
@abc.abstractmethod
def _load_input_df(
cls,
df: INPUT_DF_T,
columns: list[tuple[str, InputDataType | tuple[InputDataType, str]]],
subject_id_col: str | None = None,
filter_on: dict[str, bool | list[Any]] | None = None,
) -> DF_T:
"""Loads an input dataframe into the format expected by the processing library."""
raise NotImplementedError("Must be implemented by subclass.")
@classmethod
@abc.abstractmethod
def _process_events_and_measurements_df(
cls,
df: DF_T,
event_type: str,
columns_schema: dict[str, tuple[str, InputDataType]],
ts_col: str | list[str],
) -> tuple[DF_T, DF_T | None]:
"""Performs the following steps on an input events and measurements dataframe:
1. Produces a unified timestamp column representing the minimum of passed timestamps, with the name,
``'timestamp'``.
2. Adds a categorical event type column either from column (if `event_type` begins with ``'COL:'``) or
with value `event_type`.
3. Extracts and renames the columns present in `columns_schema`.
4. Adds an integer `event_id` column.
4. Splits the dataframe into an events dataframe, storing `event_id`, `subject_id`, `event_type`,
and `timestamp`, and a `measurements` dataframe, storing `event_id` and all other data columns.
"""
raise NotImplementedError("Must be implemented by subclass.")
@classmethod
@abc.abstractmethod
def _split_range_events_df(
cls, df: DF_T, start_ts_col: str | list[str], end_ts_col: str | list[str]
) -> tuple[DF_T, DF_T, DF_T]:
"""Performs the following steps:
1. Produces unified start and end timestamp columns representing the minimum of the passed start and
end timestamps, respectively.
2. Filters out records where the end timestamp is earlier than the start timestamp.
3. Splits the dataframe into 3 events dataframes, all with only a single timestamp column, named
`'timestamp'`:
(a) An "EQ" dataframe, where start_ts_col == end_ts_col,
(b) A "start" dataframe, with start events, and
(c) An "end" dataframe, with end events.
"""
raise NotImplementedError("Must be implemented by subclass.")
@classmethod
@abc.abstractmethod
def _concat_dfs(cls, dfs: list[DF_T]) -> DF_T:
"""Concatenates a list of dataframes into a single dataframe."""
raise NotImplementedError("Must be implemented by subclass.")
@classmethod
@abc.abstractmethod
def _resolve_ts_col(cls, df: DF_T, ts_col: str | list[str], out_name: str = "timestamp") -> DF_T:
"""Adds the minimum of the columns `ts_col` as a `datetime` column with name `out_name`"""
raise NotImplementedError("Must be implemented by subclass.")
@classmethod
@abc.abstractmethod
def _rename_cols(cls, df: DF_T, to_rename: dict[str, str]) -> DF_T:
"""Renames the columns in df according to the {in_name: out_name}s specified in to_rename."""
raise NotImplementedError("Must be implemented by subclass.")
[docs]
@classmethod
def build_subjects_dfs(cls, schema: InputDFSchema) -> tuple[DF_T, dict[Hashable, int]]:
"""Builds and returns the subjects dataframe from `schema`.
Args:
schema: The input schema defining the subjects dataframe. This will include a definition of the
input dataframe, the subject ID column, the static measurements columns to load, etc.
Returns:
Both the built `subjects_df` as well as a dictionary from the raw subject ID column values to the
inferred numeric subject IDs.
"""
subjects_df = cls._load_input_df(
schema.input_df,
schema.columns_to_load,
filter_on=schema.filter_on,
subject_id_col=schema.subject_id_col,
)
return cls._rename_cols(subjects_df, {i: o for i, (o, _) in schema.unified_schema.items()})
[docs]
@classmethod
def build_event_and_measurement_dfs(
cls,
subject_id_col: str,
schemas_by_df: dict[INPUT_DF_T, list[InputDFSchema]],
) -> tuple[DF_T, DF_T]:
"""Builds and returns events and measurements dataframes from the input schema map.
Args:
subject_id_col: The name of the column containing (input) subject IDs.
schemas_by_df: A mapping from input dataframe to associated event/measurement schemas.
Returns:
Both the built `events_df` and `dynamic_measurements_df`.
"""
all_events_and_measurements = []
event_types = []
for df, schemas in schemas_by_df.items():
all_columns = []
all_columns.extend(itertools.chain.from_iterable(s.columns_to_load for s in schemas))
try:
df = cls._load_input_df(df, all_columns, subject_id_col)
except Exception as e:
raise ValueError(f"Errored while loading {df}") from e
for schema in schemas:
if schema.filter_on:
logger.debug("Filtering")
df = cls._filter_col_inclusion(schema.filter_on)
match schema.type:
case InputDFType.EVENT:
logger.debug("Processing Event")
df = cls._resolve_ts_col(df, schema.ts_col, "timestamp")
all_events_and_measurements.append(
cls._process_events_and_measurements_df(
df=df,
event_type=schema.event_type,
columns_schema=schema.unified_schema,
)
)
event_types.append(schema.event_type)
case InputDFType.RANGE:
logger.debug("Processing Range")
df = cls._resolve_ts_col(df, schema.start_ts_col, "start_time")
df = cls._resolve_ts_col(df, schema.end_ts_col, "end_time")
for et, unified_schema, sp_df in zip(
schema.event_type,
schema.unified_schema,
cls._split_range_events_df(df=df),
):
all_events_and_measurements.append(
cls._process_events_and_measurements_df(
sp_df, columns_schema=unified_schema, event_type=et
)
)
event_types.extend(schema.event_type)
case _:
raise ValueError(f"Invalid schema type {schema.type}.")
all_events, all_measurements = [], []
for event_type, (events, measurements) in zip(event_types, all_events_and_measurements):
if events is None:
logger.warning(f"Empty new events dataframe of type {event_type}!")
continue
all_events.append(events)
if measurements is not None:
all_measurements.append(measurements)
return cls._concat_dfs(all_events), cls._concat_dfs(all_measurements)
@classmethod
@abc.abstractmethod
def _read_df(cls, fp: Path, **kwargs) -> DF_T:
"""Reads a dataframe from `fp`."""
raise NotImplementedError
@classmethod
@abc.abstractmethod
def _write_df(cls, df: DF_T, fp: Path, **kwargs):
"""Writes `df` to `fp`."""
raise NotImplementedError
@property
def subjects_df(self) -> DF_T:
"""Lazily loads and/or returns the subjects dataframe from the implicit filepath.
This will return the `_subjects_df` attribute, if defined and not `None`; otherwise, it will attempt
to load the subjects dataframe from the implicit filepath defined by `config.save_dir` and
`SUBJECTS_FN`.
"""
if (not hasattr(self, "_subjects_df")) or self._subjects_df is None:
subjects_fp = self.subjects_fp(self.config.save_dir)
logger.info(f"Loading subjects from {subjects_fp}...")
self._subjects_df = self._read_df(subjects_fp)
return self._subjects_df
@subjects_df.setter
def subjects_df(self, subjects_df: DF_T):
self._subjects_df = subjects_df
@property
def events_df(self) -> DF_T:
"""Lazily loads and/or returns the events dataframe from the implicit filepath.
This will return the `_events_df` attribute, if defined and not `None`; otherwise, it will attempt to
load the events dataframe from the implicit filepath defined by `config.save_dir` and `EVENTS_FN`.
"""
if (not hasattr(self, "_events_df")) or self._events_df is None:
events_fp = self.events_fp(self.config.save_dir)
logger.info(f"Loading events from {events_fp}...")
self._events_df = self._read_df(events_fp)
return self._events_df
@events_df.setter
def events_df(self, events_df: DF_T):
self._events_df = events_df
@property
def dynamic_measurements_df(self) -> DF_T:
"""Lazily loads and/or returns the measurements dataframe from the implicit filepath.
This will return the `_dynamic_measurements_df` attribute, if defined and not `None`; otherwise, it
will attempt to load the dynamic measurements dataframe from the implicit filepath defined by
`config.save_dir` and `DYNAMIC_MEASUREMENTS_FN`.
"""
if (not hasattr(self, "_dynamic_measurements_df")) or self._dynamic_measurements_df is None:
dynamic_measurements_fp = self.dynamic_measurements_fp(self.config.save_dir)
logger.info(f"Loading dynamic_measurements from {dynamic_measurements_fp}...")
self._dynamic_measurements_df = self._read_df(dynamic_measurements_fp)
return self._dynamic_measurements_df
@dynamic_measurements_df.setter
def dynamic_measurements_df(self, dynamic_measurements_df: DF_T):
self._dynamic_measurements_df = dynamic_measurements_df
[docs]
@classmethod
def load(cls, load_dir: Path) -> "DatasetBase":
"""Loads and returns a dataset from disk.
This function re-loads an instance of the calling class from disk. This function assumes that files
are stored on disk in the following, distributed format:
* The base configuration object is stored in the file ``'config.json'``, in JSON format.
* If the saved dataset has already been fit, then the pre-processed measurement configs with inferred
parameters are stroed in ``'inferred_measurement_configs.json'``, in JSON format. Note that these
configs may in turn store their own attributes in further files, such as their
`measurement_metadata` dataframes, which are stored on disk in separate files to facilitate lazy
loading.
* The raw or fully pre-processed subjects, events, and measurements dataframes are stored in their
respective filenames (`SUBJECTS_FN`, `EVENTS_FN`, `DYNAMIC_MEASUREMENTS_FN`).
* Remaining attributes are stored in pickle format at ``'E.pkl'``.
Args:
load_dir: The path to the directory on disk from which the dataset should be loaded.
Raises:
FileNotFoundError: If either the attributes file or config file do not exist.
"""
attrs_fp = load_dir / "E.pkl"
reloaded_config = DatasetConfig.from_json_file(load_dir / "config.json")
if reloaded_config.save_dir != load_dir:
logger.info(f"Updating config.save_dir from {reloaded_config.save_dir} to {load_dir}")
reloaded_config.save_dir = load_dir
attrs_to_add = {"config": reloaded_config}
inferred_measurement_configs_fp = load_dir / "inferred_measurement_configs.json"
if inferred_measurement_configs_fp.is_file():
with open(inferred_measurement_configs_fp) as f:
attrs_to_add["inferred_measurement_configs"] = {
k: MeasurementConfig.from_dict(v, base_dir=load_dir) for k, v in json.load(f).items()
}
return super()._load(attrs_fp, **attrs_to_add)
[docs]
def save(self, **kwargs):
"""Saves the calling object to disk, in the directory `self.config.save_dir`.
This function stores to disk the internal parameters of the calling object, in the following format:
* The base configuration object is stored in the file ``'config.json'``, in JSON format.
* If the saved dataset has already been fit, then the pre-processed measurement configs with inferred
parameters are stroed in ``'inferred_measurement_configs.json'``, in JSON format. Note that these
configs may in turn store their own attributes in further files, such as their
`measurement_metadata` dataframes, which are stored on disk in separate files to facilitate lazy
loading.
* The raw or fully pre-processed subjects, events, and measurements dataframes are stored in their
respective filenames (`SUBJECTS_FN`, `EVENTS_FN`, `DYNAMIC_MEASUREMENTS_FN`).
* Remaining attributes are stored in pickle format at ``'E.pkl'``.
Args:
do_overwrite: Keyword only; if passed with a value evaluating to `True`, then the system will
overwrite any files that exist, rather than erroring.
Raises:
FileExistsError: If any of the desired filepaths already exist and `do_overwrite` is False.
"""
self.config.save_dir.mkdir(parents=True, exist_ok=True)
do_overwrite = kwargs.get("do_overwrite", False)
config_fp = self.config.save_dir / "config.json"
self.config.to_json_file(config_fp, do_overwrite=do_overwrite)
if self._is_fit:
self.config.save_dir / "inferred_measurement_metadata"
for k, v in self.inferred_measurement_configs.items():
v.cache_measurement_metadata(self.config.save_dir, f"inferred_measurement_metadata/{k}.csv")
inferred_measurement_configs_fp = self.config.save_dir / "inferred_measurement_configs.json"
inferred_measurement_configs = {
k: v.to_dict() for k, v in self.inferred_measurement_configs.items()
}
with open(inferred_measurement_configs_fp, mode="w") as f:
json.dump(inferred_measurement_configs, f)
super()._save(self.config.save_dir / "E.pkl", **kwargs)
vocab_config_fp = self.config.save_dir / "vocabulary_config.json"
self.vocabulary_config.to_json_file(vocab_config_fp, do_overwrite=do_overwrite)
subjects_fp = self.subjects_fp(self.config.save_dir)
events_fp = self.events_fp(self.config.save_dir)
dynamic_measurements_fp = self.dynamic_measurements_fp(self.config.save_dir)
self._write_df(self.subjects_df, subjects_fp, do_overwrite=do_overwrite)
self._write_df(self.events_df, events_fp, do_overwrite=do_overwrite)
self._write_df(self.dynamic_measurements_df, dynamic_measurements_fp, do_overwrite=do_overwrite)
def __init__(
self,
config: DatasetConfig,
subjects_df: DF_T | None = None,
events_df: DF_T | None = None,
dynamic_measurements_df: DF_T | None = None,
input_schema: DatasetSchema | None = None,
**kwargs,
):
super().__init__(**kwargs)
if "do_overwrite" in kwargs:
self.do_overwrite = kwargs["do_overwrite"]
if (
subjects_df is None or events_df is None or dynamic_measurements_df is None
) and input_schema is None:
raise ValueError(
"Must set input_schema if subjects_df, events_df, or dynamic_measurements_df are None!"
)
if input_schema is None:
if subjects_df is None:
raise ValueError("Must set subjects_df if input_schema is None!")
if events_df is None:
raise ValueError("Must set events_df if input_schema is None!")
if dynamic_measurements_df is None:
raise ValueError("Must set dynamic_measurements_df if input_schema is None!")
else:
if subjects_df is not None:
raise ValueError("Can't set subjects_df if input_schema is not None!")
if events_df is not None:
raise ValueError("Can't set events_df if input_schema is not None!")
if dynamic_measurements_df is not None:
raise ValueError("Can't set dynamic_measurements_df if input_schema is not None!")
subjects_df = self.build_subjects_dfs(input_schema.static)
logger.debug("Extracting events and measurements dataframe...")
events_df, dynamic_measurements_df = self.build_event_and_measurement_dfs(
input_schema.static.subject_id_col,
input_schema.dynamic_by_df,
)
logger.debug("Built events and measurements dataframe")
self.config = config
self._is_fit = False
# After pre-processing, we may infer new types or otherwise change measurement configuration, so
# we store a separage configuration object for post-processing. It is initialized as empty as we have
# not yet pre-processed anything.
self.inferred_measurement_configs = {}
self._validate_and_set_initial_properties(subjects_df, events_df, dynamic_measurements_df)
self.split_subjects = {}
def _validate_and_set_initial_properties(self, subjects_df, events_df, dynamic_measurements_df):
"""Validates the input dataframes and sets initial properties of the calling object.
This validates that the initial dataframes are appropriately configured, re-sets certain types to
minimal-memory ``dtypes`` (e.g., ensuring ID columns are set to the smallest valid ``uint`` type), and
sets non-DF parameters such as `subject_ids`, `event_types`, and `n_events_per_subject`.
Args:
subjects_df: The subjects dataframe.
events_df: The events dataframe.
dynamic_measurements_df: The dynamic measurements dataframe.
"""
self.subject_ids = []
self.event_types = []
self.n_events_per_subject = {}
self.events_df = events_df
self.dynamic_measurements_df = dynamic_measurements_df
if self.events_df is not None:
self._agg_by_time()
self._sort_events()
(
self.subjects_df,
self.events_df,
self.dynamic_measurements_df,
) = self._validate_initial_dfs(subjects_df, self.events_df, self.dynamic_measurements_df)
self._update_subject_event_properties()
@abc.abstractmethod
def _validate_initial_dfs(
self, subjects_df: DF_T, events_df: DF_T, dynamic_measurements_df: DF_T
) -> tuple[DF_T, DF_T, DF_T]:
"""Validates input dataframes and massages their internal types to minimize memory requirements."""
raise NotImplementedError("This method must be implemented by a subclass.")
@abc.abstractmethod
def _update_subject_event_properties(self):
"""Updates the `subject_ids`, `event_types`, and `n_events_per_subject` internal properties."""
raise NotImplementedError("This method must be implemented by a subclass.")
@TimeableMixin.TimeAs
def _filter_subjects(self):
"""Filters the internal subjects dataframe to only those who have a minimum number of events."""
if self.config.min_events_per_subject is None:
return
subjects_to_keep = [
s for s, n in self.n_events_per_subject.items() if n >= self.config.min_events_per_subject
]
self.subjects_df = self._filter_col_inclusion(self.subjects_df, {"subject_id": subjects_to_keep})
self.events_df = self._filter_col_inclusion(self.events_df, {"subject_id": subjects_to_keep})
self.dynamic_measurements_df = self._filter_col_inclusion(
self.dynamic_measurements_df, {"event_id": list(self.events_df["event_id"])}
)
@TimeableMixin.TimeAs
@abc.abstractmethod
def _agg_by_time(self):
"""Aggregates events into temporal buckets governed by `self.config.agg_by_time_scale`.
Aggregates the events_df by subject_id and timestamp (into buckets of size
`self.config.agg_by_time_scale`), combining event_types into grouped categories with names
concatenated with a separator of '&', then re-aligns measurements into the new event IDs in
`dynamic_measurements_df`. Note that no numerical aggregation (e.g., mean, etc.) happens here; all
data is retained, and only dynamic measurement event IDs are updated.
"""
raise NotImplementedError("This method must be implemented by a subclass.")
@TimeableMixin.TimeAs
@abc.abstractmethod
def _sort_events(self):
"""Sorts events by subject ID and timestamp in ascending order."""
raise NotImplementedError("This method must be implemented by a subclass.")
[docs]
@SeedableMixin.WithSeed
@TimeableMixin.TimeAs
def split(
self,
split_fracs: Sequence[float],
split_names: Sequence[str] | None = None,
mandatory_set_IDs: dict[str, set[int] | None] | None = None,
):
"""Splits the underlying dataset into random sets by `subject_id`.
Args:
split_fracs: The fractional sizes of the desired splits. If it sums to < 1, the remainder will be
tracked **in an extra split** at the end of the list. All split fractions must be positive
floating point numbers less than 1.
split_names: If specified, assigns the passed names to each split. Must be of the same size as
`split_fracs` (after it is expanded to sum to 1 if necessary). If unset, and there are two
splits, it defaults to [`train`, `held_out`]. If there are three, it defaults to `['train',
'tuning', 'held_out']. If more than 3, it defaults to `['split_0', 'split_1', ...]`. Split
names of `train`, `tuning`, and `held_out` have special significance and are used elsewhere in
the model, so if `split_names` does not reflect those other things may not work down the line.
mandatory_set_IDs: Maps split name to an optional set of subject IDs that make up that split. If a
split name is included in mandatory_set_IDs, it should _not_ be included in `split_fracs` as
the size of the split is determined by the IDs in this object. Any IDs in this object will be
excluded from _all_ other splits and split_fractions will be taken over the remaining, unused
IDs.
Raises:
ValueError: if `split_fracs` contains anything outside the range of (0, 1], sums to something > 1,
or is not of the same length as `split_names`.
"""
split_fracs = list(split_fracs)
if min(split_fracs) <= 0 or max(split_fracs) > 1 or sum(split_fracs) > 1:
raise ValueError(
"split_fracs invalid! Want a list of numbers in (0, 1] that sums to no more than 1; got "
f"{repr(split_fracs)}"
)
if sum(split_fracs) < 1:
split_fracs.append(1 - sum(split_fracs))
if split_names is None:
if len(split_fracs) == 2:
split_names = ["train", "held_out"]
elif len(split_fracs) == 3:
split_names = ["train", "tuning", "held_out"]
else:
split_names = [f"split_{i}" for i in range(len(split_fracs))]
elif len(split_names) != len(split_fracs):
raise ValueError(
f"split_names and split_fracs must be the same length; got {len(split_names)} and "
f"{len(split_fracs)}"
)
if mandatory_set_IDs is None:
mandatory_set_IDs = {}
intersecting_split_names = set(split_names).intersection(mandatory_set_IDs.keys())
if intersecting_split_names:
raise ValueError(
"Splits with specified sizes overlap with those with pre-set populations! "
f"{', '.join(intersecting_split_names)}"
)
subjects_to_split = set(self.subject_ids) - set(
itertools.chain.from_iterable(mandatory_set_IDs.values())
)
# As split fractions may not result in integer split sizes, we shuffle the split names and fractions
# so that the splits that exceed the desired size are not always the last ones in the original passed
# order.
split_names_idx = np.random.permutation(len(split_names))
split_names = [split_names[i] for i in split_names_idx]
split_fracs = [split_fracs[i] for i in split_names_idx]
subjects = np.random.permutation(list(subjects_to_split))
split_lens = (np.array(split_fracs[:-1]) * len(subjects)).round().astype(int)
split_lens = np.append(split_lens, len(subjects) - split_lens.sum())
subjects_per_split = np.split(subjects, split_lens.cumsum())
self.split_subjects = {k: set(v) for k, v in zip(split_names, subjects_per_split)}
self.split_subjects = {**self.split_subjects, **mandatory_set_IDs}
@classmethod
@abc.abstractmethod
def _filter_col_inclusion(cls, df: DF_T, col_inclusion_targets: dict[str, bool | Sequence[Any]]) -> DF_T:
"""Filters `df` via the mapping of column names to allowed values in `col_inclusion_targets`."""
raise NotImplementedError("This method must be implemented by a subclass.")
@property
def train_subjects_df(self) -> DF_T:
"""Returns the train set split of subjects_df."""
return self._filter_col_inclusion(self.subjects_df, {"subject_id": self.split_subjects["train"]})
@property
def tuning_subjects_df(self) -> DF_T:
"""Returns the tuning set split of subjects_df."""
return self._filter_col_inclusion(self.subjects_df, {"subject_id": self.split_subjects["tuning"]})
@property
def held_out_subjects_df(self) -> DF_T:
"""Returns the held-out set split of subjects_df."""
return self._filter_col_inclusion(self.subjects_df, {"subject_id": self.split_subjects["held_out"]})
@property
def train_events_df(self) -> DF_T:
"""Returns the train set split of events_df."""
return self._filter_col_inclusion(self.events_df, {"subject_id": self.split_subjects["train"]})
@property
def tuning_events_df(self) -> DF_T:
"""Returns the tuning set split of events_df."""
return self._filter_col_inclusion(self.events_df, {"subject_id": self.split_subjects["tuning"]})
@property
def held_out_events_df(self) -> DF_T:
"""Returns the held-out set split of events_df."""
return self._filter_col_inclusion(self.events_df, {"subject_id": self.split_subjects["held_out"]})
@property
def train_dynamic_measurements_df(self) -> DF_T:
"""Returns the train set split of dynamic_measurements_df."""
event_ids = self.train_events_df["event_id"]
return self._filter_col_inclusion(self.dynamic_measurements_df, {"event_id": list(event_ids)})
@property
def tuning_dynamic_measurements_df(self) -> DF_T:
"""Returns the tuning set split of dynamic_measurements_df."""
event_ids = self.tuning_events_df["event_id"]
return self._filter_col_inclusion(self.dynamic_measurements_df, {"event_id": list(event_ids)})
@property
def held_out_dynamic_measurements_df(self) -> DF_T:
"""Returns the held-out set split of dynamic_measurements_df."""
event_ids = self.held_out_events_df["event_id"]
return self._filter_col_inclusion(self.dynamic_measurements_df, {"event_id": list(event_ids)})
[docs]
@TimeableMixin.TimeAs
def preprocess(self):
"""Fits all pre-processing parameters over the train set, then transforms all observations.
This entails the following steps:
1. First, filter out subjects that have too few events.
2. Next, pre-compute the `FUNCTIONAL_TIME_DEPENDENT` temporality measurements and store their values
in the events dataframe.
3. Next, fit all pre-processing parameters over the observed measurements.
4. Finally, transform all data via the fit pre-processing parameters.
"""
logger.info("Filtering subjects")
self._filter_subjects()
logger.info("Adding time derived measurements")
self._add_time_dependent_measurements()
logger.info("Fitting pre-processing parameters")
self.fit_measurements()
logger.info("Transforming variables.")
self.transform_measurements()
logger.info("Done with preprocessing")
@TimeableMixin.TimeAs
@abc.abstractmethod
def _add_time_dependent_measurements(self):
"""Adds `FUNCTIONAL_TIME_DEPENDENT` temporality measurement values to events_df."""
raise NotImplementedError("This method must be implemented by a subclass.")
@TimeableMixin.TimeAs
def _get_source_df(self, config: MeasurementConfig, do_only_train: bool = True) -> tuple[str, str, DF_T]:
"""Returns the name of the source attribute, its id column, and that dataframe for `config`.
Measurements with different configs are stored in different internal dataframes (e.g., `STATIC`
measurements in `subjects_df`, `DYNAMIC` measurements in `dynamic_measurements_df`), and are goverend
by different natural ID columns. This function gets and returns the appropriate attribute name, ID
column name for that attribute, and the associated dataframe.
Args:
config: The measurement config for which we should retrieve the source dataframe.
do_only_train: Whether or not we should also return only these data on the train set or not.
Raises:
ValueError: If the passed measurement config has an invalid temporality type.
"""
match config.temporality:
case TemporalityType.DYNAMIC:
source_attr = "dynamic_measurements_df"
source_id = "measurement_id"
if do_only_train:
source_df = self.train_dynamic_measurements_df
else:
source_df = self.dynamic_measurements_df
case TemporalityType.STATIC:
source_attr = "subjects_df"
source_id = "subject_id"
source_df = self.train_subjects_df if do_only_train else self.subjects_df
case TemporalityType.FUNCTIONAL_TIME_DEPENDENT:
source_attr = "events_df"
source_id = "event_id"
source_df = self.train_events_df if do_only_train else self.events_df
case _:
raise ValueError(f"Called get_source_df on temporality type {config.temporality}!")
return source_attr, source_id, source_df
[docs]
@TimeableMixin.TimeAs
def fit_measurements(self):
"""Fits all preprocessing parameters over the training dataset, according to `self.config`.
Raises:
ValueError: if fitting preprocessing parameters fails for a given measurement.
"""
self._is_fit = False
for measure, config in self.config.measurement_configs.items():
if config.is_dropped:
continue
self.inferred_measurement_configs[measure] = copy.deepcopy(config)
config = self.inferred_measurement_configs[measure]
_, _, source_df = self._get_source_df(config, do_only_train=True)
if measure not in source_df:
logger.warning(f"Measure {measure} not found! Dropping...")
config.drop()
continue
total_possible, total_observed, raw_total_observed = self._total_possible_and_observed(
measure, config, source_df
)
source_df = self._filter_col_inclusion(source_df, {measure: True})
if total_possible == 0:
logger.info(f"Found no possible events for {measure}!")
config.drop()
continue
config.observation_rate_over_cases = total_observed / total_possible
config.observation_rate_per_case = raw_total_observed / total_observed
# 2. Drop the column if observations occur too rarely.
if lt_count_or_proportion(
total_observed, self.config.min_valid_column_observations, total_possible
):
config.drop()
continue
if config.is_numeric:
config.add_missing_mandatory_metadata_cols()
try:
config.measurement_metadata = self._fit_measurement_metadata(measure, config, source_df)
except BaseException as e:
raise ValueError(f"Fitting measurement metadata failed for measure {measure}!") from e
if config.vocabulary is None:
config.vocabulary = self._fit_vocabulary(measure, config, source_df)
# 4. Eliminate observations that occur too rarely.
if config.vocabulary is not None:
if self.config.min_valid_vocab_element_observations is not None:
config.vocabulary.filter(
len(source_df), self.config.min_valid_vocab_element_observations
)
# 5. If all observations were eliminated, drop the column.
if config.vocabulary.vocabulary == ["UNK"]:
config.drop()
self._is_fit = True
@abc.abstractmethod
def _total_possible_and_observed(
self, measure: str, config: MeasurementConfig, source_df: DF_T
) -> tuple[int, int, int]:
"""Returns the total possible/actual/all raw instances where `measure` could be/was observed.
Possible means number of subjects (for static measurements) or number of unique events (for dynamic or
functional time dependent measurements). Actual means where the given measurement column takes on a
non-null value. All means the count of total observations, accounting for duplicate observations per
possible instance. For a multivariate regression measurement, the column that must be non-null is the
key column, not the value column.
Args:
measure: The name of the measurement.
config: The measurement config for the given measurement.
source_df: The dataframe from which to compute the total possible/actual instances.
"""
raise NotImplementedError("This method must be implemented by a subclass.")
@abc.abstractmethod
def _fit_measurement_metadata(
self, measure: str, config: MeasurementConfig, source_df: DF_T
) -> pd.DataFrame:
"""Fits & returns the metadata df for a numeric measurement over the source df.
The measurement metadata structure stores pre-processing parameters for numerical variables like
value type, outlier model parameters, normalizer parameters, etc.
Args:
measure: The name of the measurement.
config: The measurement config for the given measurement.
source_df: The dataframe from which to compute the measurement metadata columns.
"""
raise NotImplementedError("This method must be implemented by a subclass.")
@TimeableMixin.TimeAs
@abc.abstractmethod
def _fit_vocabulary(self, measure: str, config: MeasurementConfig, source_df: DF_T) -> Vocabulary:
"""Fits and returns the vocabulary for a categorical measurement over the source dataframe.
Args:
measure: The name of the measurement.
config: The measurement config for the given measurement.
source_df: The dataframe from which to compute the measurement metadata columns.
"""
raise NotImplementedError("This method must be implemented by a subclass.")
@TimeableMixin.TimeAs
@abc.abstractmethod
def _update_attr_df(self, attr: str, id_col: str, df: DF_T, cols_to_update: list[str]):
"""Replaces the columns in `cols_to_update` in self's df stored @ `attr` with the vals in `df`.
Replaces all values in the currently stored dataframe at the columns in cols_to_update with
None, then further updates the dataframe by ID with the values for those columns in `df`.
"""
raise NotImplementedError("This method must be implemented by a subclass.")
@TimeableMixin.TimeAs
@abc.abstractmethod
def _transform_numerical_measurement(
self, measure: str, config: MeasurementConfig, source_df: DF_T
) -> DF_T:
"""Transforms the numerical measurement `measure` according to config `config`.
Performs the following steps:
1. Transforms keys to categorical representations for categorical keys.
2. Eliminates any values associated with dropped or categorical keys.
3. Eliminates hard outliers and performs censoring via specified config.
4. Converts values to desired types.
5. Adds inlier/outlier indices and remove learned outliers.
6. Normalizes values.
Args:
measure: The column name of the governing measurement to transform.
config: The configuration object governing this measure.
source_df: The dataframe object containing the measure to be transformed.
"""
raise NotImplementedError("This method must be implemented by a subclass.")
@TimeableMixin.TimeAs
@abc.abstractmethod
def _transform_categorical_measurement(
self, measure: str, config: MeasurementConfig, source_df: DF_T
) -> DF_T:
"""Converts the elements to categorical column types according to the learned vocabularies.
Args:
measure: The column name of the governing measurement to transform.
config: The configuration object governing this measure.
source_df: The dataframe object containing the measure to be transformed.
"""
raise NotImplementedError("This method must be implemented by a subclass.")
@property
def has_static_measurements(self):
"""Returns `True` if the dataset has any static measurements."""
return (self.subjects_df is not None) and any(
cfg.temporality == TemporalityType.STATIC for cfg in self.measurement_configs.values()
)
@property
def measurement_configs(self):
"""Errors if not fit; otherwise returns all fit, non-dropped measurement configs.
Raises:
ValueError: if is not fit.
"""
if not self._is_fit:
raise ValueError("Can't call measurement_configs if not yet fit!")
return {m: c for m, c in self.inferred_measurement_configs.items() if not c.is_dropped}
@property
def dynamic_numerical_columns(self):
"""Returns all numerical metadata column key-column, value-column pairs."""
return [
(k, cfg.values_column)
for k, cfg in self.measurement_configs.items()
if (cfg.is_numeric and cfg.temporality == TemporalityType.DYNAMIC)
]
@property
def time_dependent_numerical_columns(self):
"""Returns all numerical metadata column key-column, value-column pairs."""
return [
k
for k, cfg in self.measurement_configs.items()
if (cfg.is_numeric and cfg.temporality == TemporalityType.FUNCTIONAL_TIME_DEPENDENT)
]
@property
def measurement_idxmaps(self):
"""Accesses the fit vocabularies vocabulary idxmap objects, per measurement column."""
idxmaps = {"event_type": {et: i for i, et in enumerate(self.event_types)}}
for m, config in self.measurement_configs.items():
if config.vocabulary is not None:
idxmaps[m] = config.vocabulary.idxmap
return idxmaps
@property
def measurement_vocabs(self):
"""Accesses the fit vocabularies vocabulary objects, per measurement column."""
vocabs = {"event_type": self.event_types}
for m, config in self.measurement_configs.items():
if config.vocabulary is not None:
vocabs[m] = config.vocabulary.vocabulary
return vocabs
@abc.abstractmethod
def _get_flat_ts_rep(self, **kwargs) -> DF_T:
raise NotImplementedError("Must be overwritten in base class.")
@abc.abstractmethod
def _get_flat_static_rep(self, **kwargs) -> DF_T:
raise NotImplementedError("Must be overwritten in base class.")
@classmethod
@abc.abstractmethod
def _summarize_over_window(self, df: DF_T, window_size: str):
raise NotImplementedError("Must be overwritten in base class.")
def _resolve_flat_rep_cache_params(
self,
feature_inclusion_frequency: float | dict[str, float] | None = None,
include_only_measurements: Sequence[str] | None = None,
) -> tuple[dict[str, float] | None, set[str]]:
if include_only_measurements is None:
if isinstance(feature_inclusion_frequency, dict):
include_only_measurements = sorted(list(feature_inclusion_frequency.keys()))
else:
include_only_measurements = sorted(list(self.measurement_configs.keys()))
else:
include_only_measurements = sorted(list(set(include_only_measurements)))
if isinstance(feature_inclusion_frequency, float):
feature_inclusion_frequency = {m: feature_inclusion_frequency for m in include_only_measurements}
return feature_inclusion_frequency, include_only_measurements
def _get_flat_rep_feature_cols(
self,
feature_inclusion_frequency: float | dict[str, float] | None = None,
window_sizes: list[str] | None = None,
include_only_measurements: set[str] | None = None,
) -> list[str]:
feature_inclusion_frequency, include_only_measurements = self._resolve_flat_rep_cache_params(
feature_inclusion_frequency, include_only_measurements
)
feature_columns = []
for m, cfg in self.measurement_configs.items():
if m not in include_only_measurements:
continue
features = None
if cfg.vocabulary is not None:
vocab = copy.deepcopy(cfg.vocabulary)
if feature_inclusion_frequency is not None:
m_freq = feature_inclusion_frequency[m]
vocab.filter(total_observations=None, min_valid_element_freq=m_freq)
features = vocab.vocabulary
elif cfg.modality == DataModality.UNIVARIATE_REGRESSION:
features = [m]
else:
raise ValueError(f"Config with modality {cfg.modality} should have a Vocabulary!")
match cfg.temporality:
case TemporalityType.STATIC:
temps = [str(cfg.temporality)]
match cfg.modality:
case DataModality.UNIVARIATE_REGRESSION:
aggs = ["value"]
case DataModality.SINGLE_LABEL_CLASSIFICATION:
aggs = ["present"]
case _:
raise ValueError(f"{cfg.modality} invalid with {cfg.temporality}")
case TemporalityType.FUNCTIONAL_TIME_DEPENDENT if window_sizes is None:
temps = [str(cfg.temporality)]
match cfg.modality:
case DataModality.UNIVARIATE_REGRESSION:
aggs = ["value"]
case DataModality.SINGLE_LABEL_CLASSIFICATION:
aggs = ["present"]
case _:
raise ValueError(f"{cfg.modality} invalid with {cfg.temporality}")
case TemporalityType.FUNCTIONAL_TIME_DEPENDENT if window_sizes is not None:
temps = window_sizes
match cfg.modality:
case DataModality.UNIVARIATE_REGRESSION:
aggs = ["count", "has_values_count", "sum", "sum_sqd", "min", "max"]
case DataModality.SINGLE_LABEL_CLASSIFICATION:
aggs = ["count"]
case _:
raise ValueError(f"{cfg.modality} invalid with {cfg.temporality}")
case TemporalityType.DYNAMIC:
temps = [str(cfg.temporality)] if window_sizes is None else window_sizes
match cfg.modality:
case DataModality.UNIVARIATE_REGRESSION | DataModality.MULTIVARIATE_REGRESSION:
aggs = ["count", "has_values_count", "sum", "sum_sqd", "min", "max"]
case DataModality.MULTI_LABEL_CLASSIFICATION:
aggs = ["count"]
case _:
raise ValueError(f"{cfg.modality} invalid with {cfg.temporality}")
for temp in temps:
for feature in features:
for agg in aggs:
feature_columns.append(f"{temp}/{m}/{feature}/{agg}")
return sorted(feature_columns)
[docs]
@TimeableMixin.TimeAs
def cache_flat_representation(
self,
subjects_per_output_file: int | None = None,
feature_inclusion_frequency: float | dict[str, float] | None = None,
window_sizes: list[str] | None = None,
include_only_measurements: set[str] | None = None,
do_overwrite: bool = False,
do_update: bool = True,
):
"""Writes a flat (historically summarized) representation of the dataset to disk.
This file caches a set of files useful for building flat representations of the dataset to disk,
suitable for, e.g., sklearn style modeling for downstream tasks. It will produce a few sets of files:
* A new directory ``self.config.save_dir / "flat_reps"`` which contains the following:
* A subdirectory ``raw`` which contains: (1) a json file with the configuration arguments and (2) a
set of parquet files containing flat (e.g., wide) representations of summarized events per subject,
broken out by split and subject chunk.
* A set of subdirectories ``past/*`` which contains summarized views over the past ``*`` time period
per subject per event, for all time periods in ``window_sizes``, if any.
Args:
subjects_per_output_file: The number of subjects that should be included in each output file.
Lowering this number increases the number of files written, making the process of creating and
leveraging these files slower but more memory efficient.
feature_inclusion_frequency: The base feature inclusion frequency that should be used to dictate
what features can be included in the flat representation. It can either be a float, in which
case it applies across all measurements, or `None`, in which case no filtering is applied, or
a dictionary from measurement type to a float dictating a per-measurement-type inclusion
cutoff.
window_sizes: Beyond writing out a raw, per-event flattened representation, the dataset also has
the capability to summarize these flattened representations over the historical windows
specified in this argument. These are strings specifying time deltas, using this syntax:
`link`_. Each window size will be summarized to a separate directory, and will share the same
subject file split as is used in the raw representation files.
include_only_measurements: Measurement types can also be filtered out wholesale from both
representations. If this list is not None, only these measurements will be included.
do_overwrite: If `True`, this function will overwrite the data already stored in the target save
directory.
do_update: If `True`, this function will (a) ensure that the parameters are the same or are
mappable to one another (critically, _it may_ default to an existing subject split if one has
been used historically, overwriting the specified `subjects_per_output_file` parameter!), then
(b) attempt to write only those files that are not yet written to disk across the historical
summarization targets.
.. _link: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.group_by_rolling.html # noqa: E501
"""
logger.info("Caching flat representations")
self._seed(1, "cache_flat_representation")
feature_inclusion_frequency, include_only_measurements = self._resolve_flat_rep_cache_params(
feature_inclusion_frequency, include_only_measurements
)
flat_dir = self.config.save_dir / "flat_reps"
flat_dir.mkdir(exist_ok=True, parents=True)
sp_subjects = {}
for split, split_subjects in self.split_subjects.items():
if subjects_per_output_file is None:
sp_subjects[split] = [[int(x) for x in split_subjects]]
else:
sp_subjects[split] = [
[int(e) for e in x]
for x in np.array_split(
np.random.permutation(list(split_subjects)),
len(split_subjects) // subjects_per_output_file,
)
]
params = {
"subjects_per_output_file": subjects_per_output_file,
"feature_inclusion_frequency": feature_inclusion_frequency,
"include_only_measurements": include_only_measurements,
"subject_chunks_by_split": sp_subjects,
}
params_fp = flat_dir / "params.json"
if params_fp.exists():
if do_update:
with open(params_fp) as f:
old_params = json.load(f)
if old_params["subjects_per_output_file"] != params["subjects_per_output_file"]:
logger.info(
"Standardizing chunk size to existing record "
f"({old_params['subjects_per_output_file']})."
)
params["subjects_per_output_file"] = old_params["subjects_per_output_file"]
params["subject_chunks_by_split"] = old_params["subject_chunks_by_split"]
old_params["include_only_measurements"] = sorted(old_params["include_only_measurements"])
if old_params != params:
err_strings = ["Asked to update but parameters differ:"]
old = set(old_params.keys())
new = set(params.keys())
if old != new:
err_strings.append("Keys differ: ")
if old - new:
err_strings.append(f" old - new = {old - new}")
if new - old:
err_strings.append(f" new - old = {old - new}")
for k in old & new:
old_val = old_params[k]
new_val = params[k]
if old_val != new_val:
err_strings.append(f"Values differ for {k}:")
err_strings.append(f" Old: {old_val}")
err_strings.append(f" New: {new_val}")
raise ValueError("\n".join(err_strings))
elif not do_overwrite:
raise FileExistsError(f"do_overwrite is {do_overwrite} and {params_fp} exists!")
with open(params_fp, mode="w") as f:
json.dump(params, f)
# 0. Identify Output Columns
# We set window_sizes to None here because we want to get the feature column names for the raw flat
# representation, not the summarized one.
feature_columns = self._get_flat_rep_feature_cols(
feature_inclusion_frequency=feature_inclusion_frequency,
window_sizes=None,
include_only_measurements=include_only_measurements,
)
# 1. Produce static representation
static_subdir = flat_dir / "static"
static_dfs = {}
for sp, subjects in tqdm(list(params["subject_chunks_by_split"].items()), desc="Flattening Splits"):
static_dfs[sp] = []
sp_dir = static_subdir / sp
for i, subjects_list in enumerate(tqdm(subjects, desc="Subject chunks", leave=False)):
fp = sp_dir / f"{i}.parquet"
static_dfs[sp].append(fp)
if fp.exists():
if do_update:
continue
elif not do_overwrite:
raise FileExistsError(f"do_overwrite is {do_overwrite} and {fp} exists!")
df = self._get_flat_static_rep(
feature_columns=feature_columns,
include_only_subjects=subjects_list,
)
self._write_df(df, fp, do_overwrite=do_overwrite)
# 2. Produce raw representation
ts_subdir = flat_dir / "at_ts"
ts_dfs = {}
for sp, subjects in tqdm(list(params["subject_chunks_by_split"].items()), desc="Flattening Splits"):
ts_dfs[sp] = []
sp_dir = ts_subdir / sp
for i, subjects_list in enumerate(tqdm(subjects, desc="Subject chunks", leave=False)):
fp = sp_dir / f"{i}.parquet"
ts_dfs[sp].append(fp)
if fp.exists():
if do_update:
continue
elif not do_overwrite:
raise FileExistsError(f"do_overwrite is {do_overwrite} and {fp} exists!")
df = self._get_flat_ts_rep(
feature_columns=feature_columns,
include_only_subjects=subjects_list,
)
self._write_df(df, fp, do_overwrite=do_overwrite)
if window_sizes is None:
return
# 3. Produce summarized history representations
history_subdir = flat_dir / "over_history"
for window_size in tqdm(window_sizes, desc="History window sizes"):
for sp, df_fps in tqdm(list(ts_dfs.items()), desc="Windowing Splits", leave=False):
for i, df_fp in enumerate(tqdm(df_fps, desc="Subject chunks", leave=False)):
fp = history_subdir / sp / window_size / f"{i}.parquet"
if fp.exists():
if do_update:
continue
elif not do_overwrite:
raise FileExistsError(f"do_overwrite is {do_overwrite} and {fp} exists!")
df = self._summarize_over_window(df_fp, window_size)
self._write_df(df, fp)
[docs]
@TimeableMixin.TimeAs
def cache_deep_learning_representation(
self, subjects_per_output_file: int | None = None, do_overwrite: bool = False
):
"""Writes a deep-learning friendly representation of the dataset to disk.
The deep learning format produced will have one row per subject, with the following columns:
* ``subject_id``: This column will be an unsigned integer type, and will have the ID of the subject
for each row.
* ``start_time``: This column will be a `datetime` type, and will contain the start time of the
subject's record.
* ``static_indices``: This column is a ragged, sparse representation of the categorical static
measurements observed for this subject. Each element of this column will itself be a list of
unsigned integers corresponding to indices into the unified vocabulary for the static measurements
observed for that subject.
* ``static_measurement_indices``: This column corresponds in shape to ``static_indices``, but contains
unsigned integer indices into the unified measurement vocabulary, defining to which measurement each
observation corresponds. It is of the same shape and of a consistent order as ``static_indices.``
* ``time``: This column is a ragged array of the time in minutes from the start time at which each
event takes place. For a given row, the length of the array within this column corresponds to the
number of events that subject has.
* ``dynamic_indices``: This column is a doubly ragged array containing the indices of the observed
values within the unified vocabulary per event per subject. Each subject's data for this column
consists of an array of arrays, each containing only the indices observed at each event.
* ``dynamic_measurement_indices`` This column is a doubly ragged array containing the indices of the
observed measurements per event per subject. Each subject's data for this column consists of an
array of arrays, each containing only the indices of measurements observed at each event. It is of
the same shape and of a consistent order as ``dynamic_indices``.
* ``dynamic_values`` This column is a doubly ragged array containing the indices of the
observed measurements per event per subject. Each subject's data for this column consists of an
array of arrays, each containing only the indices of measurements observed at each event. It is of
the same shape and of a consistent order as ``dynamic_indices``.
Args:
subjects_per_output_file: How big to chunk the dataset down for writing to disk; larger values
will make fewer chunks but increase the memory cost.
do_overwrite: Whether or not to overwrite any existing file on disk.
"""
logger.info("Caching DL representations")
if subjects_per_output_file is None:
logger.warning("Sharding is recommended for DL representations.")
DL_dir = self.config.save_dir / "DL_reps"
NRT_dir = self.config.save_dir / "NRT_reps"
shards_fp = self.config.save_dir / "DL_shards.json"
if shards_fp.exists():
shards = json.loads(shards_fp.read_text())
else:
shards = {}
if subjects_per_output_file is None:
subject_chunks = [self.subject_ids]
else:
subjects = np.random.permutation(list(self.subject_ids))
subject_chunks = np.array_split(
subjects,
np.arange(subjects_per_output_file, len(subjects), subjects_per_output_file),
)
subject_chunks = [[int(x) for x in c] for c in subject_chunks]
for chunk_idx, subjects_list in enumerate(subject_chunks):
for split, subjects in self.split_subjects.items():
shard_key = f"{split}/{chunk_idx}"
included_subjects = set(subjects_list).intersection({int(x) for x in subjects})
shards[shard_key] = list(included_subjects)
shards_fp.write_text(json.dumps(shards))
for shard_key, subjects_list in self._tqdm(list(shards.items()), desc="Shards"):
DL_fp = DL_dir / f"{shard_key}.{self.DF_SAVE_FORMAT}"
DL_fp.parent.mkdir(exist_ok=True, parents=True)
if DL_fp.exists() and not do_overwrite:
logger.info(f"Skipping {DL_fp} as it already exists.")
cached_df = self._read_df(DL_fp)
else:
logger.info(f"Caching {shard_key} to {DL_fp}")
cached_df = self.build_DL_cached_representation(subject_ids=subjects_list)
self._write_df(cached_df, DL_fp, do_overwrite=do_overwrite)
NRT_fp = NRT_dir / f"{shard_key}.pt"
NRT_fp.parent.mkdir(exist_ok=True, parents=True)
if NRT_fp.exists() and not do_overwrite:
logger.info(f"Skipping {NRT_fp} as it already exists.")
else:
logger.info(f"Caching NRT for {shard_key} to {NRT_fp}")
# TODO(mmd): This breaks the API isolation a bit, as we assume polars here. But that's fine.
jnrt_dict = {
k: cached_df[k].to_list()
for k in ["time_delta", "dynamic_indices", "dynamic_measurement_indices"]
}
jnrt_dict["dynamic_values"] = (
cached_df["dynamic_values"]
.list.eval(pl.element().list.eval(pl.element().fill_null(float("nan"))))
.to_list()
)
jnrt_dict = JointNestedRaggedTensorDict(jnrt_dict)
jnrt_dict.save(NRT_fp)
@property
def vocabulary_config(self) -> VocabularyConfig:
"""Returns the implied `VocabularyConfig` object corresponding to this (fit) dataset.
This property collates vocabulary information across all measurements into a format that is concise,
but complete for downstream DL applications.
"""
measurements_per_generative_mode = defaultdict(list)
measurements_per_generative_mode[DataModality.SINGLE_LABEL_CLASSIFICATION].append("event_type")
for m, cfg in self.measurement_configs.items():
if cfg.temporality != TemporalityType.DYNAMIC:
continue
measurements_per_generative_mode[cfg.modality].append(m)
if cfg.modality == DataModality.MULTIVARIATE_REGRESSION:
measurements_per_generative_mode[DataModality.MULTI_LABEL_CLASSIFICATION].append(m)
return VocabularyConfig(
vocab_sizes_by_measurement={m: len(idxmap) for m, idxmap in self.measurement_idxmaps.items()},
vocab_offsets_by_measurement=self.unified_vocabulary_offsets,
measurements_idxmap=self.unified_measurements_idxmap,
event_types_idxmap=self.unified_vocabulary_idxmap["event_type"],
measurements_per_generative_mode=dict(measurements_per_generative_mode),
)
@property
def unified_measurements_vocab(self) -> list[str]:
"""Returns a unified vocabulary of observed measurements."""
return ["event_type"] + list(sorted(self.measurement_configs.keys()))
@property
def unified_measurements_idxmap(self) -> dict[str, int]:
"""Returns a unified idxmap of observed measurements."""
return {m: i + 1 for i, m in enumerate(self.unified_measurements_vocab)}
@property
def unified_vocabulary_offsets(self) -> dict[str, int]:
"""Returns a set of offsets detailing at what position each measurement's vocab starts."""
offsets, curr_offset = {}, 1
for m in self.unified_measurements_vocab:
offsets[m] = curr_offset
if m in self.measurement_vocabs:
curr_offset += len(self.measurement_vocabs[m])
else:
curr_offset += 1
return offsets
@property
def unified_vocabulary_idxmap(self) -> dict[str, dict[str, int]]:
"""Provides a unified idxmap spanning all measurements' vocabularies (concatenated via offsets)."""
idxmaps = {}
for m, offset in self.unified_vocabulary_offsets.items():
if m in self.measurement_idxmaps:
idxmaps[m] = {v: i + offset for v, i in self.measurement_idxmaps[m].items()}
else:
idxmaps[m] = {m: offset}
return idxmaps
@property
def unified_vocabulary_flat(self) -> list[str]:
vocab_size = max(self.unified_vocabulary_idxmap[self.unified_measurements_vocab[-1]].values()) + 1
vocab = [None for _ in range(vocab_size)]
vocab[0] = "UNK"
for m, idxmap in self.unified_vocabulary_idxmap.items():
for e, i in idxmap.items():
vocab[i] = e
return vocab
[docs]
@abc.abstractmethod
def build_DL_cached_representation(
self, subject_ids: list[int] | None = None, do_sort_outputs: bool = False
) -> DF_T:
"""Produces the deep learning format dataframe described previously for the passed
subjects:"""
raise NotImplementedError("This method must be implemented by a subclass.")
@abc.abstractmethod
def _denormalize(self, events_df: DF_T, col: str) -> DF_T:
"""Un-normalizes the column `col` in df `events_df`."""
raise NotImplementedError("This method must be implemented by a subclass.")
[docs]
def describe(
self,
do_print_measurement_summaries: bool = True,
viz_config: Visualizer | None = None,
) -> list[Figure] | None:
"""Describes the dataset, both in language and in figures."""
print(
f"Dataset has {humanize.intword(len(self.subjects_df))} subjects, "
f"with {humanize.intword(len(self.events_df))} events and "
f"{humanize.intword(len(self.dynamic_measurements_df))} measurements."
)
if do_print_measurement_summaries:
print(f"Dataset has {len(self.measurement_configs)} measurements:")
for meas, cfg in self.measurement_configs.items():
if cfg.name is None:
cfg.name = meas
cfg.describe(line_width=60)
print()
if viz_config is not None:
return self.visualize(viz_config)
[docs]
def visualize(
self,
viz_config: Visualizer,
) -> list[Figure]:
"""Visualizes the dataset, along several axes."""
if viz_config.subset_size is not None:
viz_config.subset_random_seed = self._seed(seed=viz_config.subset_random_seed, key="visualize")
if viz_config.subset_size is not None:
subject_ids = list(np.random.choice(list(self.subject_ids), viz_config.subset_size))
subjects_df = self._filter_col_inclusion(self.subjects_df, {"subject_id": subject_ids})
events_df = self._filter_col_inclusion(self.events_df, {"subject_id": subject_ids})
dynamic_measurements_df = self._filter_col_inclusion(
self.dynamic_measurements_df, {"event_id": list(events_df["event_id"])}
)
else:
subjects_df = self.subjects_df
events_df = self.events_df
dynamic_measurements_df = self.dynamic_measurements_df
if viz_config.age_col is not None:
events_df = self._denormalize(events_df, viz_config.age_col)
figs = viz_config.plot(subjects_df, events_df, dynamic_measurements_df)
return figs