Source code for stats_misc.machine_learning.sksurv_utils

"""
A collection of utils for scikit-survival. Currently, the module focussed on
downstream extracton of predictions and outcomes, as well as on tools to help
with model validation.

The code can likely be generalised further to work with non-sksurv models as
well.
"""
# imports
import warnings
import pandas as pd
import numpy as np
from typing import Callable
from sksurv.metrics import (
    cumulative_dynamic_auc,
    integrated_brier_score,
)
from sksurv.nonparametric import kaplan_meier_estimator

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def surv_model_auc(model:Callable|None, times: list[float], data: tuple[str, np.ndarray, np.ndarray], train_y: np.ndarray) -> pd.DataFrame: """ Computes time-dependent AUCs for survival models using cumulative dynamic AUC. Parameters ---------- model : `callable` or `NoneType` A fitted survival model with a `predict(X)` method. If `None`, assumes `d_tup[1]` already contains predicted risk scores. times : `list` [`float`] Time points at which to evaluate the time-dependent AUC. data : `tuple` [`str`, `np.ndarray`, `np.ndarray`] The tuple should contain (label, test_y, test_x or predicted_risk), where `test_y` is a structured array of survival data (e.g., from `sksurv.util.Surv.from_arrays`), and `label` is a descriptor of the data split (e.g., 'test', 'validation'). train_y : np.ndarray Structured array of survival outcomes for training data, used to construct the risk sets. Returns ------- pd.DataFrame A long-format DataFrame with columns: - "Data split": label of the data subset - "Time": time points evaluated - "AUC by time": corresponding AUC values - "Mean AUC": mean of AUCs across time points (NaNs ignored) Notes ----- This function evaluates the performance of a survival model across multiple time points using the cumulative/dynamic AUC approach (as implemented in `cumulative_dynamic_auc` from `sksurv.metrics`). It handles multiple data splits or datasets and aggregates the results into a tidy DataFrame. Examples -------- >>> from sksurv.util import Surv >>> import numpy as np >>> train_y = Surv.from_arrays([True, False], [10, 20]) >>> test_y = Surv.from_arrays([True, False], [5, 25]) >>> pred = np.array([0.1, 0.4]) >>> data = ("test", test_y, pred) >>> times = [5, 10, 15] >>> surv_model_auc(None, times, data, train_y) """ # ### check input # TODO # #### calculations # do we need to make predictions if model is not None: pred = model.predict(data[2]) else: pred = data[2] # the auc across time per_time_auc = [] for t in times: try: auc_array, _ = cumulative_dynamic_auc( train_y, data[1], pred, [t] ) per_time_auc.append(auc_array[0]) except ValueError: # NOTE add a warning per_time_auc.append(np.nan) # the average auc mean_auc = np.nanmean(per_time_auc) # results dict res_dict = { "Data split" : data[0], "Time" : times, "AUC by time" : per_time_auc, "Mean AUC" : mean_auc } # return return pd.DataFrame(res_dict).set_index('Time')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def surv_model_predict(model:Callable, X: np.ndarray, times: list[float], ) -> pd.DataFrame: """ Predicted event probabilities at specific follow-up times using a scikit-survival model. Parameters ---------- model : `BaseSurvivalModel` A fitted scikit-survival model (e.g., CoxPHSurvivalAnalysis, RandomSurvivalForest). X : `np.ndarray` Feature data for prediction. times : `list` [`float`] Time points at which to estimate risk, in the same units used for training the model. Returns ------- pd.DataFrame A DataFrame with shape (n_samples, len(times)), where each column corresponds to the predicted risk (1 - survival probability) at a given time. """ # Check input # TODO # actual work surv_funcs = model.predict_survival_function(X) return pd.DataFrame( { f"Predicted risk at {t}": [1.0 - fn(t) for fn in surv_funcs] for t in times } )
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def event_by_time(Y: np.ndarray, times: list[float], event_name:str='event', time_name:str='time', ) -> pd.DataFrame: """ Determines whether an event has occurred by specified time points. Parameters ---------- Y : `np.ndarray` Structured array with fields ('event', 'time'), typically created using sksurv.util.Surv.from_arrays. times : Sequence[float] Time points at which to assess event occurrence. event_name : `str`, default `event` Name for the event column in `y`. time_name : `str`, default `time` Name for the time column in `y`. Returns ------- pd.DataFrame A DataFrame of shape (n_samples, len(times)) with boolean values indicating whether the event had occurred by each time point. """ # Check input # TODO # actual work result = { f"Event at {t}": (Y[event_name] & (Y[time_name] <= t)).astype(int) for t in times } return pd.DataFrame(result)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def surv_model_calibration_table(model: Callable, X: np.ndarray, Y: np.ndarray, times: list[float], ) -> pd.DataFrame: """ Creates a table of predicted event probability and actual observed outcomes by time. Parameters ---------- model : `callable` A fitted scikit-survival model supporting `predict_survival_function(X)`. X : `pd.DataFrame` Feature data for prediction. Y : `np.ndarray` Structured array with fields ('event', 'time'), typically created using sksurv.util.Surv.from_arrays. times : `list` [`float`] Time points at which to compute risks and determine event status. Returns ------- pd.DataFrame A DataFrame combining predicted risk probabilities and binary indicators of event occurrence by each time point. """ risks = surv_model_predict(model, X, times) events = event_by_time(Y, times) return pd.concat([risks, events], axis=1)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # NOTE think about how to generalise this without needing to dependent on # model, bassically one would need an array of predicted survival rate by time dims: # subjects by time, with elements containing the survival rate.
[docs] class SurvivalModelBrierScore(object): """ Evaluates a sklearn-survival model using the integrated Brier score, and computes the baseline (non-informative) Brier score based on the event incidence in supplied test data. Parameters ---------- model : `Callable` A fitted sklearn-survival model that implements a `predict_survival_function` method returning survival functions for individuals. times : `np.ndarray` A sequence of time points over which the integrated Brier score is computed. Attributes ---------- model : `Callable` The survival model used for prediction including the method `predict_survival_function` times : `np.ndarray` An array of time points used to evaluate the Brier score. baseline_brier_score : `float` or `NoneType` The Brier score of a non-informative model that always predicts the event incidence (π(1 - π)), calculated from the test data. ibs : `float` or `NoneType` The integrated Brier score of the model on the test data. bss : `float` or `NoneType` The Brier skills score, representing thr eation of the ibs and the baseline_brier_score. Methods ------- evaluate(Y_train, Y_test, X_test ) Computes the integrated Brier score and baseline Brier score based on predictions over the specified time grid. get_brier_skill_score() Returns the Brier skill score (BSS), which quantifies improvement over the baseline Brier score. """ # ///////////////////////////////////////////////////////////////////////// def __init__(self, model:Callable, times:np.ndarray): self.model = model self.times = np.asarray(times) self.baseline_brier_score = None self.ibs = None self.bss = None # ///////////////////////////////////////////////////////////////////////// def __call__(self, Y_train, Y_test, X_test): """ Computes the integrated Brier score on the test data using the model, and calculates the baseline Brier score from the event incidence in the test set. Parameters ---------- Y_train : `np.ndarray` Survival data for training (event indicator and time). Y_test : `np.ndarray` Survival data for evaluation. X_test : `np.ndarray` Features for evaluation. Returns ------- float The integrated Brier score. Notes ----- The method also updates the attributes `ibs` and `baseline_brier_score`. The baseline Brier score is calculated as π(1 - π), where π is the observed event incidence in `Y_test`. See Also -------- get_brier_skill_score : Returns the Brier Skill Score (BSS), measuring relative performance to the baseline. """ # get predictions surv_preds = self.model.predict_survival_function(X_test) preds = np.asarray([ [fn(t) for t in self.times] for fn in surv_preds ]) # get the brier score self.ibs = integrated_brier_score(Y_train, Y_test, preds, self.times) # calculate the baseline brier score event_indicators = np.asarray([event for event, _ in Y_test]) incidence = np.mean(event_indicators) self.baseline_brier_score = incidence * (1 - incidence) # return return self.ibs # /////////////////////////////////////////////////////////////////////////
[docs] def get_brier_skill_score(self): """ Computes the Brier Skill Score (BSS), which represents the improvement of the model over the baseline (non-informative) predictor. Raises ------ AttributeError Raises an error if the `__call__` method has not run. Returns ------- float The Brier Skill Score """ if self.ibs is None or self.baseline_brier_score is None: raise AttributeError('Please first run the __call__ method.') # estimate the BSS self.bss = 1 - (self.ibs / self.baseline_brier_score) # return return self.bss
# ///////////////////////////////////////////////////////////////////////// def __repr__(self): CLASS_NAME = type(self).__name__ return (f"{CLASS_NAME}(model={self.model.__class__.__name__}, " f"n_times={len(self.times)})" ) # ///////////////////////////////////////////////////////////////////////// def __str__(self): CLASS_NAME = type(self).__name__ lines = [ f"{CLASS_NAME} for model: {self.model.__class__.__name__}", f" Time points: {len(self.times)} evaluated", ] if self.ibs is not None: lines.append(f" Integrated Brier Score (IBS): {self.ibs:.4f}") else: lines.append(" IBS not yet evaluated.") if self.baseline_brier_score is not None: lines.append( f" Baseline Brier Score (π(1 - π)): " f"{self.baseline_brier_score:.4f}" ) if self.bss is not None: lines.append( f" The Brier Skill Score: " f"{self.bss:.4f}" ) # return return "\n".join(lines)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # THink about how to generalise this without using model
[docs] class SurvivalModelCalibration(object): """ Creates groups based on the predicted survival and compared the predicted event rate to the non-parametric event rate. Parameters ---------- model : `Callable` A fitted scikit-survival model with a `predict_survival_function` method. n_groups : `int`, default 5 Number of equally sized participant groups, created based on the predicted survival. nonparametric_estimator : `Callable`, default KaplanMeierEstimator A nonparametric estimator class from `sksurv.nonparametric`, e.g. `KaplanMeierEstimator` or `CumulativeIncidenceEstimator`. Attributes ---------- group_summary_ : dict [int, dict] summary statistics for each group: predicted survival and observed estimate. """ # ///////////////////////////////////////////////////////////////////////// def __str__(self): cls = type(self).__name__ lines = [ f"{cls} with model: {self.model.__class__.__name__}", f" Number of groups: {self.n_groups}", f" Estimator: {self.nonparametric_estimator.__name__}", ] if self.group_summary_: lines.append(f" Group summary available for " f"{len(self.group_summary_)} group(s).") else: lines.append(" Group summary not yet computed.") return "\n".join(lines) # ///////////////////////////////////////////////////////////////////////// def __repr__(self): cls = type(self).__name__ model_name = self.model.__class__.__name__ return (f"{cls}(model={model_name}, " f"n_groups={self.n_groups}, " f"estimator={self.nonparametric_estimator.__name__})") # ///////////////////////////////////////////////////////////////////////// def __init__(self, model:Callable, n_groups: int = 5, nonparametric_estimator: Callable = kaplan_meier_estimator, ): self.model = model self.n_groups = n_groups self.nonparametric_estimator = nonparametric_estimator self.group_summary_: dict[int, dict] = {} # ///////////////////////////////////////////////////////////////////////// def __call__(self, X:np.ndarray, durations:np.ndarray, events:np.ndarray, verbose:bool=True, ) -> dict[int, dict]: """ Compute the predicted and non-parametric survival curves per group. Groups individuals by predicted risk into `n_groups`, then computes both the mean predicted survival and non-parametric survival for each group. Parameters ---------- X : array-like of shape (n_samples, n_features) Feature matrix for prediction (must match training features). durations : array-like of shape (n_samples,) Follow-up time for each individual (event or censoring time). events : array-like of shape (n_samples,) Event indicator: 1 if event occurred, 0 if censored. Returns ------- dict [int, dict] A dictionary with the group ids mapped to the keys and the nested dict containing information on: - `predicted time`, - `predicted survival`, - `observed time`, - `observed survival`. """ # survival rate self.surv_rate = self.model.predict_survival_function(X) # create groups average = [fn.y.mean() for fn in self.surv_rate] quantiles = np.percentile(average, np.linspace(0, 100, self.n_groups + 1)) self.groups = np.digitize(average, quantiles[1:-1], right=True) # getting the survival times times = np.unique(np.concatenate([fn.x for fn in self.surv_rate])) # clean summary dict self.group_summary_ = {} # ### getting esimates for i in range(self.n_groups): idx = self.groups == i # skip if there are now entries for a specific group if np.sum(idx) == 0: if verbose: warnings.warn( f"Risk group {i + 1} contains no individuals. " "This may be due to tied risk scores or imbalanced " "group sizes.", RuntimeWarning ) continue # At each time point computes the average of survival probabilities # across all individuals in the group. pred_surv = np.mean( [fn(times) for j, fn in enumerate(self.surv_rate) if idx[j]], axis=0) # The non-parameteric estimators nonp_times, nonp_surv = self.nonparametric_estimator( events[idx] == 1, durations[idx]) # populate results self.group_summary_[i + 1] = { 'predicted_times': times, 'predicted_survival': pred_surv, 'observed_times': nonp_times, 'observed_survival': nonp_surv, } # return return self.group_summary_