"""
A collection of utils for scikit-survival. Currently, the module focussed on
downstream extracton of predictions and outcomes, as well as on tools to help
with model validation.
The code can likely be generalised further to work with non-sksurv models as
well.
"""
# imports
import warnings
import pandas as pd
import numpy as np
from typing import Callable
from sksurv.metrics import (
cumulative_dynamic_auc,
integrated_brier_score,
)
from sksurv.nonparametric import kaplan_meier_estimator
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def surv_model_auc(model:Callable|None, times: list[float],
data: tuple[str, np.ndarray, np.ndarray],
train_y: np.ndarray) -> pd.DataFrame:
"""
Computes time-dependent AUCs for survival models using cumulative
dynamic AUC.
Parameters
----------
model : `callable` or `NoneType`
A fitted survival model with a `predict(X)` method. If `None`,
assumes `d_tup[1]` already contains predicted risk scores.
times : `list` [`float`]
Time points at which to evaluate the time-dependent AUC.
data : `tuple` [`str`, `np.ndarray`, `np.ndarray`]
The tuple should contain (label, test_y, test_x or predicted_risk),
where `test_y` is a structured array of survival data
(e.g., from `sksurv.util.Surv.from_arrays`), and `label` is a
descriptor of the data split (e.g., 'test', 'validation').
train_y : np.ndarray
Structured array of survival outcomes for training data, used to
construct the risk sets.
Returns
-------
pd.DataFrame
A long-format DataFrame with columns:
- "Data split": label of the data subset
- "Time": time points evaluated
- "AUC by time": corresponding AUC values
- "Mean AUC": mean of AUCs across time points (NaNs ignored)
Notes
-----
This function evaluates the performance of a survival model across
multiple time points using the cumulative/dynamic AUC approach
(as implemented in `cumulative_dynamic_auc` from `sksurv.metrics`).
It handles multiple data splits or datasets and aggregates the results
into a tidy DataFrame.
Examples
--------
>>> from sksurv.util import Surv
>>> import numpy as np
>>> train_y = Surv.from_arrays([True, False], [10, 20])
>>> test_y = Surv.from_arrays([True, False], [5, 25])
>>> pred = np.array([0.1, 0.4])
>>> data = ("test", test_y, pred)
>>> times = [5, 10, 15]
>>> surv_model_auc(None, times, data, train_y)
"""
# ### check input
# TODO
# #### calculations
# do we need to make predictions
if model is not None:
pred = model.predict(data[2])
else:
pred = data[2]
# the auc across time
per_time_auc = []
for t in times:
try:
auc_array, _ = cumulative_dynamic_auc(
train_y, data[1], pred, [t]
)
per_time_auc.append(auc_array[0])
except ValueError:
# NOTE add a warning
per_time_auc.append(np.nan)
# the average auc
mean_auc = np.nanmean(per_time_auc)
# results dict
res_dict = {
"Data split" : data[0],
"Time" : times,
"AUC by time" : per_time_auc,
"Mean AUC" : mean_auc
}
# return
return pd.DataFrame(res_dict).set_index('Time')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def surv_model_predict(model:Callable, X: np.ndarray, times: list[float],
) -> pd.DataFrame:
"""
Predicted event probabilities at specific follow-up times using a
scikit-survival model.
Parameters
----------
model : `BaseSurvivalModel`
A fitted scikit-survival model (e.g., CoxPHSurvivalAnalysis,
RandomSurvivalForest).
X : `np.ndarray`
Feature data for prediction.
times : `list` [`float`]
Time points at which to estimate risk, in the same units used
for training the model.
Returns
-------
pd.DataFrame
A DataFrame with shape (n_samples, len(times)), where each
column corresponds to the predicted risk (1 - survival
probability) at a given time.
"""
# Check input
# TODO
# actual work
surv_funcs = model.predict_survival_function(X)
return pd.DataFrame(
{
f"Predicted risk at {t}": [1.0 - fn(t) for fn in surv_funcs]
for t in times
}
)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def event_by_time(Y: np.ndarray, times: list[float],
event_name:str='event',
time_name:str='time',
) -> pd.DataFrame:
"""
Determines whether an event has occurred by specified time points.
Parameters
----------
Y : `np.ndarray`
Structured array with fields ('event', 'time'), typically created
using sksurv.util.Surv.from_arrays.
times : Sequence[float]
Time points at which to assess event occurrence.
event_name : `str`, default `event`
Name for the event column in `y`.
time_name : `str`, default `time`
Name for the time column in `y`.
Returns
-------
pd.DataFrame
A DataFrame of shape (n_samples, len(times)) with boolean values
indicating whether the event had occurred by each time point.
"""
# Check input
# TODO
# actual work
result = {
f"Event at {t}": (Y[event_name] & (Y[time_name] <= t)).astype(int)
for t in times
}
return pd.DataFrame(result)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def surv_model_calibration_table(model: Callable, X: np.ndarray,
Y: np.ndarray, times: list[float],
) -> pd.DataFrame:
"""
Creates a table of predicted event probability and actual observed outcomes
by time.
Parameters
----------
model : `callable`
A fitted scikit-survival model supporting
`predict_survival_function(X)`.
X : `pd.DataFrame`
Feature data for prediction.
Y : `np.ndarray`
Structured array with fields ('event', 'time'), typically
created using sksurv.util.Surv.from_arrays.
times : `list` [`float`]
Time points at which to compute risks and determine event status.
Returns
-------
pd.DataFrame
A DataFrame combining predicted risk probabilities and binary
indicators of event occurrence by each time point.
"""
risks = surv_model_predict(model, X, times)
events = event_by_time(Y, times)
return pd.concat([risks, events], axis=1)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# NOTE think about how to generalise this without needing to dependent on
# model, bassically one would need an array of predicted survival rate by time dims:
# subjects by time, with elements containing the survival rate.
[docs]
class SurvivalModelBrierScore(object):
"""
Evaluates a sklearn-survival model using the integrated Brier score,
and computes the baseline (non-informative) Brier score based on the event
incidence in supplied test data.
Parameters
----------
model : `Callable`
A fitted sklearn-survival model that implements a
`predict_survival_function` method returning survival functions for
individuals.
times : `np.ndarray`
A sequence of time points over which the integrated Brier score is
computed.
Attributes
----------
model : `Callable`
The survival model used for prediction including the method
`predict_survival_function`
times : `np.ndarray`
An array of time points used to evaluate the Brier score.
baseline_brier_score : `float` or `NoneType`
The Brier score of a non-informative model that always predicts the
event incidence (π(1 - π)), calculated from the test data.
ibs : `float` or `NoneType`
The integrated Brier score of the model on the test data.
bss : `float` or `NoneType`
The Brier skills score, representing thr eation of the ibs and the
baseline_brier_score.
Methods
-------
evaluate(Y_train, Y_test, X_test )
Computes the integrated Brier score and baseline Brier score based on
predictions over the specified time grid.
get_brier_skill_score()
Returns the Brier skill score (BSS), which quantifies improvement
over the baseline Brier score.
"""
# /////////////////////////////////////////////////////////////////////////
def __init__(self, model:Callable, times:np.ndarray):
self.model = model
self.times = np.asarray(times)
self.baseline_brier_score = None
self.ibs = None
self.bss = None
# /////////////////////////////////////////////////////////////////////////
def __call__(self, Y_train, Y_test, X_test):
"""
Computes the integrated Brier score on the test data using the model,
and calculates the baseline Brier score from the event incidence in
the test set.
Parameters
----------
Y_train : `np.ndarray`
Survival data for training (event indicator and time).
Y_test : `np.ndarray`
Survival data for evaluation.
X_test : `np.ndarray`
Features for evaluation.
Returns
-------
float
The integrated Brier score.
Notes
-----
The method also updates the attributes `ibs` and `baseline_brier_score`.
The baseline Brier score is calculated as π(1 - π), where π is the
observed event incidence in `Y_test`.
See Also
--------
get_brier_skill_score : Returns the Brier Skill Score (BSS), measuring
relative performance to the baseline.
"""
# get predictions
surv_preds = self.model.predict_survival_function(X_test)
preds = np.asarray([
[fn(t) for t in self.times] for fn in surv_preds
])
# get the brier score
self.ibs = integrated_brier_score(Y_train, Y_test, preds, self.times)
# calculate the baseline brier score
event_indicators = np.asarray([event for event, _ in Y_test])
incidence = np.mean(event_indicators)
self.baseline_brier_score = incidence * (1 - incidence)
# return
return self.ibs
# /////////////////////////////////////////////////////////////////////////
[docs]
def get_brier_skill_score(self):
"""
Computes the Brier Skill Score (BSS), which represents the improvement
of the model over the baseline (non-informative) predictor.
Raises
------
AttributeError
Raises an error if the `__call__` method has not run.
Returns
-------
float
The Brier Skill Score
"""
if self.ibs is None or self.baseline_brier_score is None:
raise AttributeError('Please first run the __call__ method.')
# estimate the BSS
self.bss = 1 - (self.ibs / self.baseline_brier_score)
# return
return self.bss
# /////////////////////////////////////////////////////////////////////////
def __repr__(self):
CLASS_NAME = type(self).__name__
return (f"{CLASS_NAME}(model={self.model.__class__.__name__}, "
f"n_times={len(self.times)})"
)
# /////////////////////////////////////////////////////////////////////////
def __str__(self):
CLASS_NAME = type(self).__name__
lines = [
f"{CLASS_NAME} for model: {self.model.__class__.__name__}",
f" Time points: {len(self.times)} evaluated",
]
if self.ibs is not None:
lines.append(f" Integrated Brier Score (IBS): {self.ibs:.4f}")
else:
lines.append(" IBS not yet evaluated.")
if self.baseline_brier_score is not None:
lines.append(
f" Baseline Brier Score (π(1 - π)): "
f"{self.baseline_brier_score:.4f}"
)
if self.bss is not None:
lines.append(
f" The Brier Skill Score: "
f"{self.bss:.4f}"
)
# return
return "\n".join(lines)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# THink about how to generalise this without using model
[docs]
class SurvivalModelCalibration(object):
"""
Creates groups based on the predicted survival and compared the predicted
event rate to the non-parametric event rate.
Parameters
----------
model : `Callable`
A fitted scikit-survival model with a `predict_survival_function`
method.
n_groups : `int`, default 5
Number of equally sized participant groups, created based on the
predicted survival.
nonparametric_estimator : `Callable`, default KaplanMeierEstimator
A nonparametric estimator class from `sksurv.nonparametric`, e.g.
`KaplanMeierEstimator` or `CumulativeIncidenceEstimator`.
Attributes
----------
group_summary_ : dict [int, dict]
summary statistics for each group: predicted survival and
observed estimate.
"""
# /////////////////////////////////////////////////////////////////////////
def __str__(self):
cls = type(self).__name__
lines = [
f"{cls} with model: {self.model.__class__.__name__}",
f" Number of groups: {self.n_groups}",
f" Estimator: {self.nonparametric_estimator.__name__}",
]
if self.group_summary_:
lines.append(f" Group summary available for "
f"{len(self.group_summary_)} group(s).")
else:
lines.append(" Group summary not yet computed.")
return "\n".join(lines)
# /////////////////////////////////////////////////////////////////////////
def __repr__(self):
cls = type(self).__name__
model_name = self.model.__class__.__name__
return (f"{cls}(model={model_name}, "
f"n_groups={self.n_groups}, "
f"estimator={self.nonparametric_estimator.__name__})")
# /////////////////////////////////////////////////////////////////////////
def __init__(self, model:Callable, n_groups: int = 5,
nonparametric_estimator: Callable = kaplan_meier_estimator,
):
self.model = model
self.n_groups = n_groups
self.nonparametric_estimator = nonparametric_estimator
self.group_summary_: dict[int, dict] = {}
# /////////////////////////////////////////////////////////////////////////
def __call__(self, X:np.ndarray, durations:np.ndarray, events:np.ndarray,
verbose:bool=True,
) -> dict[int, dict]:
"""
Compute the predicted and non-parametric survival curves per group.
Groups individuals by predicted risk into `n_groups`,
then computes both the mean predicted survival and non-parametric
survival for each group.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Feature matrix for prediction (must match training features).
durations : array-like of shape (n_samples,)
Follow-up time for each individual (event or censoring time).
events : array-like of shape (n_samples,)
Event indicator: 1 if event occurred, 0 if censored.
Returns
-------
dict [int, dict]
A dictionary with the group ids mapped to the keys and the
nested dict containing information on:
- `predicted time`,
- `predicted survival`,
- `observed time`,
- `observed survival`.
"""
# survival rate
self.surv_rate = self.model.predict_survival_function(X)
# create groups
average = [fn.y.mean() for fn in self.surv_rate]
quantiles = np.percentile(average,
np.linspace(0, 100, self.n_groups + 1))
self.groups = np.digitize(average, quantiles[1:-1], right=True)
# getting the survival times
times = np.unique(np.concatenate([fn.x for fn in self.surv_rate]))
# clean summary dict
self.group_summary_ = {}
# ### getting esimates
for i in range(self.n_groups):
idx = self.groups == i
# skip if there are now entries for a specific group
if np.sum(idx) == 0:
if verbose:
warnings.warn(
f"Risk group {i + 1} contains no individuals. "
"This may be due to tied risk scores or imbalanced "
"group sizes.",
RuntimeWarning
)
continue
# At each time point computes the average of survival probabilities
# across all individuals in the group.
pred_surv = np.mean(
[fn(times) for j, fn in enumerate(self.surv_rate) if idx[j]],
axis=0)
# The non-parameteric estimators
nonp_times, nonp_surv = self.nonparametric_estimator(
events[idx] == 1, durations[idx])
# populate results
self.group_summary_[i + 1] = {
'predicted_times': times,
'predicted_survival': pred_surv,
'observed_times': nonp_times,
'observed_survival': nonp_surv,
}
# return
return self.group_summary_