"""Various configuration classes for EventStream data objects."""
from __future__ import annotations
import dataclasses
import enum
import hashlib
import json
import random
from collections import OrderedDict, defaultdict
from collections.abc import Hashable, Sequence
from io import StringIO, TextIOBase
from pathlib import Path
from textwrap import shorten, wrap
from typing import Any, Union
import omegaconf
import pandas as pd
from loguru import logger
from ..utils import (
COUNT_OR_PROPORTION,
PROPORTION,
JSONableMixin,
StrEnum,
hydra_dataclass,
num_initial_spaces,
)
from .time_dependent_functor import AgeFunctor, TimeDependentFunctor, TimeOfDayFunctor
from .types import DataModality, InputDataType, InputDFType, TemporalityType
from .vocabulary import Vocabulary
# Represents the type for a column name in a dataframe.
DF_COL = Union[str, Sequence[str]]
# Represents the type of an input column during pre-processing.
INPUT_COL_T = Union[InputDataType, tuple[InputDataType, str]]
# A unified type for a schema of an input dataframe.
DF_SCHEMA = Union[
# For cases where you specify a list of columns of a constant type.
tuple[list[DF_COL], INPUT_COL_T],
# For specifying a single column and type.
tuple[DF_COL, INPUT_COL_T],
# For specifying a dictionary of columns to types.
dict[DF_COL, INPUT_COL_T],
# For specifying a dictionary of column in names to column out names and types.
dict[DF_COL, tuple[str, INPUT_COL_T]],
# For specifying a dictionary of column in names to out names, all of a constant type.
tuple[dict[DF_COL, str], INPUT_COL_T],
]
[docs]
@dataclasses.dataclass
class DatasetSchema(JSONableMixin):
"""Represents the schema of an input dataset, including static and dynamic data sources.
Contains the information necessary for extracting and pulling input dataset elements during a
pre-processing pipeline. Inputs can be represented in either structured (typed) or plain (dictionary)
form. There can only be one static schema currently, but arbitrarily many dynamic measurement schemas.
During pre-processing the model will read all these dynamic input datasets and combine their outputs into
the appropriate format. This can be written to or read from JSON files via the `JSONableMixin` base class
methods.
Attributes:
static: The schema for the input dataset containing static (per-subject) information, in either object
or dict form.
dynamic: A list of schemas for all dynamic dataset schemas, each in either object or dict form.
Raises:
ValueError: If the static schema is `None`, if there is not a subject ID column specified in the
static schema, if the passed "static" schema is not typed as a static schema, or if any dynamic
schema is typed as a static schema.
Examples:
>>> DatasetSchema(dynamic=[])
Traceback (most recent call last):
...
ValueError: Must specify a static schema!
>>> DatasetSchema(
... static=dict(type="event", event_type="foo", input_df="/path/to/df.csv", ts_col="col"),
... dynamic=[]
... )
Traceback (most recent call last):
...
ValueError: Must pass a static schema config for static.
>>> DatasetSchema(
... static=dict(type="static", input_df="/path/to/df.csv", subject_id_col="col"),
... dynamic=[dict(type="static", input_df="/path/to/df.csv", subject_id_col="col")]
... )
Traceback (most recent call last):
...
ValueError: Must pass dynamic schemas in self.dynamic!
>>> DS = DatasetSchema(
... static=dict(type="static", input_df="/path/to/df.csv", subject_id_col="col"),
... dynamic=[
... dict(type="event", event_type="foo", input_df="/path/to/foo.csv", ts_col="col"),
... dict(type="event", event_type="bar", input_df="/path/to/bar.csv", ts_col="col"),
... dict(type="event", event_type="bar2", input_df="/path/to/bar.csv", ts_col="col2"),
... ],
... )
>>> DS.dynamic_by_df # doctest: +NORMALIZE_WHITESPACE
{'/path/to/foo.csv': [InputDFSchema(input_df='/path/to/foo.csv', type='event', event_type='foo',
subject_id_col='col', ts_col='col')], '/path/to/bar.csv': [InputDFSchema(input_df='/path/to/bar.csv',
type='event', event_type='bar', subject_id_col='col', ts_col='col'),
InputDFSchema(input_df='/path/to/bar.csv', type='event', event_type='bar2', subject_id_col='col',
ts_col='col2')]}
"""
static: dict[str, Any] | InputDFSchema | None = None
dynamic: list[InputDFSchema | dict[str, Any]] = dataclasses.field(default_factory=list)
def __post_init__(self):
if self.static is None:
raise ValueError("Must specify a static schema!")
if type(self.static) is dict:
self.static = InputDFSchema(**self.static)
if not self.static.is_static:
raise ValueError("Must pass a static schema config for static.")
if self.dynamic is not None:
new_dynamic = []
for v in self.dynamic:
if type(v) is dict:
v = InputDFSchema(**v)
v.subject_id_col = self.static.subject_id_col
new_dynamic.append(v)
if v.is_static:
raise ValueError("Must pass dynamic schemas in self.dynamic!")
self.dynamic = new_dynamic
self.dynamic_by_df = defaultdict(list)
for v in self.dynamic:
self.dynamic_by_df[v.input_df].append(v)
self.dynamic_by_df = {k: v for k, v in self.dynamic_by_df.items()}
[docs]
@dataclasses.dataclass
class VocabularyConfig(JSONableMixin):
"""Dataclass that describes the vocabulary of a dataset, for initializing model parameters.
This does not configure a vocabulary, but rather describes the vocabulary learned during dataset
pre-processing for an entire dataset. This description includes the sizes of all per-measurement
vocabularies (where measurements without a vocabulary, such as univariate regression measurements) are
omitted as their vocabularies have size 1, vocabulary offsets per measurement, which detail how the
various vocabularies are stuck together to form a unified vocabulary, the indices of each global
measurement type, the generative modes used by each measurement, and the event type indices.
Attributes:
vocab_sizes_by_measurement: A dictionary mapping measurements to their respective vocabulary sizes.
vocab_offsets_by_measurement: A dictionary mapping measurements to their respective vocabulary
offsets.
measurements_idxmap: A dictionary mapping measurements to their integer indices.
measurements_per_generative_mode: A dictionary mapping data modality to a list of measurements.
event_types_idxmap: A dictionary mapping event types to their respective indices.
"""
vocab_sizes_by_measurement: dict[str, int] | None = None
vocab_offsets_by_measurement: dict[str, int] | None = None
measurements_idxmap: dict[str, dict[Hashable, int]] | None = None
measurements_per_generative_mode: dict[DataModality, list[str]] | None = None
event_types_idxmap: dict[str, int] | None = None
@property
def total_vocab_size(self) -> int:
"""Returns the total vocab size of the vocabulary described here.
The total vocabulary size is the sum of (1) all the individual measurement vocabularies' sizes, (2)
any offset the global vocabulary has from 0, to account for padding indices, and (3) any measurements
who have length-1 vocabularies (which are not included in `vocab_sizes_by_measurement`) as is
reflected by elements in the vocab offsets dictionary that aren't in the vocab sizes dictionary.
Examples:
>>> config = VocabularyConfig(
... vocab_sizes_by_measurement={"measurement1": 10, "measurement2": 3},
... vocab_offsets_by_measurement={"measurement1": 5, "measurement2": 15, "measurement3": 18}
... )
>>> config.total_vocab_size
19
"""
return (
sum(self.vocab_sizes_by_measurement.values())
+ min(self.vocab_offsets_by_measurement.values())
+ (len(self.vocab_offsets_by_measurement) - len(self.vocab_sizes_by_measurement))
)
[docs]
class SeqPaddingSide(StrEnum):
"""Enumeration for the side of sequence padding during PyTorch Batch construction."""
RIGHT = enum.auto()
"""Pad on the right side (at the end of the sequence).
This is the default during normal training.
"""
LEFT = enum.auto()
"""Pad on the left side (at the beginning of the sequence).
This is the default during generation.
"""
[docs]
class SubsequenceSamplingStrategy(StrEnum):
"""Enumeration for subsequence sampling strategies.
When the maximum allowed sequence length for a PyTorchDataset is shorter than the sequence length of a
subject's data, this enumeration dictates how we sample a subsequence to include.
"""
TO_END = enum.auto()
"""Sample subsequences of the maximum length up to the end of the permitted window.
This is the default during fine-tuning and with task dataframes.
"""
FROM_START = enum.auto()
"""Sample subsequences of the maximum length from the start of the permitted window."""
RANDOM = enum.auto()
"""Sample subsequences of the maximum length randomly within the permitted window.
This is the default during pre-training.
"""
[docs]
@hydra_dataclass
class PytorchDatasetConfig(JSONableMixin):
"""Configuration options for building a PyTorch dataset from a `Dataset`.
This is the main configuration object for a `PytorchDataset`. The `PytorchDataset` class specializes the
representation of the data in a base `Dataset` class for sequential deep learning. This dataclass is also
an acceptable `Hydra Structured Config`_ object with the name "pytorch_dataset_config".
.. _Hydra Structured Config: https://hydra.cc/docs/tutorials/structured_config/intro/
Attributes:
save_dir: Directory where the base dataset, including the deep learning representation outputs, is
saved.
max_seq_len: Maximum sequence length the dataset should output in any individual item.
min_seq_len: Minimum sequence length required to include a subject in the dataset.
seq_padding_side: Whether to pad smaller sequences on the right or the left.
subsequence_sampling_strategy: Strategy for sampling subsequences when an individual item's total
sequence length in the raw data exceeds the maximum allowed sequence length.
train_subset_size: If the training data should be subsampled randomly, this specifies the size of the
training subset. If `None` or "FULL", then the full training data is used.
train_subset_seed: If the training data should be subsampled randomly, this specifies the seed for
that random subsampling.
tuning_subset_size: If the tuning data should be subsampled randomly, this specifies the size of the
tuning subset. If `None` or "FULL", then the full tuning data is used.
tuning_subset_seed: If the tuning data should be subsampled randomly, this specifies the seed for
that random subsampling.
task_df_name: If the raw dataset should be limited to a task dataframe view, this specifies the name
of the task dataframe, and indirectly the path on disk from where that task dataframe will be
read (save_dir / "task_dfs" / f"{task_df_name}.parquet").
do_include_subject_id: Whether or not to include the subject ID of the individual for this batch.
do_include_subsequence_indices: Whether or not to include the start and end indices of the sampled
subsequence for the individual from their full dataset for this batch. This is sometimes used
during generative-based evaluation.
do_include_start_time_min: Whether or not to include the start time of the individual's sequence in
minutes since the epoch (1/1/1970) in the output data. This is necessary during generation, and
not used anywhere else currently.
Raises:
ValueError: If 'seq_padding_side' is not a valid value; If 'min_seq_len' is not a non-negative
integer; If 'max_seq_len' is not an integer greater or equal to 'min_seq_len'; If
'train_subset_seed' is not None when 'train_subset_size' is None or 'FULL'; If 'train_subset_size'
is negative when it's an integer; If 'train_subset_size' is not within (0, 1) when it's a float.
TypeError: If 'train_subset_size' is of unrecognized type.
Examples:
>>> config = PytorchDatasetConfig(
... save_dir='./dataset',
... max_seq_len=256,
... min_seq_len=2,
... seq_padding_side=SeqPaddingSide.RIGHT,
... subsequence_sampling_strategy=SubsequenceSamplingStrategy.RANDOM,
... train_subset_size="FULL",
... train_subset_seed=None,
... task_df_name=None,
... do_include_start_time_min=False
... )
>>> config_dict = config.to_dict()
>>> new_config = PytorchDatasetConfig.from_dict(config_dict)
>>> config == new_config
True
>>> config = PytorchDatasetConfig(train_subset_size=-1)
Traceback (most recent call last):
...
ValueError: If integral, train_subset_size must be positive! Got -1
>>> config = PytorchDatasetConfig(train_subset_size=1.2)
Traceback (most recent call last):
...
ValueError: If float, train_subset_size must be in (0, 1)! Got 1.2
>>> config = PytorchDatasetConfig(train_subset_size='200')
Traceback (most recent call last):
...
TypeError: train_subset_size is of unrecognized type <class 'str'>.
>>> import sys
>>> from loguru import logger
>>> logger.remove()
>>> _ = logger.add(sys.stdout, format="{message}")
>>> config = PytorchDatasetConfig(
... save_dir='./dataset',
... max_seq_len=256,
... min_seq_len=2,
... seq_padding_side='left',
... subsequence_sampling_strategy=SubsequenceSamplingStrategy.RANDOM,
... train_subset_size=100,
... train_subset_seed=None,
... task_df_name=None,
... do_include_start_time_min=False
... )
train_subset_size is set, but train_subset_seed is not. Setting to...
>>> assert config.train_subset_seed is not None
"""
save_dir: Path = omegaconf.MISSING
max_seq_len: int = 256
min_seq_len: int = 2
seq_padding_side: SeqPaddingSide = SeqPaddingSide.RIGHT
subsequence_sampling_strategy: SubsequenceSamplingStrategy = SubsequenceSamplingStrategy.RANDOM
train_subset_size: int | float | str = "FULL"
train_subset_seed: int | None = None
tuning_subset_size: int | float | str = "FULL"
tuning_subset_seed: int | None = None
task_df_name: str | None = None
do_include_subsequence_indices: bool = False
do_include_subject_id: bool = False
do_include_start_time_min: bool = False
# Trades off between speed/disk/mem and support
cache_for_epochs: int = 1
def __post_init__(self):
if self.cache_for_epochs is None:
self.cache_for_epochs = 1
if self.subsequence_sampling_strategy != "random" and self.cache_for_epochs > 1:
raise ValueError(
f"It does not make sense to cache for {self.cache_for_epochs} with non-random "
"subsequence sampling."
)
if self.seq_padding_side not in SeqPaddingSide.values():
raise ValueError(f"seq_padding_side invalid; must be in {', '.join(SeqPaddingSide.values())}")
if type(self.min_seq_len) is not int or self.min_seq_len < 0:
raise ValueError(f"min_seq_len must be a non-negative integer; got {self.min_seq_len}")
if type(self.max_seq_len) is not int or self.max_seq_len < self.min_seq_len:
raise ValueError(
f"max_seq_len must be an integer at least equal to min_seq_len; got {self.max_seq_len} "
f"(min_seq_len = {self.min_seq_len})"
)
if type(self.save_dir) is str and self.save_dir != omegaconf.MISSING:
self.save_dir = Path(self.save_dir)
match self.train_subset_size:
case int() as n if n < 0:
raise ValueError(f"If integral, train_subset_size must be positive! Got {n}")
case float() as frac if frac <= 0 or frac >= 1:
raise ValueError(f"If float, train_subset_size must be in (0, 1)! Got {frac}")
case int() | float() if (self.train_subset_seed is None):
seed = int(random.randint(1, int(1e6)))
logger.warning(f"train_subset_size is set, but train_subset_seed is not. Setting to {seed}")
self.train_subset_seed = seed
case None | "FULL" if self.train_subset_seed is not None:
logger.info(f"Removing train subset seed as train subset size is {self.train_subset_size}")
self.train_subset_seed = None
case None | "FULL" | int() | float():
pass
case _:
raise TypeError(f"train_subset_size is of unrecognized type {type(self.train_subset_size)}.")
match self.tuning_subset_size:
case int() as n if n < 0:
raise ValueError(f"If integral, tuning_subset_size must be positive! Got {n}")
case float() as frac if frac <= 0 or frac >= 1:
raise ValueError(f"If float, tuning_subset_size must be in (0, 1)! Got {frac}")
case int() | float() if (self.tuning_subset_seed is None):
seed = int(random.randint(1, int(1e6)))
print(f"WARNING! tuning_subset_size is set, but tuning_subset_seed is not. Setting to {seed}")
self.tuning_subset_seed = seed
case None | "FULL" | int() | float():
pass
case _:
raise TypeError(
f"tuning_subset_size is of unrecognized type {type(self.tuning_subset_size)}."
)
[docs]
def to_dict(self) -> dict:
"""Represents this configuration object as a plain dictionary."""
as_dict = dataclasses.asdict(self)
as_dict["save_dir"] = str(as_dict["save_dir"])
return as_dict
[docs]
@classmethod
def from_dict(cls, as_dict: dict) -> PytorchDatasetConfig:
"""Creates a new instance of this class from a plain dictionary."""
as_dict["save_dir"] = Path(as_dict["save_dir"])
return cls(**as_dict)
@property
def vocabulary_config_fp(self) -> Path:
return self.save_dir / "vocabulary_config.json"
@property
def vocabulary_config(self) -> VocabularyConfig:
return VocabularyConfig.from_json_file(self.vocabulary_config_fp)
@property
def measurement_config_fp(self) -> Path:
return self.save_dir / "inferred_measurement_configs.json"
@property
def measurement_configs(self) -> dict[str, MeasurementConfig]:
with open(self.measurement_config_fp) as f:
measurement_configs = {k: MeasurementConfig.from_dict(v) for k, v in json.load(f).items()}
return {k: v for k, v in measurement_configs.items() if not v.is_dropped}
@property
def DL_reps_dir(self) -> Path:
return self.save_dir / "DL_reps"
@property
def cached_task_dir(self) -> Path | None:
if self.task_df_name is None:
return None
else:
return self.save_dir / "DL_reps" / "for_task" / self.task_df_name
@property
def raw_task_df_fp(self) -> Path | None:
if self.task_df_name is None:
return None
else:
return self.save_dir / "task_dfs" / f"{self.task_df_name}.parquet"
@property
def task_info_fp(self) -> Path | None:
if self.task_df_name is None:
return None
else:
return self.cached_task_dir / "task_info.json"
@property
def _data_parameters_and_hash(self) -> tuple[dict[str, Any], str]:
params = sorted(
(
"save_dir",
"max_seq_len",
"min_seq_len",
"seq_padding_side",
"subsequence_sampling_strategy",
"train_subset_size",
"train_subset_seed",
"task_df_name",
)
)
params_list = []
for p in params:
v = str(getattr(self, p))
if (p == "train_subset_seed") and (self.train_subset_size in ("FULL", None)):
v = None
params_list.append((p, v))
params = tuple(params_list)
h = hashlib.blake2b(digest_size=8)
h.update(str(params).encode())
return {k: v for k, v in params}, h.hexdigest()
@property
def tensorized_cached_dir(self) -> Path:
if self.task_df_name is None:
base_dir = self.DL_reps_dir / "tensorized_cached"
else:
base_dir = self.cached_task_dir
return base_dir / self._data_parameters_and_hash[1]
@property
def _cached_data_parameters_fp(self) -> Path:
return self.tensorized_cached_dir / "data_parameters.json"
def _cache_data_parameters(self):
self._cached_data_parameters_fp.parent.mkdir(exist_ok=True, parents=True)
with open(self._cached_data_parameters_fp, mode="w") as f:
logger.info(f"Saving data parameters to {self._cached_data_parameters_fp}")
json.dump(self._data_parameters_and_hash[0], f)
[docs]
def tensorized_cached_files(self, split: str) -> dict[str, Path]:
if not (self.tensorized_cached_dir / split).is_dir():
return {}
return {fp.stem: fp for fp in (self.tensorized_cached_dir / split).glob("*.npz")}
[docs]
@dataclasses.dataclass
class MeasurementConfig(JSONableMixin):
"""The Configuration class for a measurement in the Dataset.
A measurement is any observation in the dataset; be it static or dynamic, categorical or continuous. This
class contains configuration options to define a measurement and dictate how it should be pre-processed,
embedded, and generated in generative models.
Attributes:
name:
Stores the name of this measurement; also the column in the appropriate internal dataframe
(`subjects_df`, `events_df`, or `dynamic_measurements_df`) that will contain this measurement. All
measurements will have this set.
The 'column' linkage has slightly different meanings depending on `self.modality`:
* If `modality == DataModality.UNIVARIATE_REGRESSION`, then this column stores the values
associated with this continuous-valued measure.
* If `modality == DataModality.MULTIVARIATE_REGRESSION`, then this column stores the keys that
dictate the dimensions for which the associated `values_column` has the values.
* Otherwise, this column stores the categorical values of this measure.
Similarly, it has slightly different meanings depending on `self.temporality`:
* If `temporality == TemporalityType.STATIC`, this is an existent column in the `subjects_df`
dataframe.
* If `temporality == TemporalityType.DYNAMIC`, this is an existent column in the
`dynamic_measurements_df` dataframe.
* Otherwise, (when `temporality == TemporalityType.FUNCTIONAL_TIME_DEPENDENT`), then this is
the name the *output-to-be-created* column will take in the `events_df` dataframe.
modality: The modality of this measurement. If `DataModality.UNIVARIATE_REGRESSION`, then this
measurement takes on single-variate continuous values. If `DataModality.MULTIVARIATE_REGRESSION`,
then this measurement consists of key-value pairs of categorical covariate identifiers and
continuous values. Keys are stored in the column reflected in `self.name` and values in
`self.values_column`.
temporality: How this measure varies in time. If `TemporalityType.STATIC`, this is a static
measurement. If `TemporalityType.FUNCTIONAL_TIME_DEPENDENT`, then this measurement is a
time-dependent measure that varies with time and static data in an analytically computable manner
(e.g., age). If `TemporalityType.DYNAMIC`, then this is a measurement that varies in time in a
non-a-priori computable manner.
observation_rate_over_cases: The fraction of valid "instances" in which this measure is observed at
all. For example, for a static measurement, this is the fraction of subjects for which this
measure is observed to take on a non-null value at least once. For a dynamic measurement, this is
the fraction of events for which this measure is observed to take on a non-null value at least
once. This is set dynamically during pre-procesisng, and not specified at construction.
observation_rate_per_case: The number of times this measure is observed to take on a non-null value
per possible valid "instance" where at least one measure is observed. For example, for a static
measurement, this is the number of times this measure is observed per subject when
this measure is observed at all. For a dynamic measurement, this is the number of times this
measure is observed per event when this measure is observed at all. This is set dynamically during
pre-procesisng, and not specified at construction.
functor: If `temporality == TemporalityType.FUNCTIONAL_TIME_DEPENDENT`, then this will be set to the
functor used to compute the value of a known-time-depedency measure. In this case, `functor` must
be a subclass of `data.time_dependent_functor.TimeDependentFunctor`. If `temporality` is anything
else, then this will be `None`.
vocabulary: The vocabulary for this column, realized as a `Vocabulary` object. Begins with `'UNK'`.
Not set on `modality==UNIVARIATE_REGRESSION` measurements.
values_column: For `modality==MULTIVARIATE_REGRESSION` measurements, this will store the name of the
column which will contain the numerical values corresponding to this measurement. Otherwise will
be `None`.
measurement_metadata: Stores metadata about the numerical values corresponding to this measurement.
This can take one of two forms, depending on the measurement modality. If
`modality==UNIVARIATE_REGRESSION`, then this will be a `pd.Series` whose index will contain the
set of possible column headers listed below. If `modality==MULTIVARIATE_REGRESSION`, then this
will be a `pd.DataFrame`, whose index will contain the possible regression covariate identifier
keys and whose columns will contain the set of possible columns listed below.
Metadata Columns:
* drop_lower_bound: A lower bound such that values either below or at or below this level will
be dropped (key presence will be retained for multivariate regression measures). Optional.
* drop_lower_bound_inclusive: This must be set if `drop_lower_bound` is set. If this is true,
then values will be dropped if they are $<=$ `drop_lower_bound`. If it is false, then values
will be dropped if they are $<$ `drop_lower_bound`.
* censor_lower_bound: A lower bound such that values either below or at or below this level,
but above the level of `drop_lower_bound`, will be replaced with the value
`censor_lower_bound`. Optional.
* drop_upper_bound An upper bound such that values either above or at or above this level will
be dropped (key presence will be retained for multivariate regression measures). Optional.
* drop_upper_bound_inclusive: This must be set if `drop_upper_bound` is set. If this is true,
then values will be dropped if they are $>=$ `drop_upper_bound`. If it is false, then values
will be dropped if they are $>$ `drop_upper_bound`.
* censor_upper_bound: An upper bound such that values either above or at or above this level,
but below the level of `drop_upper_bound`, will be replaced with the value
`censor_upper_bound`. Optional.
* value_type: To which kind of value (e.g., integer, categorical, float) this key corresponds.
Must be an element of the enum `NumericMetadataValueType`. Optional. If not pre-specified,
will be inferred from the data.
* thresh_large: The learned upper bound for inlier values.
* thresh_small: The learned lower bound for inlier values.
* mean: The mean to which values will be standardized.
* std: The standard deviation to which values will be standardized.
modifiers: Stores a list of additional column names that modify this measurement that should be
tracked with this measurement record through the dataset.
Raises:
ValueError: If the configuration is not self consistent (e.g., a functor specified on a
non-functional_time_dependent measure).
NotImplementedError: If the configuration relies on a measurement configuration that is not yet
supported, such as numeric, static measurements.
Examples:
>>> cfg = MeasurementConfig(
... name='key',
... modality='multi_label_classification',
... temporality='dynamic',
... vocabulary=Vocabulary(['foo', 'bar', 'baz'], [0.3, 0.4, 0.3]),
... )
>>> cfg.is_numeric
False
>>> cfg.is_dropped
False
>>> cfg = MeasurementConfig(
... name='key',
... modality='univariate_regression',
... temporality='dynamic',
... _measurement_metadata=pd.Series([1, 0.2], index=['censor_upper_bound', 'censor_lower_bound']),
... )
>>> cfg.is_numeric
True
>>> cfg.is_dropped
False
>>> cfg = MeasurementConfig(
... name='key',
... modality='multivariate_regression',
... temporality='dynamic',
... values_column='vals',
... _measurement_metadata=pd.DataFrame(
... {'censor_lower_bound': [1, 0.2, 0.1]},
... index=pd.Index(['foo', 'bar', 'baz'], name='key'),
... ),
... vocabulary=Vocabulary(['foo', 'bar', 'baz'], [0.3, 0.4, 0.3]),
... )
>>> cfg.is_numeric
True
>>> cfg.is_dropped
False
>>> cfg = MeasurementConfig(
... name='key',
... modality='multi_label_classification',
... temporality='dynamic',
... modifiers=['foo', 'bar'],
... )
>>> cfg = MeasurementConfig(
... name='key',
... modality='multi_label_classification',
... temporality='dynamic',
... modifiers=[1, 2],
... )
Traceback (most recent call last):
...
ValueError: `self.modifiers` must be a list of strings; got element 1.
>>> MeasurementConfig()
Traceback (most recent call last):
...
ValueError: `self.temporality = None` Invalid! Must be in static, dynamic, functional_time_dependent
>>> MeasurementConfig(
... temporality=TemporalityType.FUNCTIONAL_TIME_DEPENDENT,
... functor=None,
... )
Traceback (most recent call last):
...
ValueError: functor must be set for functional_time_dependent measurements!
>>> MeasurementConfig(
... temporality=TemporalityType.STATIC,
... functor=AgeFunctor(dob_col="date_of_birth"),
... )
Traceback (most recent call last):
...
ValueError: functor should be None for static measurements! Got ...
>>> MeasurementConfig(
... temporality=TemporalityType.DYNAMIC,
... modality=DataModality.MULTIVARIATE_REGRESSION,
... _measurement_metadata=pd.Series([1, 10], index=['censor_lower_bound', 'censor_upper_bound']),
... values_column='vals',
... )
Traceback (most recent call last):
...
ValueError: If set, measurement_metadata must be a DataFrame on a multivariate_regression\
MeasurementConfig. Got <class 'pandas.core.series.Series'>
censor_lower_bound 1
censor_upper_bound 10
dtype: int64
"""
FUNCTORS = {
"AgeFunctor": AgeFunctor,
"TimeOfDayFunctor": TimeOfDayFunctor,
}
PREPROCESSING_METADATA_COLUMNS = OrderedDict(
{
"value_type": str,
"mean": float,
"std": float,
"thresh_small": float,
"thresh_large": float,
}
)
# Present in all measures
name: str | None = None
temporality: TemporalityType | None = None
modality: DataModality | None = None
observation_rate_over_cases: float | None = None
observation_rate_per_case: float | None = None
# Specific to time-dependent measures
functor: TimeDependentFunctor | None = None
# Specific to categorical or partially observed multivariate regression measures.
vocabulary: Vocabulary | None = None
# Specific to numeric measures
values_column: str | None = None
_measurement_metadata: pd.DataFrame | pd.Series | str | Path | None = None
modifiers: list[str] | None = None
def __post_init__(self):
self._validate()
def _validate(self):
"""Checks the internal state of `self` and ensures internal consistency and validity."""
match self.temporality:
case TemporalityType.STATIC:
if self.functor is not None:
raise ValueError(
f"functor should be None for {self.temporality} measurements! Got {self.functor}"
)
if self.is_numeric:
raise NotImplementedError(
f"Numeric data modalities like {self.modality} not yet supported on static measures."
)
case TemporalityType.DYNAMIC:
if self.functor is not None:
raise ValueError(
f"functor should be None for {self.temporality} measurements! Got {self.functor}"
)
if self.modality == DataModality.SINGLE_LABEL_CLASSIFICATION:
raise ValueError(
f"{self.modality} on {self.temporality} measurements is not currently supported, as "
"event aggregation can turn single-label tasks into multi-label tasks in a manner "
"that is not currently automatically detected or compensated for."
)
case TemporalityType.FUNCTIONAL_TIME_DEPENDENT:
if self.functor is None:
raise ValueError(f"functor must be set for {self.temporality} measurements!")
if self.modality is None:
self.modality = self.functor.OUTPUT_MODALITY
elif self.modality not in (DataModality.DROPPED, self.functor.OUTPUT_MODALITY):
raise ValueError(
"self.modality must either be DataModality.DROPPED or "
f"{self.functor.OUTPUT_MODALITY} for {self.temporality} measures; got {self.modality}"
)
case _:
raise ValueError(
f"`self.temporality = {self.temporality}` Invalid! Must be in "
f"{', '.join(TemporalityType.values())}"
)
err_strings = []
match self.modality:
case DataModality.MULTIVARIATE_REGRESSION:
if self.values_column is None:
err_strings.append(f"values_column must be set on a {self.modality} MeasurementConfig")
if (self.measurement_metadata is not None) and not isinstance(
self.measurement_metadata, pd.DataFrame
):
err_strings.append(
f"If set, measurement_metadata must be a DataFrame on a {self.modality} "
f"MeasurementConfig. Got {type(self.measurement_metadata)}\n"
f"{self.measurement_metadata}"
)
case DataModality.UNIVARIATE_REGRESSION:
if self.values_column is not None:
err_strings.append(
f"values_column must be None on a {self.modality} MeasurementConfig. "
f"Got {self.values_column}"
)
if (self.measurement_metadata is not None) and not isinstance(
self.measurement_metadata, pd.Series
):
err_strings.append(
f"If set, measurement_metadata must be a Series on a {self.modality} "
f"MeasurementConfig. Got {type(self.measurement_metadata)}\n"
f"{self.measurement_metadata}"
)
case DataModality.SINGLE_LABEL_CLASSIFICATION | DataModality.MULTI_LABEL_CLASSIFICATION:
if self.values_column is not None:
err_strings.append(
f"values_column must be None on a {self.modality} MeasurementConfig. "
f"Got {self.values_column}"
)
if self._measurement_metadata is not None:
err_strings.append(
f"measurement_metadata must be None on a {self.modality} MeasurementConfig. "
f"Got {type(self.measurement_metadata)}\n{self.measurement_metadata}"
)
case DataModality.DROPPED:
if self.vocabulary is not None:
err_strings.append(
f"vocabulary must be None on a {self.modality} MeasurementConfig. "
f"Got {self.vocabulary}"
)
if self._measurement_metadata is not None:
err_strings.append(
f"measurement_metadata must be None on a {self.modality} MeasurementConfig. "
f"Got {type(self.measurement_metadata)}\n{self.measurement_metadata}"
)
case _:
raise ValueError(f"`self.modality = {self.modality}` Invalid!")
if err_strings:
raise ValueError("\n".join(err_strings))
if self.modifiers is not None:
for mod in self.modifiers:
if not isinstance(mod, str):
raise ValueError(f"`self.modifiers` must be a list of strings; got element {mod}.")
[docs]
def drop(self):
"""Sets the modality to DROPPED and does associated post-processing to ensure validity.
Examples:
>>> cfg = MeasurementConfig(
... name='key',
... modality='multivariate_regression',
... temporality='dynamic',
... values_column='vals',
... _measurement_metadata=pd.DataFrame(
... {'censor_lower_bound': [1, 0.2, 0.1]},
... index=pd.Index(['foo', 'bar', 'baz'], name='key'),
... ),
... vocabulary=Vocabulary(['foo', 'bar', 'baz'], [0.3, 0.4, 0.3]),
... )
>>> cfg.drop()
>>> cfg.modality
<DataModality.DROPPED: 'dropped'>
>>> assert cfg._measurement_metadata is None
>>> assert cfg.vocabulary is None
>>> assert cfg.is_dropped
"""
self.modality = DataModality.DROPPED
self._measurement_metadata = None
self.vocabulary = None
@property
def is_dropped(self) -> bool:
return self.modality == DataModality.DROPPED
@property
def is_numeric(self) -> bool:
return self.modality in (
DataModality.MULTIVARIATE_REGRESSION,
DataModality.UNIVARIATE_REGRESSION,
)
@property
def measurement_metadata(self) -> pd.DataFrame | pd.Series | None:
match self._measurement_metadata:
case None | pd.DataFrame() | pd.Series():
return self._measurement_metadata
case [(Path() | str()) as base_dir, str() as fn]:
fp = Path(base_dir) / fn
case (Path() | str()) as fp:
fp = Path(fp)
case _:
raise ValueError(f"_measurement_metadata is invalid! Got {type(self._measurement_metadata)}!")
out = pd.read_csv(fp, index_col=0)
if self.modality == DataModality.UNIVARIATE_REGRESSION:
if out.shape[1] != 1:
raise ValueError(
f"For {self.modality}, measurement metadata at {fp} should be a series, but "
f"it has shape {out.shape} (expecting out.shape[1] == 1)!"
)
out = out.iloc[:, 0]
elif self.modality != DataModality.MULTIVARIATE_REGRESSION:
raise ValueError(
"Only DataModality.UNIVARIATE_REGRESSION and DataModality.MULTIVARIATE_REGRESSION "
f"measurements should have measurement metadata paths stored. Got {fp} on "
f"{self.modality} measurement!"
)
return out
@measurement_metadata.setter
def measurement_metadata(self, new_metadata: pd.DataFrame | pd.Series | None):
if new_metadata is None:
self._measurement_metadata = None
return
match self._measurement_metadata:
case [Path() as base_dir, str() as fn]:
new_metadata.to_csv(base_dir / fn)
case Path() | str() as fp:
new_metadata.to_csv(fp)
case _:
self._measurement_metadata = new_metadata
[docs]
def to_dict(self) -> dict:
"""Represents this configuration object as a plain dictionary."""
as_dict = dataclasses.asdict(self)
match self._measurement_metadata:
case pd.DataFrame():
as_dict["_measurement_metadata"] = self.measurement_metadata.to_dict(orient="tight")
case pd.Series():
as_dict["_measurement_metadata"] = self.measurement_metadata.to_dict(into=OrderedDict)
case Path():
as_dict["_measurement_metadata"] = str(self._measurement_metadata)
if self.temporality == TemporalityType.FUNCTIONAL_TIME_DEPENDENT:
as_dict["functor"] = self.functor.to_dict()
if as_dict.get("vocabulary", None) is not None:
as_dict["vocabulary"]["obs_frequencies"] = [
float(x) for x in as_dict["vocabulary"]["obs_frequencies"]
]
return as_dict
[docs]
@classmethod
def from_dict(cls, as_dict: dict, base_dir: Path | None = None) -> MeasurementConfig:
"""Build a configuration object from a plain dictionary representation."""
if as_dict["vocabulary"] is not None:
as_dict["vocabulary"] = Vocabulary(**as_dict["vocabulary"])
match as_dict["_measurement_metadata"], as_dict["modality"]:
case None, _:
pass
case str() as full_path, _:
full_path = Path(full_path)
if full_path.parts[-2] == "inferred_measurement_metadata":
prior_base_dir = "/".join(full_path.parts[:-2])
relative_path = "/".join(full_path.parts[-2:])
else:
raise ValueError(f"Can't process old path format of {full_path}")
if base_dir is not None:
as_dict["_measurement_metadata"] = [base_dir, relative_path]
else:
as_dict["_measurement_metadata"] = [str(prior_base_dir), relative_path]
case [str() as prior_base_dir, str() as relative_path], _:
if base_dir is not None:
as_dict["_measurement_metadata"] = [base_dir, relative_path]
else:
as_dict["_measurement_metadata"] = [str(prior_base_dir), relative_path]
case dict(), DataModality.MULTIVARIATE_REGRESSION:
as_dict["_measurement_metadata"] = pd.DataFrame.from_dict(
as_dict["_measurement_metadata"], orient="tight"
)
case dict(), DataModality.UNIVARIATE_REGRESSION:
as_dict["_measurement_metadata"] = pd.Series(as_dict["_measurement_metadata"])
case _:
raise ValueError(
f"{as_dict['_measurement_metadata']} and {as_dict['modality']} incompatible!"
)
if as_dict["functor"] is not None:
if as_dict["temporality"] != TemporalityType.FUNCTIONAL_TIME_DEPENDENT:
raise ValueError(
"Only TemporalityType.FUNCTIONAL_TIME_DEPENDENT measures can have functors. Got "
f"{as_dict['temporality']}"
)
as_dict["functor"] = cls.FUNCTORS[as_dict["functor"]["class"]].from_dict(as_dict["functor"])
return cls(**as_dict)
def __eq__(self, other: MeasurementConfig) -> bool:
return self.to_dict() == other.to_dict()
[docs]
def describe(
self, line_width: int = 60, wrap_lines: bool = False, stream: TextIOBase | None = None
) -> int | None:
"""Provides a plain-text description of the measurement.
Prints the following information about the MeasurementConfig object:
1. The measurement's name, temporality, modality, and observation frequency.
2. What value types (e.g., integral, float, etc.) it's values take on, if the measurement is a
numerical modality whose values may take on distinct value types.
3. Details about its internal `self.vocabulary` object, via `Vocabulary.describe`.
Args:
line_width: The maximum width of each line in the description.
wrap_lines: Whether to wrap lines that exceed the `line_width`.
stream: The stream to write the description to. If `None`, the description is printed to stdout.
Returns:
The number of characters written to the stream if a stream was provided, otherwise `None`.
Raises:
ValueError: if the calling object is misconfigured.
Examples:
>>> vocab = Vocabulary(
... vocabulary=['apple', 'banana', 'pear', 'UNK'],
... obs_frequencies=[3, 4, 1, 2],
... )
>>> cfg = MeasurementConfig(
... name="MVR",
... values_column='bar',
... temporality='dynamic',
... modality='multivariate_regression',
... observation_rate_over_cases=0.6816,
... observation_rate_per_case=1.32,
... _measurement_metadata=pd.DataFrame(
... {'value_type': ['float', 'categorical', 'categorical']},
... index=pd.Index(['apple', 'pear', 'banana'], name='MVR'),
... ),
... vocabulary=vocab,
... )
>>> cfg.describe(line_width=100)
MVR: dynamic, multivariate_regression observed 68.2%, 1.3/case on average
Value Types:
2 categorical
1 float
Vocabulary:
4 elements, 20.0% UNKs
Frequencies: █▆▁
Elements:
(40.0%) banana
(30.0%) apple
(10.0%) pear
>>> cfg.modality = 'wrong'
>>> cfg.describe()
Traceback (most recent call last):
...
ValueError: Can't describe wrong measure MVR!
"""
lines = []
lines.append(
f"{self.name}: {self.temporality}, {self.modality} "
f"observed {100*self.observation_rate_over_cases:.1f}%, "
f"{self.observation_rate_per_case:.1f}/case on average"
)
match self.modality:
case DataModality.UNIVARIATE_REGRESSION:
lines.append(f"Value is a {self.measurement_metadata.value_type}")
case DataModality.MULTIVARIATE_REGRESSION:
lines.append("Value Types:")
for t, cnt in self.measurement_metadata.value_type.value_counts().items():
lines.append(f" {cnt} {t}")
case DataModality.MULTI_LABEL_CLASSIFICATION:
pass
case DataModality.SINGLE_LABEL_CLASSIFICATION:
pass
case _:
raise ValueError(f"Can't describe {self.modality} measure {self.name}!")
if self.vocabulary is not None:
SIO = StringIO()
self.vocabulary.describe(line_width=line_width - 2, stream=SIO, wrap_lines=wrap_lines)
lines.append("Vocabulary:")
lines.extend(f" {line}" for line in SIO.getvalue().split("\n"))
line_indents = [num_initial_spaces(line) for line in lines]
if wrap_lines:
lines = [
wrap(line, width=line_width, initial_indent="", subsequent_indent=(" " * ind))
for line, ind in zip(lines, line_indents)
]
else:
lines = [
shorten(line, width=line_width, initial_indent=(" " * ind))
for line, ind in zip(lines, line_indents)
]
desc = "\n".join(lines)
if stream is None:
print(desc)
return
return stream.write(desc)
[docs]
@dataclasses.dataclass
class DatasetConfig(JSONableMixin):
"""Configuration options for a Dataset class.
This is the core configuration object for Dataset objects. Contains configuration options for
pre-processing a dataset already in the "Subject-Events-Measurements" data model or interpreting an
existing dataset. This configures details such as
1. Which measurements should be extracted and included in the raw dataset, via the `measurement_configs`
arg.
2. What filtering parameters should be applied to eliminate infrequently observed variables or columns.
3. How/whether or not numerical values should be re-cast as categorical or integral types.
4. Configuration options for outlier detector or normalization models.
5. Time aggregation controls.
6. The output save directory.
These configuration options do not include options to extract the raw dataset from source. For options for
raw dataset extraction, see `DatasetSchema` and `InputDFSchema`, and for options for the raw script
builder, see `configs/dataset_base.yml`.
Attributes:
measurement_configs: The dataset configuration for this `Dataset`. Keys are measurement names, and
values are `MeasurementConfig` objects detailing configuration parameters for that measure.
Measurement names / dictionary keys are also used as source columns for the data of that measure,
though in the case of `DataModality.MULTIVARIATE_REGRESSION` measures, this name will reference
the categorical regression target index column and the config will also contain a reference to a
values column name which points to the column containing the associated numerical values.
Columns not referenced in any configs are not pre-processed. Measurement configs are checked for
validity upon creation. Dictionary keys must match measurement config object names if such are
specified; if measurement config object names are not specified, they will be set to their
associated dictionary keys.
min_valid_column_observations: The minimum number of column observations or proportion of possible
events that contain a column that must be observed for the column to be included in the training
set. If fewer than this many observations are observed, the entire column will be dropped. Can be
either an integer count or a proportion (of total vocabulary size) in (0, 1). If `None`, no
constraint is applied.
min_valid_vocab_element_observations: The minimum number or proportion of observations of a particular
metadata vocabulary element that must be observed for the element to be included in the training
set vocabulary. If fewer than this many observations are observed, observed elements will be
dropped. Can be either an integer count or a proportion (of total vocabulary size) in (0, 1). If
`None`, no constraint is applied.
min_true_float_frequency: The minimum proportion of true float values that must be observed in order
for observations to be treated as true floating point numbers, not integers.
min_unique_numerical_observations: The minimum number of unique values a numerical column must have in
the training set to be treated as a numerical type (rather than an implied categorical or ordinal
type). Numerical entries with fewer than this many observations will be converted to categorical
or ordinal types. Can be either an integer count or a proportion (of total numerical observations)
in (0, 1). If `None`, no constraint is applied.
outlier_detector_config: Configuration options for outlier detection. If not `None`, must contain the
key `'cls'`, which points to the class used outlier detection. All other keys and values are
keyword arguments to be passed to the specified class. The API of these objects is expected to
mirror scikit-learn outlier detection model APIs. If `None`, numerical outlier values are not
removed.
center_and_scale: Whether or not to center and scale numerical values.
save_dir: The output save directory for this dataset. Will be converted to a `pathlib.Path` upon
creation if it is not already one.
agg_by_time_scale: Aggregate events into temporal buckets at this frequency. Uses the string language
described here:
https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.group_by_dynamic.html
Raises:
ValueError: If configuration parameters are invalid (e.g., proportion parameters being > 1, etc.).
TypeError: If configuration parameters are of invalid types.
Examples:
>>> cfg = DatasetConfig(
... measurement_configs={
... "meas1": MeasurementConfig(
... temporality=TemporalityType.DYNAMIC,
... modality=DataModality.MULTI_LABEL_CLASSIFICATION,
... ),
... },
... min_valid_column_observations=0.5,
... save_dir="/path/to/save/dir",
... )
>>> cfg.save_dir
PosixPath('/path/to/save/dir')
>>> cfg.to_dict() # doctest: +NORMALIZE_WHITESPACE
{'measurement_configs':
{'meas1':
{'name': 'meas1',
'temporality': <TemporalityType.DYNAMIC: 'dynamic'>,
'modality': <DataModality.MULTI_LABEL_CLASSIFICATION: 'multi_label_classification'>,
'observation_rate_over_cases': None,
'observation_rate_per_case': None,
'functor': None,
'vocabulary': None,
'values_column': None,
'_measurement_metadata': None,
'modifiers': None}},
'min_events_per_subject': None,
'agg_by_time_scale': '1h',
'min_valid_column_observations': 0.5,
'min_valid_vocab_element_observations': None,
'min_true_float_frequency': None,
'min_unique_numerical_observations': None,
'outlier_detector_config': None,
'center_and_scale': True,
'save_dir': '/path/to/save/dir'}
>>> cfg2 = DatasetConfig.from_dict(cfg.to_dict())
>>> assert cfg == cfg2
>>> DatasetConfig(
... measurement_configs={
... "meas1": MeasurementConfig(
... name="invalid_name",
... temporality=TemporalityType.DYNAMIC,
... modality=DataModality.MULTI_LABEL_CLASSIFICATION,
... ),
... },
... )
Traceback (most recent call last):
...
ValueError: Measurement config meas1 has name invalid_name which differs from dict key!
>>> DatasetConfig(
... min_valid_column_observations="invalid type"
... )
Traceback (most recent call last):
...
TypeError: min_valid_column_observations must either be a fraction (float between 0 and 1) or count\
(int > 1). Got <class 'str'> of invalid type
>>> measurement_configs = {
... "meas1": MeasurementConfig(
... temporality=TemporalityType.DYNAMIC,
... modality=DataModality.MULTI_LABEL_CLASSIFICATION,
... ),
... }
>>> # Make one of the measurements invalid to show that validitiy is re-checked...
>>> measurement_configs["meas1"].temporality = None
>>> DatasetConfig(
... measurement_configs=measurement_configs,
... min_valid_column_observations=0.5,
... save_dir="/path/to/save/dir",
... )
Traceback (most recent call last):
...
ValueError: Measurement config meas1 invalid!
"""
measurement_configs: dict[str, MeasurementConfig] = dataclasses.field(default_factory=lambda: {})
min_events_per_subject: int | None = None
agg_by_time_scale: str | None = "1h"
min_valid_column_observations: COUNT_OR_PROPORTION | None = None
min_valid_vocab_element_observations: COUNT_OR_PROPORTION | None = None
min_true_float_frequency: PROPORTION | None = None
min_unique_numerical_observations: COUNT_OR_PROPORTION | None = None
outlier_detector_config: dict[str, Any] | None = None
center_and_scale: bool = True
save_dir: Path | None = None
def __post_init__(self):
"""Validates that parameters take on valid values."""
for name, cfg in self.measurement_configs.items():
if cfg.name is None:
cfg.name = name
elif cfg.name != name:
raise ValueError(
f"Measurement config {name} has name {cfg.name} which differs from dict key!"
)
for var in (
"min_valid_column_observations",
"min_valid_vocab_element_observations",
"min_unique_numerical_observations",
):
val = getattr(self, var)
match val:
case None:
pass
case float() if (0 < val) and (val < 1):
pass
case int() if val > 1:
pass
case float():
raise ValueError(f"{var} must be in (0, 1) if float; got {val}!")
case int():
raise ValueError(f"{var} must be > 1 if integral; got {val}!")
case _:
raise TypeError(
f"{var} must either be a fraction (float between 0 and 1) or count (int > 1). Got "
f"{type(val)} of {val}"
)
for var in ("min_true_float_frequency",):
val = getattr(self, var)
match val:
case None:
pass
case float() if (0 < val) and (val < 1):
pass
case float():
raise ValueError(f"{var} must be in (0, 1) if float; got {val}!")
case _:
raise TypeError(
f"{var} must be a fraction (float between 0 and 1). Got {type(val)} of {val}"
)
for var in ("outlier_detector_config",):
val = getattr(self, var)
if val is not None and (type(val) is not dict):
raise ValueError(f"{var} must be either None or a dictionary! Got {val}")
for k, v in self.measurement_configs.items():
try:
v._validate()
except Exception as e:
raise ValueError(f"Measurement config {k} invalid!") from e
if type(self.save_dir) is str:
self.save_dir = Path(self.save_dir)
[docs]
def to_dict(self) -> dict:
"""Represents this configuration object as a plain dictionary.
Returns:
A plain dictionary representation of self (nested through measurement configs as well).
"""
as_dict = dataclasses.asdict(self)
if self.save_dir is not None:
as_dict["save_dir"] = str(self.save_dir.absolute())
as_dict["measurement_configs"] = {k: v.to_dict() for k, v in self.measurement_configs.items()}
return as_dict
[docs]
@classmethod
def from_dict(cls, as_dict: dict) -> DatasetConfig:
"""Build a configuration object from a plain dictionary representation.
Args:
as_dict: The plain dictionary representation to be converted.
Returns: A DatasetConfig instance containing the same data as `as_dict`.
"""
as_dict["measurement_configs"] = {
k: MeasurementConfig.from_dict(v) for k, v in as_dict["measurement_configs"].items()
}
if type(as_dict["save_dir"]) is str:
as_dict["save_dir"] = Path(as_dict["save_dir"])
return cls(**as_dict)
def __eq__(self, other: DatasetConfig) -> bool:
"""Returns true if self and other are equal."""
return self.to_dict() == other.to_dict()