Source code for EventStream.data.visualize

from __future__ import annotations

import dataclasses

import pandas as pd
import plotly.express as px
import polars as pl
from plotly.graph_objs._figure import Figure

from ..utils import JSONableMixin


[docs] @dataclasses.dataclass class Visualizer(JSONableMixin): """A visualization configuration and plotting class. This class helps visualize `Dataset` objects. It is both a configuration object and performs the actual data manipulations for final visualization, interfacing only with the `Dataset` object to obtain appropriately sampled and processed cuts of the data to visualize. It currently produces the following plots. All plots are broken down by `static_covariates`, which are covariates that are constant for each subject. ## Analyzing the data over time (only produced if `plot_by_time` is True) Given an $x$-axis of time $t$, the following plots are produced: - "Active Subjects": $y$ = the number of active subjects at time $x$ (i.e. the number of subjects who have at least one event before $t$ and have not yet had their last event at $t$). - "Cumulative Subjects": $y$ = the number of cumulative subjects at time $t$ (i.e., the number of subjects who have at least one event before $t$). - "Cumulative Events": $y$ = the number of events the dataset would obtain were it to be terminated at time $t$. - "Events / Subject": $y$ = the average number of events per subject as would be observed were the dataset to be terminated at time $t$. - "Events / (Subject, Time)": $y$ = the average rate of events per unit time per subject at time $t$ - "Age Distribution over Time": A 2D Density Heatmap plot showing the distributions of the ages of active subjects in the dataset at time $t$. Only produced if `age_col` is specified. Age is binned into `n_age_buckets` buckets. ## Analyzing the data over age (only produced if `plot_by_age` is True) Given an $x$-axis of age bucket $a$, the following plots are produced: - "Cumulative Subjects": $y$ = the number of subjects in the dataset who have an event in the age bucket $a$. - "Cumulative Events": $y$ = the number of events included in the dataset that occur at an age up to or before $a$. - "Events / Subject": $y$ = the average number of events per subject that occur when the subject is at age bucket $a$. Attributes: subset_size: When plotting, use an IID random subsample (over subjects) of the input dataset of this size. This makes plotting much faster, and is statistically unbiased, though can increase variance. subset_random_seed: If subsampling the raw data, use this random seed to control that subsampling. static_covariates: When plotting, split plots by these static covariates. plot_by_time: If `True`, also plot how the dataset changes over time. time_unit: If `plot_by_time` is `True`, aggregate timepoints into buckets of this size. plot_by_age: If `True`, plot how datasret characteristics evolve with subject age. age_col: The column in the Dataset's `events_df` where age is stored. This should typically be the name of the measurement employing the `AgeFunctor` time dependent functor object, unless age is pre-computed in the dataset. dob_col: This is used to compute ages of subjects at inferred timepoints created dynamically during plotting. This string should point to the date of birth (in datetime format) within the subjects dataframe. n_age_buckets: If `plot_by_age` is `True`, this controls how many buckets ages are discretized into to limit plot granularity. Raises: ValueError: If * `subset_size` is specified but `subset_random_seed` is not. * `plot_by_age` is `True`, but `age_col` or `n_age_buckets` is `None` * `age_col` is specified but `dob_col` is not * `plot_by_time` is `True`, but `time_unit` is None Examples: >>> V = Visualizer() >>> V = Visualizer( ... subset_size=100, subset_random_seed=1, ... plot_by_age=True, age_col='age', dob_col='dob', n_age_buckets=100, ... plot_by_time=True, time_unit='1y', ... ) >>> V = Visualizer(subset_size=100) Traceback (most recent call last): ... ValueError: subset_size is specified, but subset_random_seed is not! >>> V = Visualizer(plot_by_age=True, age_col='age', n_age_buckets=None) Traceback (most recent call last): ... ValueError: plot_by_age is True, but n_age_buckets is unspecified! >>> V = Visualizer(age_col='age') Traceback (most recent call last): ... ValueError: age_col is specified, but dob_col is not! >>> V = Visualizer(plot_by_time=True, time_unit=None) Traceback (most recent call last): ... ValueError: plot_by_time is True, but time_unit is unspecified! """ subset_size: int | None = None subset_random_seed: int | None = None static_covariates: list[str] = dataclasses.field(default_factory=list) plot_by_time: bool = True time_unit: str | None = "1y" plot_by_age: bool = False age_col: str | None = None dob_col: str | None = None n_age_buckets: int | None = 200 def __post_init__(self): if self.subset_size is not None and self.subset_random_seed is None: raise ValueError("subset_size is specified, but subset_random_seed is not!") if self.plot_by_age: if self.age_col is None: raise ValueError("plot_by_age is True, but age_col is unspecified!") if self.n_age_buckets is None: raise ValueError("plot_by_age is True, but n_age_buckets is unspecified!") if self.age_col is not None and self.dob_col is None: raise ValueError("age_col is specified, but dob_col is not!") if self.plot_by_time and self.time_unit is None: raise ValueError("plot_by_time is True, but time_unit is unspecified!") @staticmethod def _normalize_to_pandas(df: pl.DataFrame, covariate: str | None = None) -> pd.DataFrame: df = df.to_pandas() if covariate is None: return df if df[covariate].isna().any(): if "UNK" not in df[covariate].cat.categories: df[covariate] = df[covariate].cat.add_categories("UNK") df[covariate] = df[covariate].fillna("UNK") df[covariate] = df[covariate].cat.remove_unused_categories() return df
[docs] def plot_counts_over_time(self, in_events_df: pl.DataFrame) -> list[Figure]: figures = [] if not self.plot_by_time: return figures in_events_df = ( in_events_df.sort("timestamp", descending=False) .with_columns( pl.when( (pl.col("timestamp") == pl.col("start_time")) & (pl.col("timestamp") == pl.col("end_time")) ) .then(0) .when(pl.col("timestamp") == pl.col("start_time")) .then(1) .when(pl.col("timestamp") == pl.col("end_time")) .then(-1) .otherwise(0) .alias("active_subj_increment"), pl.when(pl.col("timestamp") == pl.col("start_time")) .then(1) .otherwise(0) .alias("cumulative_subj_increment"), ) .group_by_dynamic( index_column="timestamp", every=self.time_unit, by=self.static_covariates, ) .agg( pl.col("subject_id").n_unique().alias("n_subjects"), pl.col("event_id").n_unique().alias("n_events"), pl.col("active_subj_increment").sum().alias("active_subjects_delta"), pl.col("cumulative_subj_increment").sum().alias("cumulative_subjects_delta"), ) .sort("timestamp", descending=False) ) for static_covariate in self.static_covariates: plt_kwargs = {"x": "timestamp", "color": static_covariate} events_df = ( in_events_df.group_by("timestamp", static_covariate) .agg( pl.col("n_subjects").sum(), pl.col("n_events").sum(), pl.col("active_subjects_delta").sum(), pl.col("cumulative_subjects_delta").sum(), ) .with_columns( (pl.col("n_events") / pl.col("n_subjects")).alias("events_per_subject_per_time"), ) .sort("timestamp", descending=False) ) # "Active Subjects": $y$ = the number of active subjects at time $x$ (i.e. the number of subjects # who have at least one event before $t$ and have not yet had their last event at $t$). # "Cumulative Subjects": $y$ = the number of cumulative subjects at time $t$ (i.e., the number of # subjects who have at least one event before $t$). subjects_as_of_time = self._normalize_to_pandas( events_df.select( "timestamp", static_covariate, pl.col("active_subjects_delta").cumsum().over(static_covariate).alias("Active Subjects"), pl.col("cumulative_subjects_delta") .cumsum() .over(static_covariate) .alias("Cumulative Subjects"), ), static_covariate, ) figures.extend( [ px.line(subjects_as_of_time, y="Active Subjects", **plt_kwargs), px.line(subjects_as_of_time, y="Cumulative Subjects", **plt_kwargs), ] ) # "Cumulative Events": $y$ = the number of events the dataset would obtain were it to be # terminated at time $t$. # "Events / Subject": $y$ = the average number of events per subject as would be observed were the # dataset to be terminated at time $t$. # "Events / (Subject, Time)": $y$ = the average rate of events per unit time per subject at time # $t$ events_as_of_time = self._normalize_to_pandas( events_df.select( "timestamp", static_covariate, pl.col("n_events").cumsum().over(static_covariate).alias("Cumulative Events"), ( pl.col("n_events").cumsum().over(static_covariate) / pl.col("cumulative_subjects_delta").cumsum().over(static_covariate) ).alias("Average Events / Subject"), pl.col("events_per_subject_per_time").alias("New Events / Subject / time"), ), static_covariate, ) figures.extend( [ px.line(events_as_of_time, y="Cumulative Events", **plt_kwargs), px.line(events_as_of_time, y="Average Events / Subject", **plt_kwargs), px.line(events_as_of_time, y="New Events / Subject / time", **plt_kwargs), ] ) return figures
[docs] def plot_static_variables_breakdown(self, static_variables: pl.DataFrame) -> list[Figure]: figures = [] if not self.static_covariates: return for static_covariate in self.static_covariates: df = static_variables.group_by(static_covariate).agg( pl.col("subject_id").n_unique().alias("# Subjects") ) figures.append( px.bar( self._normalize_to_pandas(df, static_covariate), x=static_covariate, y="# Subjects", ) ) return figures
[docs] def plot_counts_over_age(self, events_df: pl.DataFrame) -> list[Figure]: figures = [] if not self.plot_by_age: return figures min_age = events_df[self.age_col].min() max_age = events_df[self.age_col].max() age_bucket_size = (max_age - min_age) / (self.n_age_buckets) events_df = ( events_df.with_columns( (pl.col("age") / age_bucket_size).round(0).cast(pl.Int64, strict=False).alias("age_bucket"), pl.col("subject_id").n_unique().over(*self.static_covariates).alias("total_n_subjects"), ) .drop_nulls("age_bucket") .group_by("age_bucket", *self.static_covariates) .agg( pl.col(self.age_col).mean(), pl.col("event_id").n_unique().alias("n_events"), pl.col("subject_id").n_unique().alias("n_subjects_at_age"), pl.col("total_n_subjects").first(), ) .sort(by=self.age_col, descending=False) .with_columns( pl.col("n_events").cumsum().over(*self.static_covariates).alias("Cumulative Events"), ) ) for static_covariate in self.static_covariates: plt_kwargs = {"x": self.age_col, "color": static_covariate} counts_at_age = self._normalize_to_pandas( events_df.group_by("age_bucket", static_covariate) .agg( ( (pl.col(self.age_col) * pl.col("n_subjects_at_age")).sum() / pl.col("n_subjects_at_age").sum() ).alias(self.age_col), pl.col("n_subjects_at_age").sum().alias("Subjects with Event @ Age"), pl.col("n_events").sum().alias("Events @ Age"), pl.col("Cumulative Events").sum().alias("Events <= Age"), pl.col("total_n_subjects").sum().alias("Total Subjects"), ) .with_columns( (pl.col("Subjects with Event @ Age") / pl.col("Total Subjects")).alias( "% Subjects with Event @ Age" ), (pl.col("Events @ Age") / pl.col("Subjects with Event @ Age")).alias( "Events @ Age / (Subjects with >= 1 Event @ Age)" ), (pl.col("Events @ Age") / pl.col("Total Subjects")).alias("Events @ Age / Subject"), (pl.col("Events <= Age") / pl.col("Total Subjects")).alias("Events <= Age / Subject"), ) .sort(self.age_col, descending=False), static_covariate, ) figures.extend( [ px.line(counts_at_age, y="% Subjects with Event @ Age", **plt_kwargs), px.line(counts_at_age, y="Events @ Age / Subject", **plt_kwargs), px.line(counts_at_age, y="Events <= Age / Subject", **plt_kwargs), px.line( counts_at_age, y="Events @ Age / (Subjects with >= 1 Event @ Age)", **plt_kwargs, ), ] ) return figures
[docs] def plot_events_per_patient(self, events_df: pl.DataFrame) -> list[Figure]: events_per_patient = events_df.group_by("subject_id", *self.static_covariates).agg( pl.col("event_id").n_unique().alias("# of Events") ) return [ px.histogram(self._normalize_to_pandas(events_per_patient, c), x="# of Events", color=c) for c in self.static_covariates ]
[docs] def plot( self, subjects_df: pl.DataFrame, events_df: pl.DataFrame, dynamic_measurements_df: pl.DataFrame, ) -> list[Figure]: subj_ranges = events_df.group_by("subject_id").agg( pl.col("timestamp").min().alias("start_time"), pl.col("timestamp").max().alias("end_time"), ) static_variables = subj_ranges.join( subjects_df.select("subject_id", *self.static_covariates), on="subject_id" ) events_df = events_df.join(static_variables, on="subject_id") figs = [] figs.extend(self.plot_static_variables_breakdown(static_variables)) figs.extend(self.plot_counts_over_time(events_df)) figs.extend(self.plot_counts_over_age(events_df)) figs.extend(self.plot_events_per_patient(events_df)) return figs