A scalable discretetime survival model for neural networks
 Published
 Accepted
 Received
 Academic Editor
 Jun Pang
 Subject Areas
 Data Mining and Machine Learning, Data Science
 Keywords
 Survival analysis, Neural networks, Machine learning
 Copyright
 © 2019 Gensheimer and Narasimhan
 Licence
 This is an open access article distributed under the terms of the Creative Commons Attribution License, which permits unrestricted use, distribution, reproduction and adaptation in any medium and for any purpose provided that it is properly attributed. For attribution, the original author(s), title, publication source (PeerJ) and either DOI or URL of the article must be cited.
 Cite this article
 2019. A scalable discretetime survival model for neural networks. PeerJ 7:e6257 https://doi.org/10.7717/peerj.6257
Abstract
There is currently great interest in applying neural networks to prediction tasks in medicine. It is important for predictive models to be able to use survival data, where each patient has a known followup time and event/censoring indicator. This avoids information loss when training the model and enables generation of predicted survival curves. In this paper, we describe a discretetime survival model that is designed to be used with neural networks, which we refer to as Nnetsurvival. The model is trained with the maximum likelihood method using minibatch stochastic gradient descent (SGD). The use of SGD enables rapid convergence and application to large datasets that do not fit in memory. The model is flexible, so that the baseline hazard rate and the effect of the input data on hazard probability can vary with followup time. It has been implemented in the Keras deep learning framework, and source code for the model and several examples is available online. We demonstrate the performance of the model on both simulated and real data and compare it to existing models Coxnnet and Deepsurv.
Introduction
With the popularization of deep learning and the increasing size of medical datasets, there has been increasing interest in the use of machine learning to improve medical care. Several recent papers have described use of neural network or other machine learning techniques to predict future clinical outcomes (Rajkomar et al., 2018; Kwong et al., 2017; Miotto et al., 2016; Avati et al., 2017). The outcome measure is generally evaluated at one followup time point, and there is often little discussion of how to deal with censored data (e.g., patients lost to followup before the followup time point). This is not ideal, as information about censored patients is lost and the model would need to be retrained to make predictions at different time points. Because of these issues, modern predictive models generally use Cox proportional hazards regression or a parametric survival model instead of simpler methods such as logistic regression that discard timetoevent information (Cooney, Dudina & Graham, 2009).
Several authors have described solutions for modeling timetoevent data with neural networks. These are generally adaptations of linear models such as the Cox proportional hazards model (Cox, 1972). Approaches include a discretetime survival model with a heuristic loss function (Brown, Branford & Moran, 1997), a parametric model with predicted survival time having a Weibull distribution (Martinsson, 2016), and adaptations of the Cox proportional hazards model (Faraggi & Simon, 1995; Ching, Zhu & Garmire, 2018; Katzman et al., 2018). Most of the models assume proportional hazards (the effect of each predictor variable is the same at all values of followup time). This is not a very realistic assumption for most clinical situations. In the past, when models were typically trained using dozens or hundreds of patients, it was often not possible to demonstrate violation of proportional hazards. However, in the modern era of datasets of thousands or millions of patients, it will usually be possible to demonstrate violation of the proportional hazards assumption, either by plotting residuals or with a statistical test.
In this paper, we describe Nnetsurvival, a discretetime survival model that is theoretically justified, naturally deals with nonproportional hazards, and is trained rapidly by minibatch gradient descent. It may be useful in several situations, especially when nonproportional hazards are known to be present, for very large datasets that do not fit in memory, or when predictor data is a good fit for a neural network approach (such as image or text data). We have published source code for the use of the model with the Keras deep learning library, which is available at http://github.com/MGensheimer/nnetsurvival.
Materials & Methods
Relationship to prior work
In this section we describe prior approaches to the problem and illustrate some pitfalls that are addressed with our model.
Several authors have adapted the Cox proportional hazards model to neural networks (Faraggi & Simon, 1995; Ching, Zhu & Garmire, 2018; Katzman et al., 2018). This is potentially attractive since the Cox model has been shown to be very useful and is familiar to most medical researchers. One issue with this approach is that the partial likelihood for each individual depends not only on the model output for that individual, but also on the output for all individuals with longer survival. This would preclude the use of stochastic gradient descent (SGD) since with SGD only a small number of individuals are visible to the model at a time. Therefore, the entire dataset would need to be used for each gradient descent step. This is undesirable because it slows down convergence, cannot be applied to datasets that do not fit into memory (“outofcore learning”), and could result in getting stuck in a local minimum of the loss function (Bottou, 1991).
An alternative approach that avoids the above issue is to use a fully parametric survival model, such as a discrete time model. See Section 7.5 of Rodriguez (2016) for a brief overview of discrete time survival models. Brown et al. proposed a discretetime survival model using neural networks (Brown, Branford & Moran, 1997). This model can easily be trained with SGD, which is attractive. Followup time is divided into a set of fixed intervals. For each time interval the conditional hazard probability is estimated: the probability of failure in the interval, given that the individual has survived at least to the beginning of the interval. For each time interval j, the neural network loss is defined as (adapted from Eq. 17 in Brown, Branford & Moran 1997): (1)$\frac{1}{2}\sum _{i=1}^{{d}_{j}}{\left(1{h}_{j}^{i}\right)}^{2}+\frac{1}{2}\sum _{i={d}_{j}+1}^{{r}_{j}}{\left({h}_{j}^{i}\right)}^{2}$ where ${h}_{j}^{i}$ is the hazard probability for individual i during time interval j, there are r_{j} individuals “in view” during the interval j (i.e., have not experienced failure or censoring before the beginning of the interval) and the first d_{j} of them suffer a failure during this interval. The overall loss function is the sum of the losses for each time interval.
The authors note that in the case of a null model with no predictor variables, minimizing the loss in Eq. (1) results in an estimate of the hazard probabilities that equals the Kaplan–Meier maximum likelihood estimate: $\stackrel{\u02c6}{{h}_{j}}=\frac{{d}_{j}}{{r}_{j}}$. While this is true, the equivalence does not hold once each individual’s hazard depends on the value of predictor variables.
A more theoretically justified loss function, which we use in our model, would be the negative of the log likelihood function of a statistical survival model. This likelihood function has been well studied for discretetime survival models in a nondeep learning context. Adapting Eq. (3.4) from Cox & Oakes (1984) and Eq. (2.17) from Singer & Willett (1993), the contribution of time interval j to the overall log likelihood is: (2)$\sum _{i=1}^{{d}_{j}}ln\left({h}_{j}^{i}\right)+\sum _{i={d}_{j}+1}^{{r}_{j}}ln\left(1{h}_{j}^{i}\right).$
This is similar but not identical to Eq. (1) and can be shown to produce different values of the model parameters for anything more complex than the null model (for an example, see the file brown1997_loss_function_example.md in our GitHub repository).
The proposed model using Eq. (2) naturally incorporates timevarying baseline hazard rate and nonproportional hazards if each time interval output node is fully connected to the last hidden layer’s neurons. The neural network has ndimensional output where n is the number of time intervals, giving a separate hazard rate for each time interval.
There are several attractive features of the proposed model:

It is theoretically justified and fits into the established literature on survival modeling

The loss function depends only on the information contained in the current minibatch, which enables rapid training with minibatch SGD and application to arbitrarysize datasets

It is flexible and can be adapted to specific situations. For instance, for small sample size where we wish to minimize the number of neural network parameters, it is easy to incorporate a proportional hazards constraint so that the effect of the input data on the hazard function does not vary with followup time.
Model formulation
Followup time is divided into n intervals which are leftclosed and rightopen. Let [t_{1}, t_{2}, …, t_{n}] be the times at the upper limit of each interval. The conditional hazard probability h_{j} is defined as the probability of failure in interval j, given that the individual has survived at least to the beginning of the interval. h_{j} can vary per individual according to the input and the weights of the neural network. The predicted probability of an individual surviving at least to the end of interval j is: (3)${S}_{j}=\prod _{i=1}^{j}\left(1{h}_{i}\right).$
The model likelihood can be divided either by time interval as in Eq. (2), or by individual. For a neural network trained with minibatches of individuals, the latter formulation translates more easily into computer code. For an individual with failure during interval j (i.e., uncensored), the likelihood is the probability of surviving through intervals 1 through j − 1, multiplied by the probability of failing during interval j: (4)$lik={h}_{j}\prod _{i=1}^{j1}\left(1{h}_{i}\right)$ (5)$loglik=ln\left({h}_{j}\right)+\sum _{i=1}^{j1}ln\left(1{h}_{i}\right).$
For a censored individual with a censoring time t_{c} which falls in the second half of interval j − 1 or the first half of interval j (i.e., $\frac{1}{2}\left({t}_{j2}+{t}_{j1}\right)\le {t}_{c}<\frac{1}{2}\left({t}_{j1}+{t}_{j}\right)$), the likelihood is the probability of surviving through intervals 1 through j − 1: (6)$lik=\prod _{i=1}^{j1}\left(1{h}_{i}\right)$ (7)$loglik=\sum _{i=1}^{j1}ln\left(1{h}_{i}\right).$
It can be seen that individuals with a censoring time in the second half of an interval are given “credit” for surviving that interval (without this, there would be a downward bias on the survival estimates (Brown, Branford & Moran, 1997).
The full log likelihood of the observed data is the sum of the log likelihoods for each individual. In the neural network survival model, we wish to maximize the likelihood, so we set the loss to equal the negative log likelihood and minimize the loss by by stochastic gradient descent or minibatch gradient descent.
Determination of hazard probability
For each time interval, the hazard probability will vary according to the input data. We have implemented two approaches to mapping input data to hazard probabilities:
With the flexible version, the final hidden layer (e.g., the “Max pooling” layer in Fig. 1) is densely connected to the output layer (the “Fully connected” layer in Fig. 1). The output layer has n neurons, where n is the number of time intervals. The log odds of surviving each time interval is equal to the dot product of the incoming values and the kernel weights, plus the bias weight. Then, using a sigmoid activation function, log odds are converted to the conditional probability of surviving this interval. With this approach, both the baseline hazard rate and the effect of input data on the hazard rate can vary freely with followup time. This approach is most appropriate for larger datasets or when the proportional hazards assumption is known to be violated.
With the proportional hazards version, the baseline hazard probability is allowed to vary freely with time interval, but the effect of input data on hazard rate does not vary with followup time (if a certain combination of input data results in a high rate of death in the early followup period, it will also result in a high rate of death in the late followup period). This is implemented by setting the final hidden layer to have a single neuron, and densely connecting the prior hidden layer to the final hidden layer without any bias weights. The final hidden layer neuron value is Xβ, where X is the value of the prior hidden layer neurons and β is the weights between the prior hidden layer and the final hidden layer. The Xβ notation is meant to echo that of the “linear predictor” in standard survival analysis, for instance section 18.2 of Harrell Jr (2015). The conditional probability of surviving the interval i is (adapted from Eq. (18.13) in Harrell Jr (2015): (8)$1{h}_{i}={\left(1{h}_{base}\right)}^{exp\left(X\beta \right)}$ where h_{base} is the baseline hazard probability for this time interval. The h_{base} values are estimated as part of the neural network by training a set of n weights, which are each transformed by a sigmoid function to convert baseline log odds of surviving each time interval into baseline probability of survival. These sigmoidtransformed weights, along with the final hidden layer value, contribute to the ndimensional output layer according to Eq. (8). See class PropHazards in file nnet_survival.py in the GitHub repository. The proportional hazards approach is useful for small datasets where one wishes to reduce overfitting by minimizing the number of parameters to optimize. It also makes it easier to interpret the reasons for the model’s predictions. This version is very similar to a traditional proportional hazards discretetime survival model using a complementary log–log link (see Rodriguez (2016), section 7.5.3: “Discrete Survival and the CLogLog Link”).
Implementation
We implemented Nnetsurvival in the Python language, using the Keras library with Tensorflow backend (code at http://github.com/MGensheimer/nnetsurvival). A custom loss function is used which represents the negative log likelihood of the survival model. The output of the neural network is an ndimensional vector surv_{pred}, where n is the number of time intervals. Each element represents the predicted conditional probability of surviving that time interval, or 1 − h_{j}. An individual’s predicted probability of surviving through the end of time interval j is given by Eq. (3). An example neural network architecture using the “flexible” version of the discrete time survival model is shown in Fig. 1.
Each individual used to train the model has a known failure/censoring time t and censoring indicator, which are transformed into a vector format for use in the model. Vector surv_{s} has length n and represents the time intervals the individual has survived through; vector surv_{f} also has length n and represents the time interval during which failure occurred, if it occurred.
For individuals with failure (uncensored), for time interval j: (9)$sur{v}_{s}\left(j\right)=\left\{\begin{array}{cc}1,\phantom{\rule{10.00002pt}{0ex}}\hfill & \text{if}t\ge {t}_{j}\hfill \\ 0,\phantom{\rule{10.00002pt}{0ex}}\hfill & \text{otherwise}\hfill \end{array}\right.$ (10)$sur{v}_{f}\left(j\right)=\left\{\begin{array}{cc}1,\phantom{\rule{10.00002pt}{0ex}}\hfill & \text{if}{t}_{j1}\le t<{t}_{j}\hfill \\ 0,\phantom{\rule{10.00002pt}{0ex}}\hfill & \text{otherwise}\hfill \end{array}\right.$
For censored individuals: (11)$sur{v}_{s}\left(j\right)=\left\{\begin{array}{cc}1,\phantom{\rule{10.00002pt}{0ex}}\hfill & \text{if}t\ge \frac{1}{2}\left({t}_{j1}+{t}_{j}\right)\hfill \\ 0,\phantom{\rule{10.00002pt}{0ex}}\hfill & \text{otherwise}\hfill \end{array}\right.$
and (12)$sur{v}_{f}\left(j\right)=0.$
The log likelihood for each individual is: (13)$loglik=\sum _{i=1}^{n}\left(\begin{array}{c}ln\left(\right.1+sur{v}_{s}\left(i\right)\cdot \left(sur{v}_{pred}\left(i\right)1\right)\left)\right.\hfill \\ +ln\left(\right.1sur{v}_{f}\left(i\right)\cdot sur{v}_{pred}\left(i\right)\left)\right.\hfill \end{array}\right)$ which is a restatement of Eqs. (5) and (7) to work with the vector encoding of actual and predicted survival.
The loss function is the negative of Eq. (13). The loss function is minimized using gradient descent; Keras performs automatic differentiation of the loss function in order to calculate the gradient. In our experiments, using the custom loss function extended running time very slightly compared to standard loss functions such as mean squared error.
The cutpoints for the time intervals can be varied according to the specific application. In most of our experiments we have used 15–40 time intervals, spaced out more widely with increasing followup time. This ensures that around the same number of survival events fall into each time interval, which may help ensure reliable estimates for all time intervals. Other authors have suggested using at least ten time intervals to avoid bias in the survival estimates (Breslow & Crowley, 1974). In our experiments we have found that the model’s performance is fairly robust to choice of specific cutpoints.
Performance evaluation: simulated data
We ran several experiments with simulated data to assess correctness of the model. The code is available in nnet_survival_examples.py in the GitHub repository.
Simple model with one predictor variable
We first tested a very simple survival model with one binary predictor variable. Five thousand simulated patients were created. Half of the patients had predictor variable value of 0 and were the poor prognosis patients. For this group, survival times were drawn from an exponential distribution with median survival of 200 days. The other half of the patients had predictor variable value of 1 and were the good prognosis patients. Their survival times were drawn from an exponential distribution with median survival of 400 days. For both groups, some patients were censored; censoring time was drawn from an exponential distribution with median value / halflife of 400 days. This survival model used the flexible version of nnetsurvival (i.e., nonproportional hazards) with no hidden layers and 39 time intervals spanning the range of 0 to 1,780 days.
To evaluate the correctness of this model, we created calibration curves: we plotted and compared actual vs. modelpredicted survival curves for the two groups. For each of the two groups, a Kaplan–Meier survival curve was plotted to show actual survival. Then, for each group, a modelpredicted survival curve was generated: for each followup time point, the average of predicted survival for all patients in that group was calculated and displayed.
Optimal width of time intervals
We investigated whether model performance depended on time interval width. Similarly to the prior example, we simulated a population of 5,000 patients with one binary predictor variable. Survival time distribution was generated using a Weibull distribution, with scale parameter depending on the predictor variable value. Median survival time for the overall population was 182 days. We used the flexible version of nnetsurvival to predict survival time. Four options for time intervals were evaluated:

Uniform intervals with width of 1 year

Uniform intervals with width of 1 month

Uniform intervals with width of 1 week

Increasing width of intervals with increasing followup time, with halflife for interval width of 1 year. Specifically, the time interval borders were placed at: $\frac{ln\left(1x\right)\cdot 365}{ln\left(2\right)}$ for x in [0.0, 0.05, 0.10, …, 0.95]
Discrimination performance was assessed with Harrell’s Cindex.
Convolutional neural network for MNIST dataset
One area in which neural networks have shown clear superiority to other model types is in analysis of 2D image data, for which convolutional neural networks provide stateoftheart results. We wished to demonstrate use of a convolutional neural network as part of a survival model. For this, we used the MNIST dataset (Lecun et al., 1998). This dataset includes images of 70,000 handwritten digits with goldstandard labels, divided into a training set of 60,000 digits and a test set of 10,000 digits. We created a simulated scenario in which each image corresponds to one patient, and patients with higher digits tend to have shorter survival. The images could be imagined as an Xray images of tumors, with higher digits representing larger, more deadly tumors. The goal of the model is to predict survival distribution for each test set patient.
We used only images with digits 0 through 4, leaving 30,596 training set images and 5,139 test set images. Image size was 28 × 28 pixels. Patients’ survival times were drawn from an exponential distribution with scale parameter depending on digit. The scale parameter (with units of days) was set to (14)$\beta =\frac{365\cdot exp\left(0.9\cdot digit\right)}{ln\left(2\right)}$ with the probability density function being: (15)$f\left(t;\frac{1}{\beta}\right)=\frac{1}{\beta}exp\left(\frac{t}{\beta}\right).$
Therefore, median survival ranged from 365 days for digit 0 down to 10 days for digit 4. The setup is illustrated in Fig. 2.
A fivelayer neural network architecture was used, with two convolutional layers of kernel size 3 × 3, followed by a maxpooling layer, a fully connected layer of size 4 neurons, then the output layer. The flexible version of the nnetsurvival model was used, so that nonproportional hazards were possible. The Adam optimizer was used. Model performance was evaluated using the Cindex to measure discrimination, and calibration curves (actual vs. predicted survival curves) to evaluate calibration. As the Nnetsurvival model is flexible and the predicted survival curve for each patient can have a different shape, there is no unique ordering of patients by prognosis (i.e., when comparing two patients, one could have a higher probability of 1year survival but the other could have a higher probability of 2year survival). Therefore, to calculate Cindex, the model’s predicted probability of 1year survival was used to rank the patients.
Performance evaluation: SUPPORT study (real data)
We evaluated the performance of the Nnetsurvival model and other similar models using real patient data. We wished to use a publicly available dataset with timetoevent data on a large number of patients. With a large sample size, we could use data splitting to formally test model performance, and would also be able to evaluate for violations of the proportional hazards assumption of the standard Cox proportional hazards model.
For the real dataset, we chose to use the Study to Understand Prognoses and Preferences for Outcomes and Risks of Treatments (SUPPORT) (Knaus et al., 1995). In this multicenter study, 9,105 hospitalized patients had detailed data recorded including diagnoses, laboratory values, followup time and vital status. The dataset is publicly available on the Vanderbilt Biostatistics web site. The task for the survival models was to predict each patient’s life expectancy with good discrimination and calibration.
Some patients had missing values for one or more predictor variables; in this case we imputed the missing data by using the median value in the sample, or for laboratory values, using the recommended default value listed on the Vanderbilt Biostatistics web site. If more than 4,000 patients were missing a value for the variable, that variable was excluded from analysis. After processing, there were 39 predictor variables. Patients were divided with a 70%/30% split into training and test sets. The processed dataset is available at our project’s GitHub page.
We tested four models on the SUPPORT study dataset:

Our model, Nnetsurvival (flexible version, so that nonproportional hazards were possible)

Coxnnet (Ching, Zhu & Garmire, 2018)

Deepsurv (Katzman et al., 2018)

Standard Cox proportional hazards model
All three neural network models used a simple multilayer perceptron architecture with a single hidden layer. The Coxnnet default parameters specify a hidden layer size of 7 neurons when input dimension is 39, which we felt to be a reasonable choice, so a hidden layer size of 7 was used for the three models. For all three neural network models, L2 regularization was used to help prevent overfitting. The regularization strength parameter was chosen using 10fold cross validation on the training set, using log likelihood as the performance metric. No regularization was used for the standard Cox proportional hazards model. For Nnetsurvival, 19 followup time intervals were used, extending out to 6 years (around the maximum followup time of the SUPPORT study), with larger spacing for later intervals due to the decreased density of failure events with increasing followup time. The RMSprop optimizer was used for Nnetsurvival.
As Coxnnet and Deepsurv only output a prognostic index for each patient, not a predicted survival curve, we generated predicted survival curves for these methods by using the Breslow method to generate a baseline hazard function (Breslow, 1974).
To evaluate the models’ discrimination performance, we used Harrell’s Cindex to assess discrimination. To calculate Cindex for Nnetsurvival, the model’s predicted probability of 1year survival was used to rank the patients.
To evaluate model calibration, we used a published adaptation of the Brier score for censored data (Graf et al., 1999). This was implemented using the ipred R package. We also created calibration plots for specific followup times (Royston & Altman, 2013).
We tested the running time of each method by fitting each model using a range of dataset sizes. Simulated datasets ranging from 1,000 to 1,000,000 patients were created by sampling from the 9,105 SUPPORT study patients with replacement. Each combination of model and sample size was run three times and the results were averaged. Each model was trained for 1,000 epochs. An Ubuntu Linux server with 3.6 GHz Intel Xeon E51650 CPUs and 32GB of RAM was used. The models were constrained to run on one CPU core. Python version 3.5.2 was used for the Nnetsurvival, Coxnnet, and standard Cox proportional hazards models; Python version 2.7.12 was used for Deepsurv. R version 3.4.3 was used to calculate Brier scores (R Core Team, 2017). The code for the SUPPORT study analysis is available in support_study.py in the GitHub repository.
Results
Simulated data
Simple model with one predictor variable
We tested a simple survival model with one binary predictor variable and no hidden layers. The calibration curves for the two groups of patients are shown in Fig. 3. It can be seen that calibration is excellent: the actual and predicted survival curves are superimposed.
Model convergence was found to be reliable. The model was optimized repeatedly with different random starting weights and converged to very similar final loss/likelihood values.
Optimal width of time intervals
The discrimination performance of the survival model was robust to various choices of time interval width and configuration (constant width, or increasing width with increasing followup time). For each of the four time interval options, discrimination performance was identical with Cindex of 0.66.
Convolutional neural network for MNIST dataset
We used the MNIST dataset of handwritten digits to simulate a scenario where each patient has an Xray image of a tumor, and survival time distribution depends on the appearance of the tumor. Digits 0 through 4 were used, with lower digits having longer median survival. The model’s task was to accurately predict survival time for each patient. There were 30,596 images in the training set used to train the model’s weights, and 5,139 in the test set used to evaluate model performance.
The Nnetsurvival model produced good performance. Cindex for the test set was 0.713, compared to 0.770 for a “perfect” model that used the true digit as the predictor variable. Calibration was excellent, as seen in Fig. 4.
Support study (real data)
Four survival models (Nnetsurvival, Coxnnet, Deepsurv, and a standard Cox proportional hazards model) were tested using the SUPPORT study dataset of 9,105 hospitalized patients.
We found that several predictor variables violated the proportional hazards assumption of the standard Cox model, with an example given in Fig. 5. This provides an opportunity for our Nnetsurvival model to have improved calibration compared to the other three models.
All models were trained/fit using the 70% of patients in the training set (n = 6,373). Then, performance was measured using the remaining 30% of patients in the test set (n = 2,732). Discrimination performance was very similar for all models, with test set Cindex around 0.73 (Table 1). Table 1 also shows calibration performance as measured by the Brier score (lower is better). Nnetsurvival had the best calibration performance at all three followup time points, though the differences were fairly small. Calibration was also assessed visually using calibration plots (Fig. 6). Our Nnetsurvival model appeared to have the best calibration at the 6 month and 1 year time points, with Coxnnet and the standard Cox model tending to underpredict survival probability for the bestprognosis patients.
Model  Cindex  Brier score: 6 months  Brier score: 1 year  Brier score: 3 years 

Nnetsurvival  0.732  0.181  0.184  0.177 
Coxnnet  0.735  0.183  0.185  0.177 
Deepsurv  0.730  0.184  0.187  0.179 
Cox PH  0.734  0.183  0.186  0.178 
We compared running time of the three neural network models for various training set sizes, with results shown in Fig. 7. Simulated datasets of size 1,000 to 1,000,000 were created by sampling from the SUPPORT study dataset with replacement. Each model was run for 1,000 epochs. For sample sizes of 100,000 and higher, the Coxnnet model ran out of memory on a computer with 32 GB memory; therefore, for this model running times could only be calculated for sample sizes of 1,000 to 31,622.
Discussion
We presented Nnetsurvival, a discretetime survival model for neural networks. It is theoretically justified since the likelihood function is used as the loss function, and naturally incorporates nonproportional hazards. Because it is a parametric model, it can be trained with minibatch gradient descent as the likelihood/loss depends only on the patients in the current minibatch. This enables fast training, use on datasets that do not fit in memory, and can avoid the network getting stuck in a local minimum of the loss function (Bottou, 1991). This is in contrast to models based on the Cox proportional hazards model such as Coxnnet (Ching, Zhu & Garmire, 2018) and Deepsurv (Katzman et al., 2018), which require the entire training set to be used for each model update (batch gradient descent). The Nnetsurvival model can be applied to a variety of neural network architectures, including multilayer perceptrons and convolutional neural networks.
In our experiments, the model performed well on both simulated and real datasets. It was challenging to find a publicly available dataset that would potentially highlight the advantages of the model. Ideally, such a dataset would have large sample size, predictor data such as images or text that are wellsuited to neural networks, and timetoevent outcome data. Since no such dataset was available to our knowledge, we used the SUPPORT study dataset of 9,105 hospitalized patients, which has moderate sample size and timetoevent outcome data, but has lowdimensional predictor data that may not result in a benefit from a neural network approach. For this dataset, our model’s discrimination and calibration performance was similar to several other neural network survival models and a traditional Cox proportional hazards model. In running time tests, its running time was similar to Deepsurv (Katzman et al., 2018) and better than Coxnnet (Ching, Zhu & Garmire, 2018) for sample sizes >1,000. Interestingly, Coxnnet ran out of memory for larger dataset sizes, because it stores an n by n matrix where n is the sample size (variable name R_matrix_train in the Coxnnet code).
While our model has several advantages and we think it will be useful for a broad range of applications, it does has some drawbacks. The discretization of followup time results in a less smooth predicted survival curve compared to a nondiscrete parametric survival model such as a Weibull accelerated failure time model. As long as a sufficient number of time intervals is used, this is not a large practical concern—for instance, with 19 intervals the curves in Fig. 6 appear very smooth. Unlike a parametric survival model, the model does not provide survival predictions past the end of the last time interval, so it is recommended to extend the last interval past the last followup time of interest.
The advantages of parametric survival models and our discretetime survival model could be combined in the future using a flexible parametric model, such as the cubic splinebased model of Royston and Parmar, implemented in the flexsurv R package (Royston & Parmar, 2002; Jackson, 2016). Complex nonproportional hazards models can be created in this way, and likely could be implemented in deep learning packages.
Conclusions
Our discretetime survival model allows for nonproportional hazards, can be used with stochastic gradient descent, allows rapid training time, and was found to produce good discrimination and calibration performance with both simulated and real data. For these reasons, it may be useful to medical researchers.