Bayesian Uncertainty Estimation by Hamiltonian Monte Carlo: Applications to Cardiac MRI Segmentation

Yidong Zhao1Orcid, João Tourais1Orcid, Iain Pierce2,3, Christian Nitsche2,3, Thomas A. Treibel2,3, Sebastian Weingärtner1Orcid, Artur M. Schweidtmann4Orcid, Qian Tao1Orcid
1: Department of Imaging Physics, Delft University of Technology, Delft, The Netherlands, 2: Barts Heart Centre, Barts Health NHS Trust, London, United Kingdom, 3: Institute of Cardiovascular Science, University College, London, United Kingdom, 4: Department of Chemical Engineering, Delft University of Technology, Delft, The Netherlands
Publication date: 2024/06/23
https://doi.org/10.59275/j.melba.2024-88fa
PDF · Code · arXiv

Abstract

Deep learning (DL)-based methods have achieved state-of-the-art performance for many medical image segmentation tasks. Nevertheless, recent studies show that deep neural net- works (DNNs) can be miscalibrated and overconfident, leading to ”silent failures” that are risky for clinical applications. Bayesian DL provides an intuitive approach to DL failure de- tection, based on posterior probability estimation. However, the posterior is intractable for large medical image segmentation DNNs. To tackle this challenge, we propose a Bayesian learning framework using Hamiltonian Monte Carlo (HMC), tempered by cold posterior (CP) to accommodate medical data augmentation, named HMC-CP. For HMC compu- tation, we further propose a cyclical annealing strategy, capturing both local and global geometries of the posterior distribution, enabling highly efficient Bayesian DNN training with the same computational budget as training a single DNN. The resulting Bayesian DNN outputs an ensemble segmentation along with the segmentation uncertainty. We evaluate the proposed HMC-CP extensively on cardiac magnetic resonance image (MRI) segmentation, using in-domain steady-state free precession (SSFP) cine images as well as out-of-domain datasets of quantitative T1 and T2 mapping. Our results show that the proposed method improves both segmentation accuracy and uncertainty estimation for in- and out-of-domain data, compared with well-established baseline methods such as Monte Carlo Dropout and Deep Ensembles. Additionally, we establish a conceptual link between HMC and the commonly known stochastic gradient descent (SGD) and provide general insight into the uncertainty of DL. This uncertainty is implicitly encoded in the training dynamics but often overlooked. With reliable uncertainty estimation, our method provides a promising direction toward trustworthy DL in clinical applications. We release our code in https://gitlab.tudelft.nl/yidongzhao/hmc_uncertainty

Keywords

Uncertainty estimation · Bayesian deep learning · Hamiltonian Monte Carlo · segmentation · cardiac MRI

Bibtex @article{melba:2024:011:zhao, title = "Bayesian Uncertainty Estimation by Hamiltonian Monte Carlo: Applications to Cardiac MRI Segmentation", author = "Zhao, Yidong and Tourais, João and Pierce, Iain and Nitsche, Christian and Treibel, Thomas A. and Weingärtner, Sebastian and Schweidtmann, Artur M. and Tao, Qian", journal = "Machine Learning for Biomedical Imaging", volume = "2", issue = "June 2024 issue", year = "2024", pages = "856--887", issn = "2766-905X", doi = "https://doi.org/10.59275/j.melba.2024-88fa", url = "https://melba-journal.org/2024:011" }
RISTY - JOUR AU - Zhao, Yidong AU - Tourais, João AU - Pierce, Iain AU - Nitsche, Christian AU - Treibel, Thomas A. AU - Weingärtner, Sebastian AU - Schweidtmann, Artur M. AU - Tao, Qian PY - 2024 TI - Bayesian Uncertainty Estimation by Hamiltonian Monte Carlo: Applications to Cardiac MRI Segmentation T2 - Machine Learning for Biomedical Imaging VL - 2 IS - June 2024 issue SP - 856 EP - 887 SN - 2766-905X DO - https://doi.org/10.59275/j.melba.2024-88fa UR - https://melba-journal.org/2024:011 ER -

2024:011 cover

Disclaimer: the following html version has been automatically generated and the PDF remains the reference version. Feedback can be sent directly to publishing-editor@melba-journal.org

1 Introduction

Image segmentation is an integral part of medical image post-processing in a wide range of clinical applications (Chen et al., 2020). However, the manual delineation of anatomical features or organs is a demanding and time-consuming task in clinical practice. Deep learning (DL)-based automatic segmentation methods, in particular the U-Net and its variants (Ronneberger et al., 2015; Isensee et al., 2021), have demonstrated excellent performance in automated medical image segmentation and become the de facto standard (Bernard et al., 2018; Campello et al., 2021) in literature. Nonetheless, the robustness and reliability of deep neural networks (DNN) remain a major concern for clinical use when tested on data with domain shift (Campello et al., 2021; Yan et al., 2019, 2020). Ideally, such uncertainty can be indicated by the Softmax score (Guo et al., 2017). However, recent studies found DNNs to be seriously miscalibrated (Guo et al., 2017; Minderer et al., 2021; Wang et al., 2021), i.e., the confidence score provided by the Softmax output does not match the empirical accuracy (Mehrtash et al., 2020). While tested on unseen, heterogeneous data, the DL models often output high confidence Softmax score in erroneous predictions, leading to “silent failures” (Gonzalez et al., 2021). This severely undermines the trustworthiness of DL models to clinicians and patients and causes high risks for clinical applications. Therefore, accurate uncertainty estimation, i.e., reporting low confidence when an error likely occurs, has important clinical implications on the trustworthiness of DL systems in real-world deployment (Jungo and Reyes, 2019).

1.1 Related Work

Uncertainty in medical imaging segmentation has recently moved into the attention of the community (Kohl et al., 2018; Baumgartner et al., 2019; Wang et al., 2019; Jungo and Reyes, 2019; Jungo et al., 2020; Gonzalez et al., 2021; Mehrtash et al., 2020). Previous work has investigated two different types of uncertainty. Part of the work focused on the intrinsic ambiguity of contour definition, inherent to the difficulty of segmentation tasks, which is referred to as the aleatoric uncertainty (Hora, 1996; Der Kiureghian and Ditlevsen, 2009). This uncertainty cannot be reduced by collecting more data (Hüllermeier and Waegeman, 2021). Kohl et al. proposed a Probabilistic U-Net (Kohl et al., 2018) which models the variation among experts in manual contouring and aims at generating various feasible segmentation masks to estimate the uncertainty. Baumgartner et al. proposed PHi-Seg, which assumes that the segmentation map is intrinsically ambiguous and governed by hierarchical latent features, while probabilistic predictions can be made via sampling from the learned latent feature distribution (Baumgartner et al., 2019). Test time augmentation (TTA) (Wang et al., 2019) was also proposed to estimate the aleatoric uncertainty of contours via averaging predictions on augmented inputs. A more advanced technology proposed by Ouyang et al. (2022) combines TTA, adaptive temperature scaling, and shape feasibility. However, the aleatoric uncertainty reflects intrinsic ambiguity of the task rather than indicating failures of trained networks.

Another type of uncertainty is the epistemic uncertainty that reflects the model uncertainty when tested on heterogeneous data. (Hüllermeier and Waegeman, 2021; Hora, 1996; Der Kiureghian and Ditlevsen, 2009; Kendall and Gal, 2017). Most of the epistemic estimation methods fall into the Bayesian neural network (BNN) framework (MacKay, 1995; Lampinen and Vehtari, 2001; Wang and Yeung, 2020; Marcot and Penman, 2019). The BNNs model the posterior probability of the learned DNN weights (MacKay, 1995; Neal, 2012). The predictive uncertainty is inferred from the distribution of the DNN weights and subsequently that of the DNN predictions (Gal and Ghahramani, 2016; Kendall and Gal, 2017; Mehrtash et al., 2020). Modern BNN architectures also learn the distribution of explainable variables that govern the mapping from images to segmentations instead of weights for better generalization (Gao et al., 2023). However, the posterior distribution is prohibitively difficult to analytically derive for large networks (Blundell et al., 2015; Blei et al., 2017; Neal, 2012). As such, the Variational Inference (VI) method (Blei et al., 2017) has been proposed for the posterior approximation, which models network weights as independent Gaussian random variables (Blundell et al., 2015). Such approximation is intrinsically limited by the strong assumption of Gaussian posterior. Moreover, it practically doubles the network’s parameters and can become unstable during training (Ovadia et al., 2019; Jospin et al., 2022). The Monte-Carlo Dropout (MC-Dropout) (Gal and Ghahramani, 2016; Kendall and Gal, 2017) proposed by Gal et al. can be considered as a VI proxy, assuming that the posterior of weights is modulated by a random Bernoulli random variable. In the same spirit, Bayesian SegNet (Kendall et al., 2015) employed dropout layers in the bottleneck layers of fully convolutional networks for uncertainty estimation.

Unfortunately, uncertainty estimation by VI and its MC-Dropout proxy remained insufficient for large DL models. Recent work reported silent failures, poor calibration, and degraded segmentation performance (Folgoc et al., 2021; Gonzalez and Mukhopadhyay, 2021; Gonzalez et al., 2021). Fort et al. showed that the VI-based methods including MC-Dropout only explore a limited, local weight space due to the restrictive assumptions. In comparison, Deep Ensembles (Lakshminarayanan et al., 2016; Mehrtash et al., 2020) estimate network uncertainty via averaging independently trained network instances. The independently trained instances can be seen as a combination of maximum-a-posteriori (MAP) solutions (Fort et al., 2019). The ability to globally explore solutions makes Deep Ensembles the best-performing uncertainty estimation method so far (Ovadia et al., 2019; Abdar et al., 2021; Fort et al., 2019; Gustafsson et al., 2020). However, theoretical and practical limitations remain: first, Deep Ensembles ignores the local posterior geometry around the MAP solution, which was reported to be important for DNN calibration (Garipov et al., 2018; Maddox et al., 2019; Mingard et al., 2021); second, the time complexity of Deep Ensembles grows linearly with the number of models. It becomes computationally prohibitive, given that training a single network is time-and-energy-consuming.

In this work, we aim to address the aforementioned limitations of BNNs for medical image segmentation: the VI and MC-Dropout methods have limited approximation capacity, while Deep Ensembles fail to cover local posterior distribution while being computationally inefficient. We propose to use the Markov Chain Monte Carlo (MCMC) (Hammersley, 2013; Hastings, 1970) approach, and in particular the Hamiltonian Monte Carlo (HMC) (Neal et al., 2011; Chen et al., 2014). HMC treats sampling of a target distribution as modeling of particle motion (Risten, 1989; Särkkä and Solin, 2019). It is theoretically guaranteed that simulating the Hamiltonian dynamics yields samples conforming to the target distribution (Neal et al., 2011), hence it theoretically promises improved BNN uncertainty estimation compared with previous methods with restrictive assumptions.

In literature, Izmailov et al. (2021) employed a full-batch HMC to explore the precise posterior of neural networks. However, the full-batch HMC is not scalable to large neural networks because of the computational efficiency (Izmailov et al., 2021). The first scalable HMC with stochastic gradient (SGHMC) on neural networks was proposed by Chen et al. (2014) for posterior estimation. Further works reveal that tempering the posterior is needed for HMC sampling with stochastic gradient (Zhang et al., 2019; Wenzel et al., 2020). However, these early attempts (Chen et al., 2014; Izmailov et al., 2018; Zhang et al., 2019) focus on simple classification or regression tasks. The research on segmentation networks’ behavior with a dense output under the HMC dynamics is seriously limited, and warrants further investigation. For example, Wenzel et al. (2020) reported the necessity of a tempered posterior in the classification tasks but Izmailov et al. (2021) claimed that it is not necessary for the full-batch HMC but rather an artifact of data augmentation. This raises questions on the posterior choice in Bayesian segmentation networks, where data augmentation is heavily used. A standard method of posterior distribution modeling is to use Gibbs distribution (Lifshitz and Landau, 1984) and treat the inverse of the predefined loss function as “energy” (Carvalho et al., 2020; Kapoor et al., 2022). However, training data augmentation, as is commonly used in medical image applications due to data scarcity, would render the exact modeling of posterior intractable, as the independent number of observed data samples becomes ambiguous after data augmentation. This leads to the so-called “dirty likelihood” effect and results in degraded performance of BNNs (Nabarro et al., 2021; Wenzel et al., 2020). Moreover, the sampling strategy of the HMC chain remains understudied in segmentation networks, including e.g. the number of HMC samples needed for proper calibration. IN this work, we propose to investigate and evaluate cold posterior in Bayesian segmentation to research the “dirty likelihood” in the presence of data augmentation.

Additionally, a largely unanswered question is whether the diversity in the posterior weight space W𝑊W propagates to that of the functional space fW()subscript𝑓𝑊f_{W}(\cdot). We differentiate these two spaces because there is no simple relationship between the two due to symmetry222A permuted set of weights, for example, can lead to the same function., whereas functional space diversity critically determines the quality of uncertainty estimation (Kendall and Gal, 2017; Fort et al., 2019). Limited research has been done to investigate the functional diversity of BNNs, and none for medical image segmentation applications. For classification, Fort et al. used cosine similarity to analyze the similarity between posterior weights and evaluated the predictive diversity in functional space via comparing classification agreement (Fort et al., 2019). Segmentation networks, however, have much more complicated output in high dimensions (Ronneberger et al., 2015). In this work, we propose to evaluate the functional space diversity of BNNs for segmentation uncertainty beyond that of the posterior weight space.

In practice, the out-of-domain performance of uncertainty estimation is crucial to failure detection in cardiac MRI because of the domain shifts caused by imaging protocol variations. Previous works (Ovadia et al., 2019; Izmailov et al., 2021) researched the network behavior under simulated distortions such as additive noise and blurring. However, real-world domain shift appears much more complicated than such in-silico distortion simulations, such as novel quantitative MRI mapping versus conventional cine MRI. The performance of image segmentation and uncertainty estimation under such drastic domain shifts have not been systematically studied. Finally, for ease of use in clinical practice and scalable data analysis (Jungo and Reyes, 2019; Czolbe et al., 2021), an aggregated confidence score that can detect the segmentation failure on the image level would be highly clinically relevant, which obviates the need for users to review the voxel-level uncertainty maps (Kendall et al., 2015; Czolbe et al., 2021).

1.2 Contributions

This study substantially extends the theory, analysis, and application of our previous work published in MICCAI 2022 (Zhao et al., 2022), in which we proposed the training checkpoint ensemble during SGD with momentum. In particular, we developed the theoretical foundation of HMC uncertainty estimation and absorbed the previously proposed method as a special case. Specifically, we have made the following contributions:

  • We propose a Bayesian DL framework for medical image segmentation using HMC-CP, which delivers improved uncertainty estimation compared with state-of-the-art baseline methods. The proposed method is highly efficient in computation with the novel annealing learning strategy for multi-modal posterior sampling, because of the natural resemblance between HMC sampling and SGD optimization. We systematically investigated the effect of cold-posterior in the cardiac MRI segmentation network and researched the calibration performance with various numbers of posterior samples.

  • We extensively analyze the functional diversity of the Bayesian segmentation networks by the proposed and other existing methods. We demonstrate that the proposed method yields superior functional diversity compared with other methods, leading to more accurate uncertainty estimation.

  • We propose an image-level uncertainty score for ease of use in clinical practice and evaluated our proposed method on datasets covering substantial domain shifts from cine to quantitative MRI data. Empirical results showed that the proposed score can effectively detect segmentation failure, for both in-domain and out-of-domain datasets.

2 Methods

2.1 Posterior Modelling of Segmentation Networks

BNNs admit a statistical model p(𝒘)𝑝𝒘p(\bm{w}) as its prior distribution over the network weights (Jospin et al., 2022), which characterizes the weight distribution before observing any data. Following (Carvalho et al., 2020; Hammam et al., 2021), we assume the weight prior as a zero-mean Gaussian: p(𝒘)𝒩(0,1λ𝕀)similar-to𝑝𝒘𝒩01𝜆𝕀p\left(\bm{w}\right)\sim\mathcal{N}\left(0,\frac{1}{\lambda}\mathbb{I}\right), where λ𝜆\lambda controls the prior variance. According to Bayes’ Theorem, the weight distribution can be re-estimated after observing the dataset, known as posterior (Neal, 2012). In this section, we define the posterior distribution of segmentation models.

2.1.1 The Weight Posterior

Given a training dataset 𝒟={𝒙i,𝒚i}i=1n𝒟superscriptsubscriptsubscript𝒙𝑖subscript𝒚𝑖𝑖1𝑛\mathcal{D}=\left\{\bm{x}_{i},\bm{y}_{i}\right\}_{i=1}^{n} with n𝑛n image-label pairs, the training procedure learns a weight setting 𝒘𝒘\bm{w} that minimizes the discrepancy between f𝒘(𝒙i)subscript𝑓𝒘subscript𝒙𝑖f_{\bm{w}}(\bm{x}_{i}) and 𝒚isubscript𝒚𝑖\bm{y}_{i} for i=1,2,,n𝑖12𝑛i=1,2,\dots,n, where f𝒘subscript𝑓𝒘f_{\bm{w}} is the DNN parameterized by 𝒘𝒘\bm{w}. The prior distribution of weights shrinks to the posterior with the presence of 𝒟𝒟\mathcal{D}. According to Bayes’ Theorem, the following relationship holds:

p(𝒘|𝒟)p(𝒟|𝒘)p(𝒘)proportional-to𝑝conditional𝒘𝒟𝑝conditional𝒟𝒘𝑝𝒘p\left(\bm{w}|\mathcal{D}\right)\propto p\left(\mathcal{D}|\bm{w}\right)\cdot p\left(\bm{w}\right)(1)

where the likelihood term p(𝒟|𝒘)𝑝conditional𝒟𝒘p\left(\mathcal{D}|\bm{w}\right) measures how well the network prediction f𝒘(𝒙)subscript𝑓𝒘𝒙f_{\bm{w}}(\bm{x}) on a training sample 𝒙𝒙\bm{x} with weight 𝒘𝒘\bm{w} aligns with the ground truth 𝒚𝒚\bm{y} in the training set.

In this work, we adopt the widely-used nnU-Net (Isensee et al., 2021) as our f𝒘subscript𝑓𝒘f_{\bm{w}}, and use a combination of soft-Dice loss DSCsubscript𝐷𝑆𝐶\mathcal{L}_{DSC} and cross entropy loss CEsubscriptCE\mathcal{L}_{\mathrm{CE}} to estimate the discrepancy between the network prediction f𝒘(𝒙)subscript𝑓𝒘𝒙f_{\bm{w}}(\bm{x}) and the ground truth 𝒚𝒚\bm{y}. For an image with N𝑁N voxels, let p(𝒚^i=c|𝒙i,w)𝑝subscript^𝒚𝑖conditional𝑐subscript𝒙𝑖𝑤p(\hat{\bm{y}}_{i}=c|\bm{x}_{i},{w}) be the predictive probability of voxel 𝒙isubscript𝒙𝑖\bm{x}_{i} belonging to class c𝑐c with C𝐶C semantic classes in total, the soft-Dice loss is defined as:

DSC=2c=1Ci=1Np(𝒚^i=c|𝒙i,𝒘)(𝒚i=c)i=1Np(𝒚^i=c|𝒙i,𝒘)+(𝒚i=c),subscriptDSC2superscriptsubscript𝑐1𝐶superscriptsubscript𝑖1𝑁𝑝subscript^𝒚𝑖conditional𝑐subscript𝒙𝑖𝒘subscript𝒚𝑖𝑐superscriptsubscript𝑖1𝑁𝑝subscript^𝒚𝑖conditional𝑐subscript𝒙𝑖𝒘subscript𝒚𝑖𝑐\mathcal{L}_{\mathrm{DSC}}=-2\sum_{c=1}^{C}\frac{\sum_{i=1}^{N}p(\hat{\bm{y}}_{i}=c|\bm{x}_{i},\bm{w})\cdot(\bm{y}_{i}=c)}{\sum_{i=1}^{N}p(\hat{\bm{y}}_{i}=c|\bm{x}_{i},\bm{w})+(\bm{y}_{i}=c)},(2)

and the cross entropy loss is defined as:

CE=1Ni=1Nc=1Clogp(𝒚^i=c|𝒙i,𝒘)(𝒚i=c).subscriptCE1𝑁superscriptsubscript𝑖1𝑁superscriptsubscript𝑐1𝐶𝑝subscript^𝒚𝑖conditional𝑐subscript𝒙𝑖𝒘subscript𝒚𝑖𝑐\mathcal{L}_{\mathrm{CE}}=-\frac{1}{N}\sum_{i=1}^{N}\sum_{c=1}^{C}\log p(\hat{\bm{y}}_{i}=c|\bm{x}_{i},\bm{w})\cdot(\bm{y}_{i}=c).(3)

The total loss (𝒘)=DSC(𝒘)+CE(𝒘)𝒘subscriptDSC𝒘subscriptCE𝒘\mathcal{L}(\bm{w})=\mathcal{L}_{\mathrm{DSC}}(\bm{w})+\mathcal{L}_{\mathrm{CE}}(\bm{w}) measures how likely the training samples are observed under the weight setting 𝒘𝒘\bm{w}. In this work, we follow (Carvalho et al., 2020; Wenzel et al., 2020) and define the likelihood with Gibbs distribution (Lifshitz and Landau, 1984):

p(𝒟|𝒘)exp[i=1n(f𝒘(𝒙i),𝒚i)].proportional-to𝑝conditional𝒟𝒘superscriptsubscript𝑖1𝑛subscript𝑓𝒘subscript𝒙𝑖subscript𝒚𝑖\begin{split}p\left(\mathcal{D}|\bm{w}\right)\propto\exp{\left[-\sum_{i=1}^{n}\mathcal{L}\left(f_{\bm{w}}(\bm{x}_{i}),\bm{y}_{i}\right)\right]}.\end{split}(4)

We aim to draw samples that maximize the log-posterior logp(𝒘|𝒟)𝑝conditional𝒘𝒟\log p\left(\bm{w}|\mathcal{D}\right), which is equivalent to minimizing the following energy function during training:

U(𝒘)=logp(𝒘|𝒟)=logp(𝒟|𝒘)logp(𝒘)logexp[i=1n(f𝒘(𝒙i),𝒚i)]logexp[𝒘T𝒘2λ1]=i=1n(f𝒘(𝒙i),𝒚i)+12λ𝒘22,𝑈𝒘𝑝conditional𝒘𝒟𝑝conditional𝒟𝒘𝑝𝒘proportional-tosuperscriptsubscript𝑖1𝑛subscript𝑓𝒘subscript𝒙𝑖subscript𝒚𝑖superscript𝒘𝑇𝒘2superscript𝜆1superscriptsubscript𝑖1𝑛subscript𝑓𝒘subscript𝒙𝑖subscript𝒚𝑖12𝜆superscriptsubscriptdelimited-∥∥𝒘22\begin{split}U(\bm{w})&=-\log p\left(\bm{w}|\mathcal{D}\right)=-\log p\left(\mathcal{D}|\bm{w}\right)-\log p\left(\bm{w}\right)\\ &\propto-\log\exp{\left[-\sum_{i=1}^{n}\mathcal{L}\left(f_{\bm{w}}(\bm{x}_{i}),\bm{y}_{i}\right)\right]}-\log\exp\left[-\frac{\bm{w}^{T}\bm{w}}{2\lambda^{-1}}\right]\\ &=\sum_{i=1}^{n}\mathcal{L}\left(f_{\bm{w}}(\bm{x}_{i}),\bm{y}_{i}\right)+\frac{1}{2}\lambda\left\|\bm{w}\right\|_{2}^{2},\end{split}(5)

where λ𝜆\lambda is the inverse of the Gaussian prior variance. Note that the energy function U𝑈U is equivalent to the loss function in normal neural network training with SGD momentum, and the Gaussian prior with variance λ1superscript𝜆1\lambda^{-1} reduces to a commonly used L2subscript𝐿2L_{2} regularization term in the energy (loss) function. In practice, we choose λ=3×105𝜆3superscript105\lambda=3\times 10^{-5} which forms a relatively weak prior assumption because of the high prior variance. We will also research the effect of varying λ𝜆\lambda.

2.1.2 The Cold Posterior with Tempering

For medical image segmentation, data augmentation proved to be a highly practical and effective strategy to overcome the data scarcity problem (Isensee et al., 2021; Campello et al., 2021; Chlap et al., ). Extensive data augmentation is also explicitly performed in the nnU-Net (Isensee et al., 2021). Data augmentation, however, violates the independent and identically distributed (i.i.d.) assumption of data samples, leading to the so-called dirty likelihood (Nabarro et al., 2021). We proposed to mitigate the undesirable effect of data augmentation on likelihood estimation by tempering Eq. eq:likihood, as recently suggested by Nabarro et al. (Wenzel et al., 2020; Nabarro et al., 2021):

pcold(𝒘|𝒟)exp(U(𝒘)/T),proportional-tosubscript𝑝coldconditional𝒘𝒟𝑈𝒘𝑇p_{\mathrm{cold}}\left(\bm{w}|\mathcal{D}\right)\propto\exp(-U(\bm{w})/T),(6)

where T𝑇T is named “temperature” in analogy to Maxwell-Boltzmann Statistics in physics, to counteract the dirty likelihood effect of data augmentation. A “cold” temperature T<1𝑇1T<1 compensates for the increased number of training samples from data augmentation and limits the variance of posterior samples (Nabarro et al., 2021; Wenzel et al., 2020).

2.2 Bayesian Inference and Voxel-wise Uncertainty

With the posterior estimation of the weight distribution p(𝒘|𝒟)𝑝conditional𝒘𝒟p(\bm{w}|\mathcal{D}), the prediction on a test image 𝒙superscript𝒙\bm{x}^{*} can be made by integration:

p(𝒚|𝒙,𝒟)=𝒘p(𝒘|𝒟)p(𝒚|𝒙,𝒘)𝑑𝒘,𝑝conditionalsuperscript𝒚superscript𝒙𝒟subscript𝒘𝑝conditional𝒘𝒟𝑝conditionalsuperscript𝒚superscript𝒙𝒘differential-d𝒘p(\bm{y}^{*}|\bm{x}^{*},\mathcal{D})=\int_{\bm{w}}p\left(\bm{w}|\mathcal{D}\right)p\left(\bm{y}^{*}|\bm{x}^{*},\bm{w}\right)d\bm{w},(7)

which, however, cannot be solved analytically without restrictive assumptions on the exact form of likelihood and prior model (e.g., Gaussian) (Gal and Ghahramani, 2016; Blei et al., 2017). We approximate the integration in Eq. eq:marginalization using the Monte-Carlo method, which is assumption-free and scalable to network sizes:

p(𝒚|𝒙,𝒟)1Mj=1Mp(𝒚|𝒙,𝒘j),𝑝conditionalsuperscript𝒚superscript𝒙𝒟1𝑀superscriptsubscript𝑗1𝑀𝑝conditionalsuperscript𝒚superscript𝒙subscript𝒘𝑗p(\bm{y}^{*}|\bm{x}^{*},\mathcal{D})\approx\frac{1}{M}\sum_{j=1}^{M}p\left(\bm{y}^{*}|\bm{x}^{*},\bm{w}_{j}\right),(8)

where {𝒘j}j=1Msuperscriptsubscriptsubscript𝒘𝑗𝑗1𝑀\{\bm{w}_{j}\}_{j=1}^{M} are the M𝑀M samples drawn from the posterior distribution p(𝒘|𝒟)𝑝conditional𝒘𝒟p\left(\bm{w}|\mathcal{D}\right) and p(𝒚|𝒙,𝒘j)=f𝒘j(𝒙)𝑝conditionalsuperscript𝒚superscript𝒙subscript𝒘𝑗subscript𝑓subscript𝒘𝑗superscript𝒙p\left(\bm{y}^{*}|\bm{x}^{*},\bm{w}_{j}\right)=f_{\bm{w}_{j}}(\bm{x}^{*}) is the voxel-wise probabilistic prediction made by the network with weight 𝒘jsubscript𝒘𝑗\bm{w}_{j}. In practice, the M𝑀M samples are saved checkpoints during the training or posterior sampling process and these samples can form an ensemble without the need to train multiple ensembles. Subsequently, we can estimate the predictive uncertainty map based on the voxel-wise binary entropy Hcsubscript𝐻𝑐H_{c} of each class or the categorical distribution entropy \mathcal{H} (Mehrtash et al., 2020):

csubscript𝑐\displaystyle\mathcal{H}_{c}=pclogpc(1pc)log(1pc),absentsubscript𝑝𝑐subscript𝑝𝑐1subscript𝑝𝑐1subscript𝑝𝑐\displaystyle=-p_{c}\log p_{c}-(1-p_{c})\log(1-p_{c}),(9)
\displaystyle\mathcal{H}=c=1Cpclogpc,absentsuperscriptsubscript𝑐1𝐶subscript𝑝𝑐subscript𝑝𝑐\displaystyle=-\sum_{c=1}^{C}p_{c}\log p_{c},(10)

where pc=p(𝒚i=c|𝒙,𝒟)subscript𝑝𝑐𝑝superscriptsubscript𝒚𝑖conditional𝑐superscript𝒙𝒟p_{c}=p(\bm{y}_{i}^{*}=c|\bm{x}^{*},\mathcal{D}).

2.3 Posterior Sampling via Hamiltonian Monte Carlo

HMC is an MCMC variant that can effectively generate samples conforming to a given distribution (Neal et al., 2011; Chen et al., 2014), scalable to high dimensionality (Speagle, 2019). In this section, we introduce the HMC sampling of the CP distribution defined in Eq. eq:cold_post.

2.3.1 Stochastic Gradient Hamiltonian Monte Carlo

Refer to caption
Figure 1: With a limited amount of training data, the network admits infinite weight solutions that can explain the training set. The posterior of weights models the probability density of the solution space, which is characterized by multiple local optima. The HMC chain (black line) is guided by the momentum (red arrow) which accumulates the gradient (purple arrow) to approach the local optima. The noise (green arrow) encourages the exploration of the low-loss surface. Multiple local optima can be visited by the chain via the annealing strategy. The weight space sampling is essentially similar to training the networks with SGD with momentum. In practice, checkpoints during the chain simulation are saved as posterior samples to form ensembles for function space marginalization.

We propose drawing samples from the CP Eq. eq:cold_post, by HMC (Neal et al., 2011) sampling. HMC builds a Markov chain by simulating the motion of the particles in an energy field U(𝒘)𝑈𝒘U(\bm{w}) at position 𝒘𝒘\bm{w} with momentum 𝒓𝒓\bm{r}. The particle dynamics is governed by the Hamiltonian, H(𝒘,𝒓)=U(𝒘)+12𝒓T𝑴1𝒓𝐻𝒘𝒓𝑈𝒘12superscript𝒓𝑇superscript𝑴1𝒓H(\bm{w},\bm{r})=U(\bm{w})+\frac{1}{2}\bm{r}^{T}\bm{M}^{-1}\bm{r}, in which the potential energy U(𝒘)𝑈𝒘\nabla U(\bm{w}) drives the particle to a low energy state (equivalently, the low loss region in the weight space). The auxiliary momentum term 12𝒓T𝑴1𝒓12superscript𝒓𝑇superscript𝑴1𝒓\frac{1}{2}\bm{r}^{T}\bm{M}^{-1}\bm{r} simulates the kinetic energy that makes the particle explore the low-energy surface. Without loss of generality, the mass 𝑴𝑴\bm{M} can be set as the identity matrix.

The HMC dynamics simulation requires the evaluation of the full batch gradient U(𝒘)𝑈𝒘\nabla U(\bm{w}) (Neal et al., 2011). In practice, however, we only have access to the stochastic gradient estimated on a mini-batch of size nbsubscript𝑛𝑏n_{b}:

U~(𝒘)=1nbi=1nb[(f𝒘(𝒙i),𝒚i)+λ𝒘]=1n[U(𝒘)+𝒩(0,2𝑽)],~𝑈𝒘1subscript𝑛𝑏superscriptsubscript𝑖1subscript𝑛𝑏delimited-[]subscript𝑓𝒘subscript𝒙𝑖subscript𝒚𝑖𝜆𝒘1𝑛delimited-[]𝑈𝒘𝒩02𝑽\begin{split}\nabla\tilde{U}(\bm{w})&=\frac{1}{n_{b}}\sum_{i=1}^{n_{b}}\left[\nabla\mathcal{L}\left(f_{\bm{w}}(\bm{x}_{i}),\bm{y}_{i}\right)+\lambda\bm{w}\right]\\ &=\frac{1}{n}\left[\nabla{U}(\bm{w})+\mathcal{N}(0,2\bm{V})\right],\end{split}(11)

where U𝑈U is defined in Eq. eq:loss_function and 12𝒘22=𝒘12superscriptsubscriptnorm𝒘22𝒘\nabla\frac{1}{2}\|\bm{w}\|_{2}^{2}=\bm{w}. The stochastic gradient estimation U~(𝒘)~𝑈𝒘\nabla\tilde{U}(\bm{w}) introduces additional noise of strength 𝑽𝑽\bm{V} to the true gradient. In the presence of such noise, the stationary distribution of the HMC samples is no longer the target distribution (Chen et al., 2014).

To address this problem, we propose to use the Stochastic Gradient Hamiltonian Monte Carlo (SGHMC) (Chen et al., 2014), which introduces a friction term that compensates for the stochastic gradient noise and a Gaussian noise to the momentum update such that the dampening friction matches the noise level. In practice, the tempered posterior by T𝑇T in Eq. eq:cold_post leads to the Hamiltonian H(𝒘,𝒓)=1TU(𝒘)+12𝒓T𝑴1𝒓𝐻𝒘𝒓1𝑇𝑈𝒘12superscript𝒓𝑇superscript𝑴1𝒓H(\bm{w},\bm{r})=\frac{1}{T}U(\bm{w})+\frac{1}{2}\bm{r}^{T}\bm{M}^{-1}\bm{r}, and the gradient scales linearly to 1TU~(𝒘)1𝑇~𝑈𝒘\frac{1}{T}\nabla\tilde{U}(\bm{w}). The Markov chain can be simulated according to the discrete form of SGHMC:

{𝒘t+1=𝒘t+𝒓t𝒓t+1=(1μ)𝒓t1TηtnU~(𝒘t)+2ηtμ𝒩(0,𝕀),\left\{\begin{aligned} \bm{w}_{t+1}&=\bm{w}_{t}+\bm{r}_{t}\\ \bm{r}_{t+1}&=(1-\mu)\bm{r}_{t}-\frac{1}{T}\eta_{t}{n}\nabla\tilde{U}(\bm{w}_{t})+\sqrt{{2\eta_{t}\mu}}\mathcal{N}(0,\mathbb{I}),\end{aligned}\right.(12)

where μ𝜇\mu is the friction coefficient, ηtsubscript𝜂𝑡\eta_{t} is the step size of HMC simulation. Note that the momentum update rule in Eq. eq:sghmc_discrete is equivalent to the following form:

𝒓t+1=(1μ)𝒓tηtnU~(𝒘t)+2ηtμT𝒩(0,𝕀)subscriptsuperscript𝒓𝑡11𝜇subscriptsuperscript𝒓𝑡subscript𝜂𝑡𝑛~𝑈subscript𝒘𝑡2subscript𝜂𝑡𝜇𝑇𝒩0𝕀\bm{r}^{\prime}_{t+1}=(1-\mu)\bm{r}^{\prime}_{t}-\eta_{t}{n}\nabla\tilde{U}(\bm{w}_{t})+\sqrt{{2\eta_{t}\mu}T}\mathcal{N}(0,\mathbb{I})(13)

by multiplying T𝑇T on both sides and use 𝒓=T𝒓superscript𝒓𝑇𝒓\bm{r}^{\prime}=T\bm{r} to replace the original 𝒓𝒓\bm{r}. The dynamics of (𝒘,𝒓)𝒘𝒓(\bm{w},\bm{r}) in Eq. eq:sghmc_discrete yield samples whose stationary distribution is exactly pcold(𝒘|𝒟)subscript𝑝coldconditional𝒘𝒟p_{\mathrm{cold}}(\bm{w}|\mathcal{D}). This can be strictly proven via the Fokker-Planck-Equation of the stationary distribution of SGHMC (Särkkä and Solin, 2019; Chen et al., 2014).

Here, we note that when T=0𝑇0T=0, Eq. eq:sghmc_discrete is exactly the update rule of SGD with momentum, where 1μ1𝜇1-\mu is equivalent to the momentum term and ηtsubscript𝜂𝑡\eta_{t} is the learning rate. The length of the Markov chain is the number of iterations in network training using SGD-momentum. In this case, the single source of the stochastic noise is the gradient estimation noise in U~(𝒘t)~𝑈subscript𝒘𝑡\tilde{U}(\bm{w}_{t}) of strength 𝑽𝑽\bm{V}. As T𝑇T increases, additional noise is injected and perturbs the gradient direction, which can be considered as an SGD-momentum process with elevated gradient estimation noise. In summary, the HMC sampling process is equivalent to the network optimization process as is shown in Eq. eq:sghmc_discrete and, thus, comes at no additional cost. The optimizer performs as a posterior sampler and the sampling process is in essence saving the checkpoints. The overview of the proposed method is shown in Fig. 1: the gradient (potential force) drives the chain to high-posterior-density regions, and the momentum term and the injected noise keep the chain exploring the vicinity of a local optimum.

2.3.2 Annealing Learning Rate and Sample Thinning

The learning rate controls the convergence in optimization and Monte Carlo sampling. We present a novel approach to reschedule the learning rate for more accurate posterior sampling. Specifically, to let the chain explore a broader area and prevent it from converging to a single point, we set the learning rate as a constant non-zero value after the γ𝛾\gamma fraction of the training budget. To capture the multimodal posterior geometry that is typical of complex DNNs (Zhang et al., 2019; Huang et al., 2017), we further propose to use cyclical annealing training such that the Markov chain can visit multiple modes of the posterior. We divide the training budget of TEsubscript𝑇𝐸T_{E} epochs into NCsubscript𝑁𝐶N_{C} cycles and each cycle consumes Tc=TENCsubscript𝑇𝑐subscript𝑇𝐸subscript𝑁𝐶T_{c}=\frac{T_{E}}{N_{C}} epochs. In particular, we propose to use a high learning rate at the beginning of each cycle such that the perturbation is strong enough to drive the chain into various posterior modes.

The overall learning rate for the Hamiltonian dynamics simulation is formulated as:

ηte={ηr,tc<Trη0(1min{tc,γTc}Tc)0.9,tcTr.\eta_{t_{e}}=\begin{cases}\eta_{r}&,t_{c}<T_{r}\\ \eta_{0}\cdot\left(1-\frac{\min\left\{t_{c},\gamma T_{c}\right\}}{T_{c}}\right)^{0.9}&,t_{c}\geq T_{r}.\end{cases}(14)

where tc=temodTcsubscript𝑡𝑐subscript𝑡𝑒modsubscript𝑇𝑐t_{c}=t_{e}\ \mathrm{mod}\ T_{c} is the intra-cycle epoch number and ηrsubscript𝜂𝑟\eta_{r} is the high restart learning rate which was set for the first Trsubscript𝑇𝑟T_{r} epochs in each cycle.

The first γ𝛾\gamma fraction of training epochs are considered as the burn-in stage of SGHMC (Zhang et al., 2019). The weights computed at each iteration after the burn-in stage can be seen as a sample of the posterior distribution. However, a single iteration causes a marginal change in weights in the Markov chain, and the consecutive samples can be highly correlated. The auto-correlation between samples significantly reduces the number of effective samples in a Markov chain (Hammersley, 2013). Moreover, collecting all samples after the burn-in stage requires substantial disk space while the inference would be extremely time-consuming. After the burn-in stage, we, therefore, adopt the sample thinning strategy (Hammersley, 2013) to only collect samples at the end of every fourth epoch (every 1000 iterations): 𝑾={𝒘tj|tjmodTcγ,1tjTE,tjmod 4=0}𝑾conditional-setsubscript𝒘subscript𝑡𝑗formulae-sequenceformulae-sequencesubscript𝑡𝑗modsubscript𝑇𝑐𝛾1subscript𝑡𝑗subscript𝑇𝐸subscript𝑡𝑗mod40\bm{W}=\left\{\bm{w}_{t_{j}}|t_{j}\ \mathrm{mod}\ T_{c}\geq\gamma,1\leq t_{j}\leq T_{E},t_{j}\ \mathrm{mod}\ 4=0\right\}.

2.4 Weight and Functional Space Diversity

We differentiate two types of diversity for weights 𝒘𝒘\bm{w} and function f𝒘subscript𝑓𝒘f_{\bm{w}}, respectively. Although both are relevant to uncertainty estimation, their relationship is complex and is largely understudied. To investigate the diversity of weights, we use mutual cosine similarity (Larrazabal et al., 2021) as a metric for weight space diversity, which is defined as:

simcos(𝒘i,𝒘j)=<𝒘i,𝒘j>𝒘i2𝒘j2.\mathrm{sim}_{\mathrm{cos}}(\bm{w}_{i},\bm{w}_{j})=\frac{<\bm{w}_{i},\bm{w}_{j}>}{\|\bm{w}_{i}\|_{2}\cdot\|\bm{w}_{j}\|_{2}}.(15)

Additionally, we monitor the volume of the high-dimensional space that the chain explored, via rectangular approximation:

vol(𝑾)=s=1nσσs,vol𝑾superscriptsubscriptproduct𝑠1subscript𝑛𝜎subscript𝜎𝑠\mathrm{vol}(\bm{W})=\prod_{s=1}^{n_{\sigma}}\sigma_{s},(16)

where σssubscript𝜎𝑠\sigma_{s}’s are the first singular values of the matrix 𝑾=[𝒘1,𝒘2,,𝒘M]𝑾subscript𝒘1subscript𝒘2subscript𝒘𝑀\bm{W}=[\bm{w}_{1},\bm{w}_{2},\dots,\bm{w}_{M}]. In practice, we choose nσ=5subscript𝑛𝜎5n_{\sigma}=5 because the first five can explain at least 90%percent9090\% of the total weight variance. Additionally, nσ=5subscript𝑛𝜎5n_{\sigma}=5 is a computationally practical choice because of the extremely high dimension of weights 𝑾𝑾\bm{W}.

To investigate the diversity of functional space, we propose to evaluate the variation of predictions by f𝒘subscript𝑓𝒘f_{\bm{w}} on the validation set. Given two functions f𝒘isubscript𝑓subscript𝒘𝑖f_{\bm{w}_{i}} and f𝒘jsubscript𝑓subscript𝒘𝑗f_{\bm{w}_{j}}, we measure the functional space distance on a validation set 𝒟valsubscript𝒟𝑣𝑎𝑙\mathcal{D}_{val} as:

d(f𝒘i,f𝒘j)=11|𝒟val|𝒙𝒟valDSC(𝒆f𝒘i(𝒙),𝒆f𝒘j(𝒙)),𝑑subscript𝑓subscript𝒘𝑖subscript𝑓subscript𝒘𝑗11subscript𝒟𝑣𝑎𝑙subscript𝒙subscript𝒟𝑣𝑎𝑙DSC𝒆subscript𝑓subscript𝒘𝑖𝒙𝒆subscript𝑓subscript𝒘𝑗𝒙\begin{array}[]{l}d(f_{\bm{w}_{i}},f_{\bm{w}_{j}})=1-\frac{1}{|\mathcal{D}_{val}|}\sum\limits_{\bm{x}\in\mathcal{D}_{val}}\mathrm{DSC}(\bm{e}\circ f_{\bm{w}_{i}}(\bm{x}),\bm{e}\circ f_{\bm{w}_{j}}(\bm{x})),\end{array}(17)

where 𝒆=𝕀(fE(𝒙)𝒚)𝒆𝕀subscript𝑓𝐸𝒙𝒚\bm{e}=\mathbb{I}\left(f_{E}(\bm{x})\neq\bm{y}\right) indicates where the ensemble prediction fE(𝒙)subscript𝑓𝐸𝒙f_{E}(\bm{x}) make erroneous predictions compared to the ground truth 𝒚𝒚\bm{y} at voxel-level. We note that focusing only on such misclassified voxels can better manifest the difference between functions, because in practice, most of the voxels in an image 𝒙𝒙\bm{x} are correctly classified (e.g., background), leading to an over-optimistically high agreement despite the diversity in organ segmentation.

2.5 Voxel-wise Calibration Metrics

To quantify the performance of voxel-wise calibration and uncertainty estimation, we use the Expected Calibration Error (ECE) (Guo et al., 2017), the Brier score (Br) (Brier et al., 1950) and the negative log-likelihood (NLL) (Ovadia et al., 2019). For a segmentation task with N𝑁N voxels in total, the confidence score ranging from 0%percent00\% to 100%percent100100\% is equally divided into B𝐵B bins and the ECE score is defined as:

ECE=i=1B|Bi|N|conf(Bi)acc(Bi)|ECEsuperscriptsubscript𝑖1𝐵subscript𝐵𝑖𝑁confsubscript𝐵𝑖accsubscript𝐵𝑖\mathrm{ECE}=\sum_{i=1}^{B}\frac{|B_{i}|}{N}\cdot\left|\mathrm{conf}(B_{i})-\mathrm{acc}(B_{i})\right|(18)

where Bisubscript𝐵𝑖B_{i} is the set of voxels whose confidence falls into the ithsuperscript𝑖𝑡i^{th} bin, conf(Bi)confsubscript𝐵𝑖\mathrm{conf}(B_{i}) is the mean confidence and acc(Bi)accsubscript𝐵𝑖\mathrm{acc}(B_{i}) is the mean accuracy. The Brier score quantifies the deviation of predictive categorical distribution from the ground truth one-hot label:

Br=1Ni=1Nc=1C[p(𝒚i=c|𝒙)(𝒚i==c)]2\mathrm{Br}=\frac{1}{N}\sum_{i=1}^{N}\sum_{c=1}^{C}\left[p(\bm{y}_{i}^{*}=c|\bm{x}^{*})-(\bm{y}_{i}==c)\right]^{2}(19)

and the NLL metric is defined as:

NLL=1Ni=1Nc=1C(𝒚i==c)logp(𝒚i=c|𝒙).\mathrm{NLL}=-\frac{1}{N}\sum_{i=1}^{N}\sum_{c=1}^{C}(\bm{y}_{i}==c)\cdot\log p(\bm{y}_{i}^{*}=c|\bm{x}^{*}).(20)

2.6 Image-level Confidence Score and Failure Detection

Refer to caption
Figure 2: Uncertainty maps indicate possible over- and under-segmentation. We estimate the true foreground (TF), false foreground(FF), and false background (FB) using the estimated uncertainty and aggregate them into the final image-level score.

Based on the estimated entropy map, we aggregate the voxel-wise uncertainty and derive an image-level confidence score as a segmentation failure indicator. Specifically, we estimate the correct segmentation (true foreground, TF), the over-segmentation (false foreground, FF), and under-segmentation (false background, FB) areas as follows:

TFTF\displaystyle\mathrm{TF}=Sc(1Hc),absentsubscript𝑆𝑐1subscript𝐻𝑐\displaystyle=S_{c}\cdot(1-H_{c}),(21)
FFFF\displaystyle\mathrm{FF}=ScHc,absentsubscript𝑆𝑐subscript𝐻𝑐\displaystyle=S_{c}\cdot H_{c},
FBFB\displaystyle\mathrm{FB}=(1Sc)Hc,absent1subscript𝑆𝑐subscript𝐻𝑐\displaystyle=(1-S_{c})\cdot H_{c},

where Scsubscript𝑆𝑐S_{c} is the segmentation map for class c𝑐c and Hcsubscript𝐻𝑐H_{c} is the corresponding entropy map. The final confidence score of the generated segmentation map Scsubscript𝑆𝑐S_{c} with uncertainty Hcsubscript𝐻𝑐H_{c} is given by simulating the Dice coefficient:

C(Sc)=2|TF|2|TF|+|FF|+|FB|.𝐶subscript𝑆𝑐2𝑇𝐹2𝑇𝐹𝐹𝐹𝐹𝐵\displaystyle C(S_{c})=\frac{2|TF|}{2|TF|+|FF|+|FB|}.(22)

Examples of the estimated TF, FF, and FB maps are shown in Fig. 2. We detect possible failures based on the computed confidence score C(Sc)𝐶subscript𝑆𝑐C(S_{c}) and measure the performance of image-level failure detection by the area under the receiver operating characteristic curve (AUC). Empirically, for cardiac MRI applications, a segmentation prediction on a 2D image slice with a dice score lower than 80% and an average symmetric surface distance (ASSD) greater than 2mm is considered a segmentation failure.

3 Experiments

3.1 Dataset

Refer to caption
Figure 3: Representative images of ACDC, M&M and QMRI datasets. ACDC and M&M are SSFP cine images and the contrast variation is relatively minor. QMRI baseline images have a larger contrast change compared to the training set (ACDC).

We evaluated the proposed HMC-CP method on nnU-Net (Isensee et al., 2021), an established U-Net architecture, for the cardiac MRI task. We use the ACDC dataset (Bernard et al., 2018) for training and validation, which consists of short-axis end-diastolic (ED) and end-systolic (ES) cardiac MRI steady-state free precession (SSFP) cine images of 100 subjects acquired at 1.5T (Aera, Siemens Healthineers, Erlangen, Germany) and 3T (Tim Trio, Siemens Healthineers, Erlangen Germany). The original training part of the dataset with 100 subjects was randomly split into a training set (80%percent8080\%) and a validation set (20%percent2020\%). The validation set was used for selecting the best temperature and studying the influence of the number of samples on the segmentation and calibration performance. Based on the validation results, we evaluate our methods on the ACDC test set consisting of 50 subjects to test the in-domain performance. For out-of-domain performance, we tested the proposed method and other methods in comparison on a completely independent dataset, the Multi-center Multi-vendor (M&M) cardiac MRI (Campello et al., 2021), which contains 320 SSFP cine scans of end-systolic (ES) and end-diastolic (ED) images collected from 6 medical centers using different 1.5T -(Siemens Avanto, Germany, Philips Achieva, Netherlands; GE Signa Excite; and Cannon Vantage Orian); and 3T scanners (Siemens Skyra, Germany). Additionally, we also evaluate the proposed method on a quantitative MRI (QMRI) dataset, containing 112 modified look-locker inversion recovery (MOLLI) T1subscript𝑇1T_{1}-mapping (Messroghli et al., 2004) images and T2subscript𝑇2T_{2}-prep-based T2subscript𝑇2T_{2}-mapping (Giri et al., 2009) images. The images are collected from 8 healthy subjects at 3T (Prisma, Siemens Healthineers, Erlangen, Germany). Each image contains several (8 for T1subscript𝑇1T_{1}-mapping and 5 for T2-mapping) baseline images that were read out during the T1subscript𝑇1T_{1} or T2subscript𝑇2T_{2} relaxation processes. We show some exemplar images of the three datasets in Fig. 3. Three classes of ground truth labels are provided in the ACDC and M&M datasets: left ventricle cavity (LV), myocardium (MYO), and right ventricle (RV). For the QMRI dataset, the LV and MYO regions were manually annotated on the second baseline image which has relatively good contrast.

3.2 Methods in Comparison

We implemented several baseline methods including PHi-Seg (Baumgartner et al., 2019), Bayesian SegNet with Dropout (Kendall et al., 2015), Deep Ensembles (Lakshminarayanan et al., 2016; Mehrtash et al., 2020), and compared them to the proposed method in terms of both segmentation and uncertainty estimation. We used the automatically configured nnU-Net  (Isensee et al., 2021) architecture, which is a commonly used reference for medical image segmentation. All methods were trained with 100010001000 epochs. We set the initial learning rate to be η0=0.02subscript𝜂00.02\eta_{0}=0.02 and used a fixed batch size of 404040 for all methods.

  • PHi-Seg We implemented the PHi-Seg (Baumgartner et al., 2019) in the nnU-Net framework with 666 resolution levels, 555 latent levels and latent feature depth 444 for the prior, likelihood and posterior networks. At inference time, we drew M=30𝑀30M=30 realizations of the hierarchical latent features from the prior network output and decoded the features with the likelihood network.

  • MC-Dropout Following the Bayesian SegNet work (Kendall et al., 2015), we inserted dropout layers into the innermost three layers on both the encoder- and decoder-side of the U-Net. The dropout rate was set as p=0.5𝑝0.5p=0.5 at both the training and testing phases and we ran M=30𝑀30M=30 forward passes at test time.

  • Deep Ensembles Deep Ensembles of 15 models were trained by SGD-momentum with random initialization for 100010001000 epochs with the standard exponential learning rate decay.

  • SGHMC Variants We ran the proposed SGHMC method for NC=3subscript𝑁𝐶3N_{C}=3 cycles of 333333333 epochs. In each cycle, the first γ=0.60𝛾0.60\gamma=0.60 fraction of each cycle was the burn-in stage. Afterward, the noise was injected into the momentum update. The noise level is controlled by the temperature T𝑇T as in Eq. eq:sghmc_discrete. To investigate the effect of cold posterior, we trained networks using SGHMC with temperature T{0,106,105,104,103}𝑇0superscript106superscript105superscript104superscript103T\in\{0,10^{-6},10^{-5},10^{-4},10^{-3}\}, where T=0𝑇0T=0 corresponds to SGD-momentum with constant learning rate (SGD-Const) (Zhao et al., 2022). The restart learning rate was set to ηr=0.2subscript𝜂𝑟0.2\eta_{r}=0.2 and restarting epochs Tr=10subscript𝑇𝑟10T_{r}=10. The checkpoints were collected every 444 epoch after the burn-in stage as posterior samples until the end of training. The weight evolution in the last training cycle was recorded for the single-mode sampling baselines (SGHMC-Single), and the checkpoints across all three modes form a set of multi-modal weight samples (SGHMC-Multi).

4 Results

4.1 Posterior Geometry and Chain Trajectory of SGHMC

Refer to caption
Figure 4: Loss landscape and chain trajectory during training: (a) The loss landscape around a MAP solution. (b) Applying cyclical training promotes the diversity of solutions. The triangular marks indicate the three modes of solutions on the loss surface in three training cycles111We visualize the training trajectory of the loss landscape of trained U-Nets on 2D planes. For the checkpoints belonging to the same posterior mode 𝑾={𝒘tj}j=1M𝑾superscriptsubscriptsubscript𝒘subscript𝑡𝑗𝑗1𝑀\bm{W}=\{\bm{w}_{t_{j}}\}_{j=1}^{M}, we perform singular value decomposition on centered 𝑾𝑾\bm{W} to find the first five principal components 𝒗p,p{1,2,,5}subscript𝒗𝑝𝑝125\bm{v}_{p},p\in\{1,2,\dots,5\}. The validation loss is then visualized via evaluating (𝒘¯+a𝒗2+b𝒗3)¯𝒘𝑎subscript𝒗2𝑏subscript𝒗3\mathcal{L}(\bar{\bm{w}}+a\bm{v}_{2}+b\bm{v}_{3}) as a function of (a,b)𝑎𝑏(a,b), where 𝒘¯=1Mj=1M𝒘tj¯𝒘1𝑀superscriptsubscript𝑗1𝑀subscript𝒘subscript𝑡𝑗\bar{\bm{w}}=\frac{1}{M}\sum_{j=1}^{M}\bm{w}_{t_{j}}. For weights 𝒘1subscript𝒘1\bm{w}_{1}, 𝒘2subscript𝒘2\bm{w}_{2} and 𝒘3subscript𝒘3\bm{w}_{3} drawn from three posterior modes, we performed the Gram-Schmidt orthogonalization (Garipov et al., 2018) of 𝒘2𝒘1subscript𝒘2subscript𝒘1\bm{w}_{2}-\bm{w}_{1} and 𝒘3𝒘1subscript𝒘3subscript𝒘1\bm{w}_{3}-\bm{w}_{1} and used the resultant orthogonal vectors 𝒖𝒖\bm{u} and 𝒗𝒗\bm{v} as the base. We visualize the surface of (𝒘1+a𝒖+b𝒗)subscript𝒘1𝑎𝒖𝑏𝒗\mathcal{L}(\bm{w}_{1}+a\bm{u}+b\bm{v}) with varying (a,b)𝑎𝑏(a,b) (Garipov et al., 2018).. (c) The t-SNE map of the collected weight samples illustrates three clusters of local weight samples. (d) Cosine similarity of weight samples collected in three cycles, suggests that weights drawn from a single cycle (mode) of the chain correlate with each other, while weight modes from different cycles are diverse.

Fig. 1 (b) depicts the loss surface on the interpolated plane collected from three training cycles, illustrating the multi-modality of U-Net solution space. Via cyclical learning, the weight iterations visited multiple posterior modes, which can be observed from the t-SNE visualization of the training trajectory in Fig. 1 (c). Fig. 1 (d) visualizes the cosine similarity (Larrazabal et al., 2021) of checkpoints collected in three cycles which shows that the local weight checkpoints are similar to each other, while the cyclical training promotes the orthogonality of weights in different modes.

4.2 Diversity in Function Space

A high degree of diversity in the function space leads to better uncertainty estimation. We analyzed the functional diversity of all methods in comparison via evaluating the distance between function instances defined in Eq. eq:diversity_measure. The result is shown in Fig. 5 (a)-(e) as confusion matrices. The mutual diversity levels for all methods are summarised in Fig. 5 (f). We show that the PHi-Seg and MC-Dropout methods have lower functional diversity in the functional space compared to Deep Ensembles and the proposed HMC variants, namely SGHMC-Single and SGHMC-Multi. The SGHMC-Single model yielded slightly lower diversity than that of Deep Ensembles, however, the SGHMC-Multi model showed the highest diversity, surpassing Deep Ensembles.

Refer to caption
Figure 5: Confusion matrices that show the diversity of functions of PHi-Seg (a), MC-Dropout (b), Deep Ensembles (c), and our proposed SGHMC variants, SGHMC-Single (d) and SGHMC-Multi (e). The ensemble of all function instances is denoted as E, at the lower-right corner of the matrices. Each entry in the confusion matrix represents the mutual distance in the function space of two functions, defined in Sec. 2.4. (f) sums up the functional distance values from (a) to (e) and illustrates the mean of rows in the confusion matrices.

4.3 The Effect of Cold Posterior

Refer to caption
Figure 6: The cold posterior effect: (a) Calibration performance with various temperatures. (b) Mean Dice over LV/MYO/RV on the validation set of each posterior sample. (c)-(d) The functional and weight space diversity w.r.t. varying temperature.

We also studied the effect of varying temperatures on the calibration and segmentation in both with or without augmentation cases. Fig. 6 (a) shows the calibration performance variation with an increasing temperature from T=0𝑇0T=0 to 103superscript10310^{-3}. From the figure, we observe that the model is not in favor of a cold posterior when augmentation is turned off, as the NLL consistently improves as the temperature increases from 00 to 103superscript10310^{-3}. However, the NLL is relatively less sensitive to the changes in temperature when data augmentation is on and the NLL drop is relatively marginal compared with the without augmentation case. Fig.  6 (b) shows that the mean Dice across LV/MYO/RV drops in both cases as the temperature increases. We conjecture that this is because of the sampling on sub-optimal loss levels because of injected noise. The best segmentation accuracy is achieved at T=105𝑇superscript105T=10^{-5}. Fig.  6 (c) - (d) reveals that higher temperature drives the chain to explore broader weight space because the weight volume increases exponentially with an increasing temperature. However, the functional diversity is more sensitive to the weight space volume change as is shown in Fig. 6 (c) when the augmentation is turned off. When the augmentation is on, we observe that a cold posterior at T=105𝑇superscript105T=10^{-5} provides good calibration and improved segmentation performance. In the following, we use T=105𝑇superscript105T=10^{-5} to evaluate our method on the test sets.

4.4 The Effect of Prior

In Fig. 7, we study the effect of varying prior strength λ𝜆\lambda defined in Eq. eq:loss_function. Smaller λ𝜆\lambda indicates a higher prior variance and thus a weaker prior assumption. From Fig. 7 (a) - (b), we observe that the stronger prior assumptions with λ=3×104𝜆3superscript104\lambda=3\times 10^{-4} cause a significant performance drop in Dice, but the calibration performance measured by NLL was improved. As the prior strength increases, its contribution to the posterior geometry is more pronounced and the likelihood part that fits the training data can end up with a sub-optimal level and thus lead to a lower accuracy. However, smaller prior λ=3×106𝜆3superscript106\lambda=3\times 10^{-6} can cause a slight performance drop compared with λ=3×105𝜆3superscript105\lambda=3\times 10^{-5}, which shows that a proper regularization is beneficial to the accuracy.

Refer to caption
Figure 7: The effect of prior λ𝜆\lambda: (a) Mean Dice over LV/MYO/RV on the validation set of each posterior sample. (b) Calibration performance with various prior strengths.

4.5 Calibration and Segmentation Performance

Refer to caption
Figure 8: Calibration quality and segmentation performance as a function of the number of models averaged on the validation set. (a) - (c) illustrate calibration metrics: ECE (a), Br (b), and NLL (c) versus the number of models averaged. Averaging M15𝑀15M\geq 15 models using SGHMC significantly improves the network calibration averaged compared to relying solely on a single weight setting. Ensembles of SGHMC samples consistently improve the segmentation and achieve better performance than MC-Dropout and PHi-Seg. The best calibration is achieved by SGHMC with a cyclical annealing learning rate (SGHMC-Multi). (d-f) illustrate the segmentation performance versus several models for three segmentation classes: LV (d), MYO (e), and RV (f). Ensembles of multi-modal SGHMC samples achieved the best performance on RV and MYO.

In Fig. 8, we show the calibration quality and segmentation performance as a function of M𝑀M, which is the number of models (predictions) averaged on the in-distribution validation set. The figure shows that the segmentation performance of the PHi-Seg framework is not competitive with other methods of comparison. Ensembling consistently improves the segmentation performance for all methods, however, the segmentation performance improvement via ensembling MC-Dropout and PHi-Seg predictions is relatively marginal compared to Deep Ensembles and HMC, in accordance with our finding that PHi-Seg and MC-Dropout lack functional diversity. On the in-distribution validation set, an ensemble of 303030 cyclical SGHMC (SGHMC-Multi) samples achieved the best performance. Fig. 8 (a)-(c) list the calibration results measured by ECE, Brier score, and NLL respectively. From the figures, we observe that a single model M=1𝑀1M=1 has poor calibration for all methods in comparison while combining more predictions consistently improves the model calibration. MC-Dropout improves calibration but is not as good as the ensemble of SGHMC with a constant learning rate (SGD-Const) at T=0𝑇0T=0. Compared with SGD-Const (T=0𝑇0T=0), SGHMC-Single (T=105𝑇superscript105T=10^{-5}) achieved better calibration performance, which was further surpassed by SGHMC-Multi, which ensembles from multiple posterior modes.

Table 1: Segmentation performance measured by the Dice score [%] on ACDC, M&M, and QMRI datasets. Highlighted are the best and the second best results.
MethodsACDCM&MQMRI
LVMYORVLVMYORVLVMYO
Vanilla  (M=1)92.80±5.1587.03±3.6590.08±5.7188.64±7.7582.97±5.4986.34±9.3674.17±38.4563.88±31.51
PhiSeg   (M=30)93.64±3.8187.84±3.5088.47±7.3289.79±8.1184.05±5.6084.83±11.1687.11±25.8774.59±21.77
Dropout (M=30)93.36±4.4187.47±3.6689.50±6.5590.19±6.4983.30±5.4586.45±9.1685.44±28.0872.51±23.96
Deep Ens. (M=30)93.84±3.7987.91±3.5490.92±5.3990.13±6.9684.48±5.2587.43±8.5382.31±31.5270.29±27.04
SGD Const. LR (M=30)93.70±4.2787.97±3.5790.61±5.8890.28±6.7184.21±5.1487.26±9.8085.77±27.3873.11±23.80
SGHMC, single (M=30)93.54±5.1387.91±3.4590.16±6.6590.48±6.6484.21±5.2687.40±8.9886.31±26.1174.22±22.27
SGHMC, multi (M=30)93.88±3.8287.97±3.6590.74±5.7990.31±6.8884.43±5.2287.38±9.0389.88±20.3076.20±19.16
Refer to caption
Figure 9: Segmentation performance evaluated on ACDC, M&M, and QMRI datasets. Both HMC variants (single- and multi-modal samples) substantially improved the segmentation performance compared to the single model prediction (Vanilla), especially on the QMRI dataset with a strong domain shift.

The segmentation performance measured by Dice score is listed in Fig. 9 and Table 1. The results show that the proposed method improves segmentation performance compared to the Vanilla model on all test sets. This indicates that Bayesian inference via averaging posterior samples leads to more accurate prediction. On the ACDC test set and M&M, the performance of Deep Ensemble and the proposed method are marginally better than other methods in comparison. There is little difference between the segmentation performance of the single-modal and multi-modal variants of the proposed model. Additionally, we observe that there exists little difference between the proposed method and traditional methods like Deep Ensembles. However, it is not our primary purpose to significantly increase the segmentation performance against the traditional methods. Instead, we are more interested in the calibration and uncertainty quantification performance.

The results also show that domain shifts cause a performance drop, but our proposed method by marginalizing HMC samples is more robust to the contrast changes. Comparing the results on the three datasets, we observe a drastic performance drop as the domain shift increases from cine (ACDC, M&M) to QMRI. The MYO Dice of the vanilla model drops from 87.03%percent87.0387.03\% to 63.88%percent63.8863.88\% on QMRI. However, on the QMRI dataset with the largest domain shift, the proposed method exhibits high robustness and achieved the highest Dice score on both LV (89.88%percent89.8889.88\%) and MYO (76.20%percent76.2076.20\%).

Table 2: Calibration performance measured by the ECE, Br, and NLL on ACDC, M&M, and QMRI datasets. Highlighted are the best and the second best results.
MethodsACDCM&MQMRI
ECE[%] \downarrowBr [%]\downarrowNLL[%] \downarrowECE [%] \downarrowBr [%]\downarrowNLL [%]\downarrowECE [%]\downarrowBr [%]\downarrowNLL[%]\downarrow
Vanilla   (M=1)5.45±1.632.76±0.809.45±3.876.49±2.813.28±1.4011.96±7.2529.66±25.3914.92±12.6576.01±83.30
PhiSeg   (M=30)4.55±1.862.61±0.966.91±3.945.55±2.823.11±1.448.68±5.5514.87±18.658.03±9.1628.23±48.75
Dropout (M=30)5.09±1.712.69±0.878.10±3.496.19±2.503.23±1.2910.66±5.7522.79±19.0611.60±9.1042.03±47.43
Deep Ens. (M=15)3.49±1.202.20±0.635.55±2.014.49±2.142.71±1.127.51±4.3520.67±19.9410.73±9.3328.61±37.71
SGD Const. LR (M=30)3.68±1.402.27±0.705.38±1.934.81±2.352.81±1.247.65±4.6419.92±18.6510.35±8.9826.91±33.44
SGHMC, single (M=30)3.54±1.232.24±0.645.17±1.684.75±2.232.81±1.197.43±4.2120.11±17.1110.33±8.0326.11±27.29
SGHMC, multi (M=30)3.29±1.182.19±0.634.89±1.534.38±2.192.72±1.176.91±4.0019.81±15.6010.18±7.5625.16±24.91
Refer to caption
Figure 10: Voxel-wise calibration performance measured by NLL on the ACDC, M&M, and QMRI datasets. HMC variants significantly improved the calibration and achieved the best calibration score measured by NLL.

In Fig. 10 and Table 2, we report the voxel-wise calibration metrics of methods in comparison. The figure and table show that the calibration performance decreases consistently as the domain shift becomes larger. Overall, the multi-modal SGHMC predictions deliver the best calibration performance, significantly better than baseline methods like MC-Dropout and Phi-Seg on the cine datasets (ACDC and M&M). This is in accordance with the results in Sec. 4.2 which show that the HMC samples have higher functional diversity. On the QMRI dataset, the PhiSeg network achieved the best ECE and Brier score, but the single-modal and multi-modal SGHMC variants have the lowest NLL. Comparing the results of Dropout and Deep Ensemble, the proposed method is robustly well-calibrated even in the presence of a large domain shift.

4.6 Automated Failure Detection

Table 3: AUC [%] of failure detection on ACDC, M&M, and QMRI datasets. Highlighted are the best and the second best results.
MethodsACDCM&MQMRI
LVMYORVLVMYORVLVMYO
Vanilla73.5483.4369.8085.6788.6979.5622.2237.04
PhiSeg59.4160.0565.8871.0271.3174.1022.0143.58
Dropout92.3596.9984.5787.2389.2384.9462.4770.75
Deep Ens.93.2196.4284.5788.4390.7185.9393.6889.41
SGHMC, single91.4997.2788.2088.3191.0886.9886.8391.47
SGHMC, multi91.1295.5488.2189.3392.0586.7092.8889.20
Refer to caption
Figure 11: ROC of RV segmentation failure detection on ACDC (a) and M&M datasets (b), and MYO segmentation failure detection (c) on QMRI dataset.

With the proposed image-level uncertainty score, we could automatically detect segmentation failure for the test datasets. Table 3 lists the AUC values of failure detection on three datasets. MC-Dropout, deep ensemble, and the proposed HMC variants achieved similar performances in failure detection of LV and MYO segmentations on the ACDC dataset. However, the proposed method can detect segmentation failure of the most challenging anatomy RV better with an AUC of 88.21%, which is higher than deep ensemble (84.57%) and MC-Dropout (84.57%). This is also reflected in the ROC curve in Fig. 11 (a) which shows that the ROC of the proposed method encloses that of Deep Ensemble and MC-Dropout. The improvement is also observed in M&M and QMRI datasets. The proposed method achieved a remarkable AUC (91.47%) on the QMRI dataset, significantly outperforming the MC-Dropout(70.75%). Despite the good voxel-wise calibration, PHi-Seg suffers from severe silent failures on QMRI with an AUC of 43.58%. More detailed ROC curves are listed in Appendix A, Fig. 15.

4.7 Qualitative Results

Refer to caption
Figure 12: Visualization of predictions by HMC samples on both in-domain (ACDC) and out-of-domain (M&M, QMRI) images.

Fig. 12 visualizes the segmentation predictions made by HMC samples. From the figure, we can observe a consistency across these predictions on a certain input, for example, on the cine case from the ACDC dataset. However, on uncertain inputs like the M&M cine case, the samples tend to make diverse predictions on the RV basal areas. As the contrast change increases further, the network also makes different predictions on the LV blood pool and myocardium because of the low contrast of QMRI images.

Refer to caption
Figure 13: Segmentation predictions with Dice scores of LV/MYO/RV (the first row) on the QMRI images and the corresponding confidence scores (Conf.) placed on top pf pixel-wise uncertainty maps (the second row). Red values indicate high confidence scores on segmentation failures.
Refer to caption
Figure 14: Segmentation predictions with Dice scores of LV/MYO/RV (the first row) on the QMRI images and the corresponding confidence scores (Conf.) placed on top pf pixel-wise uncertainty maps (the second row). The Dice values of LV and MYO are listed on the first row and the confidence scores are listed on the second row. Red values indicate high confidence scores on segmentation failures.

In Fig. 13, we visualize the predictions on the cine images and estimated uncertainty maps produced by all methods in comparison. Fig. 13 (a) depicts a middle slice image on which all the methods make accurate predictions and the uncertainty concentrates only on the border between anatomical structures. However, in all four cases, the vanilla network can only output high uncertainty on the border area, even in the presence of erroneous predictions (i.e. silent failure). In general, for cardiac MRI segmentation tasks, uncertainty, and segmentation failure occur more frequently on the right ventricle. A typical failure case is shown in Figure. 13 (b), on which the Phi-Seg outputs zero uncertainty on RV voxels.

Figure. 14 (a) shows a QMRI image case on which the network successfully segmented MYO and LV, and a high level of confidence scores are provided by all the methods. Fig. 14 (b) depicts the uncertainty estimation results on a hard QMRI image. Under the strong domain shift, PHiSeg and MC-Dropout failed to detect the segmentation failures. In all cases, our proposed HMC variants demonstrated much better uncertainty estimation performance, without the risk of “silent failures”.

5 Discussion and Conclusion

DL models have achieved extraordinary performance for medical image segmentation tasks. However, they are generally miscalibrated and can fail silently. This seriously undermines DL models’ trustworthiness in clinical utilization. In this work, we propose a framework of Bayesian deep learning for uncertainty estimation and failure detection for DL-based medical image segmentation.

To make the high-dimensional Bayesian segmentation problem computationally tractable, we formulated it as a posterior sampling problem, which can be solved by HMC. Diversity in the functional space plays a critical role in uncertainty estimation. We have shown that, in the context of cardiac MRI segmentation, the HMC can effectively produce network samples that make diverse segmentation predictions with a thinning of 4 epochs in our practice. Moreover, ensembling more than 151515 samples in one chain simulation can provide comparable or even superior performance compared to training 151515 deep ensemble models. Comparing the functional diversity of various methods, we observe that degraded calibration performance strongly correlates with the lack of functional diversity. HMC samples and Deep Ensembles can produce functions of high diversity and their calibration is also superior to other methods including PHi-Seg and MC-Dropout.

We also noticed that cold posterior (CP) is crucial to an accurate and well-calibrated prediction for segmentation purposes. Specifically, we observe that HMC at temperature T=0𝑇0T=0 forms an outstanding baseline for model calibration. This method is also successfully validated on the segmentation of shapes with multiple connected components like lymph nodes (Salahuddin et al., 2023). Moreover, the performance can be further improved by increasing to a cold temperature at T=105𝑇superscript105T=10^{-5}. On the other hand, when the data augmentation is turned off, a cold temperature causes a significant drop in model calibration.

In general, the marginalization of diverse functions is beneficial to the segmentation accuracy. All methods improve the segmentation accuracy by combining predictions made by various network samples. A relatively higher function space diversity improves the segmentation accuracy, comparing HMC to PHi-Seg and MC-Dropout. However, overly high function space diversity could harm the segmentation accuracy. HMC with a higher temperature T104𝑇superscript104T\geq 10^{-4} produces an elevated level of functional diversity but the segmentation accuracy drops. In contrast, the calibration consistently improves with an increasing temperature.

In particular, our proposed method has addressed the two major issues of current Bayesian learning methods for large, over-parameterized neural networks. First, it makes no restrictive assumptions on weight posterior as the VI family of Bayesian methods (including the VI proxy such as MC Dropout). Second, with our proposed cyclical annealing strategy, the method is highly computationally efficient, only consuming the same computation budget as a single-round standard network training. The resemblance between HMC and SGD with momentum makes posterior sampling as straightforward as saving checkpoints during network training. Nonetheless, we acknowledge that with more computational resources, Deep Ensembles can also be trained in parallel, which is also time-efficient and further enables combining predictions from various network architectures. We did not cover the ensembling of different architectures in our work, which we see as a limitation. However, our method still forms an outstanding baseline with less training effort and energy consumption.

In practice, HMC parameters like batch size and learning rate can directly affect the simulation (training) dynamics, because the batch size determines the strength of the gradient estimation noise and the learning rate controls the range of exploration. We used a radical learning rate to promote diversity but the influence of batch size and learning rate on the final performance has not been fully discussed. We see this as a limitation and an interesting future work direction.

By the proposed image-level confidence score, we can also automatically detect the possible segmentation failure on each image. We showed that the automatic failure detection is highly robust on both in-domain cine images and QMRI images with a strong domain shift with AUC above 86%percent8686\%. The HMC approach is especially robust to large domain shifts like cine to QMRI, being the most robust one to detect myocardial segmentation failure in quantitative CMR with an AUC of 91%percent9191\%. When integrated into the DL workflow, automatic failure detection potentially improves the trustworthiness of DL models deployed on large-scale clinical studies or in daily clinical practice.

In conclusion, we have proposed a Bayesian learning framework for uncertainty estimation of medical image segmentation, by Hamiltonian Monte Carlo with cold posterior (HMC-CP). HMC-CP is theoretically grounded, computationally efficient, and scalable to large medical image segmentation networks. Our extensive experiments on both in-domain and out-of-domain data showed that the proposed HMC-CP method results in more reliable uncertainty estimation, as well as more accurate image segmentation, compared with a range of state-of-the-art baselines. Importantly, by reliable uncertainty estimation, our method provides a promising direction for improving the trustworthiness of DL models in clinical applications.


Acknowledgments

YZ and QT acknowledge the TU Delft AI Initiative for financial support, and the NVIDIA Applied Research Accelerator Program for granting high-performance computing resources. JT and SW are supported by the NWO Start-up grant STU.019.024, the TU Delft - Erasmus MC Convergence Impulse initiative, the European Union (ERC, Vascular ID, 101078711), and the Netherlands Heart Foundation Dekker Grant. TAT is supported by the British Heart Foundation (IRF FS/19/35/34374). IP and TAT are directly and indirectly supported by the University College London Hospitals NIHR Biomedical Research Centre and Biomedical Research Unit at Barts Hospital respectively.


Ethical Standards

The work follows appropriate ethical standards in conducting research and writing the manuscript, following all applicable laws and regulations regarding the treatment of animals or human subjects.


Conflicts of Interest

We declare we don’t have conflicts of interest.

References

  • Abdar et al. (2021) Moloud Abdar, Farhad Pourpanah, Sadiq Hussain, Dana Rezazadegan, Li Liu, Mohammad Ghavamzadeh, Paul Fieguth, Xiaochun Cao, Abbas Khosravi, U Rajendra Acharya, et al. A review of uncertainty quantification in deep learning: Techniques, applications and challenges. Information Fusion, 76:243–297, 2021.
  • Baumgartner et al. (2019) Christian F Baumgartner, Kerem C Tezcan, Krishna Chaitanya, Andreas M Hötker, Urs J Muehlematter, Khoschy Schawkat, Anton S Becker, Olivio Donati, and Ender Konukoglu. Phiseg: Capturing uncertainty in medical image segmentation. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pages 119–127. Springer, 2019.
  • Bernard et al. (2018) Olivier Bernard, Alain Lalande, Clement Zotti, Frederick Cervenansky, Xin Yang, Pheng-Ann Heng, Irem Cetin, Karim Lekadir, Oscar Camara, Miguel Angel Gonzalez Ballester, et al. Deep learning techniques for automatic mri cardiac multi-structures segmentation and diagnosis: is the problem solved? IEEE transactions on medical imaging, 37(11):2514–2525, 2018.
  • Blei et al. (2017) David M Blei, Alp Kucukelbir, and Jon D McAuliffe. Variational inference: A review for statisticians. Journal of the American statistical Association, 112(518):859–877, 2017.
  • Blundell et al. (2015) Charles Blundell, Julien Cornebise, Koray Kavukcuoglu, and Daan Wierstra. Weight uncertainty in neural network. In International Conference on Machine Learning, pages 1613–1622. PMLR, 2015.
  • Brier et al. (1950) Glenn W Brier et al. Verification of forecasts expressed in terms of probability. Monthly weather review, 78(1):1–3, 1950.
  • Campello et al. (2021) Víctor M Campello, Polyxeni Gkontra, Cristian Izquierdo, Carlos Martín-Isla, Alireza Sojoudi, Peter M Full, Klaus Maier-Hein, Yao Zhang, Zhiqiang He, Jun Ma, et al. Multi-centre, multi-vendor and multi-disease cardiac segmentation: the m&ms challenge. IEEE Transactions on Medical Imaging, 40(12):3543–3554, 2021.
  • Carvalho et al. (2020) Eduardo DC Carvalho, Ronald Clark, Andrea Nicastro, and Paul HJ Kelly. Scalable uncertainty for computer vision with functional variational inference. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 12003–12013, 2020.
  • Chen et al. (2020) Chen Chen, Chen Qin, Huaqi Qiu, Giacomo Tarroni, Jinming Duan, Wenjia Bai, and Daniel Rueckert. Deep learning for cardiac image segmentation: a review. Frontiers in Cardiovascular Medicine, 7:25, 2020.
  • Chen et al. (2014) Tianqi Chen, Emily Fox, and Carlos Guestrin. Stochastic gradient hamiltonian monte carlo. In International conference on machine learning, pages 1683–1691. PMLR, 2014.
  • (11) P. Chlap, H. Min, N. Vandenberg, J. Dowling, L. Holloway, and A. Haworth. A review of medical image data augmentation techniques for deep learning applications. Journal of Medical Imaging and Radiation Oncology.
  • Czolbe et al. (2021) Steffen Czolbe, Kasra Arnavaz, Oswin Krause, and Aasa Feragen. Is segmentation uncertainty useful? In Information Processing in Medical Imaging: 27th International Conference, IPMI 2021, Virtual Event, June 28–June 30, 2021, Proceedings 27, pages 715–726. Springer, 2021.
  • Der Kiureghian and Ditlevsen (2009) Armen Der Kiureghian and Ove Ditlevsen. Aleatory or epistemic? does it matter? Structural safety, 31(2):105–112, 2009.
  • Folgoc et al. (2021) Loic Le Folgoc, Vasileios Baltatzis, Sujal Desai, Anand Devaraj, Sam Ellis, Octavio E Martinez Manzanera, Arjun Nair, Huaqi Qiu, Julia Schnabel, and Ben Glocker. Is mc dropout bayesian? arXiv preprint arXiv:2110.04286, 2021.
  • Fort et al. (2019) Stanislav Fort, Huiyi Hu, and Balaji Lakshminarayanan. Deep ensembles: A loss landscape perspective. arXiv preprint arXiv:1912.02757, 2019.
  • Gal and Ghahramani (2016) Yarin Gal and Zoubin Ghahramani. Dropout as a bayesian approximation: Representing model uncertainty in deep learning. In international conference on machine learning, pages 1050–1059. PMLR, 2016.
  • Gao et al. (2023) Shangqi Gao, Hangqi Zhou, Yibo Gao, and Xiahai Zhuang. Bayeseg: Bayesian modeling for medical image segmentation with interpretable generalizability. arXiv preprint arXiv:2303.01710, 2023.
  • Garipov et al. (2018) Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry P Vetrov, and Andrew G Wilson. Loss surfaces, mode connectivity, and fast ensembling of dnns. Advances in neural information processing systems, 31, 2018.
  • Giri et al. (2009) Shivraman Giri, Yiu-Cho Chung, Ali Merchant, Georgeta Mihai, Sanjay Rajagopalan, Subha V Raman, and Orlando P Simonetti. T2 quantification for improved detection of myocardial edema. Journal of cardiovascular magnetic resonance, 11(1):1–13, 2009.
  • Gonzalez and Mukhopadhyay (2021) Camila Gonzalez and Anirban Mukhopadhyay. Self-supervised out-of-distribution detection for cardiac cmr segmentation. In Medical Imaging with Deep Learning, 2021.
  • Gonzalez et al. (2021) Camila Gonzalez, Karol Gotkowski, Andreas Bucher, Ricarda Fischbach, Isabel Kaltenborn, and Anirban Mukhopadhyay. Detecting when pre-trained nnu-net models fail silently for covid-19 lung lesion segmentation. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pages 304–314. Springer, 2021.
  • Guo et al. (2017) Chuan Guo, Geoff Pleiss, Yu Sun, and Kilian Q Weinberger. On calibration of modern neural networks. In International Conference on Machine Learning, pages 1321–1330. PMLR, 2017.
  • Gustafsson et al. (2020) Fredrik K Gustafsson, Martin Danelljan, and Thomas B Schon. Evaluating scalable bayesian deep learning methods for robust computer vision. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition workshops, pages 318–319, 2020.
  • Hammam et al. (2021) Ahmed Hammam, Seyed Eghbal Ghobadi, Frank Bonarens, and Christoph Stiller. Real-time uncertainty estimation based on intermediate layer variational inference. In Proceedings of the 5th ACM Computer Science in Cars Symposium, pages 1–9, 2021.
  • Hammersley (2013) John Hammersley. Monte carlo methods. Springer Science & Business Media, 2013.
  • Hastings (1970) W Keith Hastings. Monte carlo sampling methods using markov chains and their applications. 1970.
  • Hora (1996) Stephen C Hora. Aleatory and epistemic uncertainty in probability elicitation with an example from hazardous waste management. Reliability Engineering & System Safety, 54(2-3):217–223, 1996.
  • Huang et al. (2017) Gao Huang, Yixuan Li, Geoff Pleiss, Zhuang Liu, John E Hopcroft, and Kilian Q Weinberger. Snapshot ensembles: Train 1, get m for free. arXiv preprint arXiv:1704.00109, 2017.
  • Hüllermeier and Waegeman (2021) Eyke Hüllermeier and Willem Waegeman. Aleatoric and epistemic uncertainty in machine learning: An introduction to concepts and methods. Machine Learning, 110(3):457–506, 2021.
  • Isensee et al. (2021) Fabian Isensee, Paul F Jaeger, Simon AA Kohl, Jens Petersen, and Klaus H Maier-Hein. nnu-net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2):203–211, 2021.
  • Izmailov et al. (2018) Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson. Averaging weights leads to wider optima and better generalization. arXiv preprint arXiv:1803.05407, 2018.
  • Izmailov et al. (2021) Pavel Izmailov, Sharad Vikram, Matthew D Hoffman, and Andrew Gordon Gordon Wilson. What are bayesian neural network posteriors really like? In International conference on machine learning, pages 4629–4640. PMLR, 2021.
  • Jospin et al. (2022) Laurent Valentin Jospin, Hamid Laga, Farid Boussaid, Wray Buntine, and Mohammed Bennamoun. Hands-on bayesian neural networks—a tutorial for deep learning users. IEEE Computational Intelligence Magazine, 17(2):29–48, 2022.
  • Jungo and Reyes (2019) Alain Jungo and Mauricio Reyes. Assessing reliability and challenges of uncertainty estimations for medical image segmentation. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pages 48–56. Springer, 2019.
  • Jungo et al. (2020) Alain Jungo, Fabian Balsiger, and Mauricio Reyes. Analyzing the quality and challenges of uncertainty estimations for brain tumor segmentation. Frontiers in neuroscience, page 282, 2020.
  • Kapoor et al. (2022) Sanyam Kapoor, Wesley J Maddox, Pavel Izmailov, and Andrew Gordon Wilson. On uncertainty, tempering, and data augmentation in bayesian classification. arXiv preprint arXiv:2203.16481, 2022.
  • Kendall and Gal (2017) Alex Kendall and Yarin Gal. What uncertainties do we need in bayesian deep learning for computer vision? Advances in neural information processing systems, 30, 2017.
  • Kendall et al. (2015) Alex Kendall, Vijay Badrinarayanan, and Roberto Cipolla. Bayesian segnet: Model uncertainty in deep convolutional encoder-decoder architectures for scene understanding. arXiv preprint arXiv:1511.02680, 2015.
  • Kohl et al. (2018) Simon AA Kohl, Bernardino Romera-Paredes, Clemens Meyer, Jeffrey De Fauw, Joseph R Ledsam, Klaus H Maier-Hein, SM Eslami, Danilo Jimenez Rezende, and Olaf Ronneberger. A probabilistic u-net for segmentation of ambiguous images. arXiv preprint arXiv:1806.05034, 2018.
  • Lakshminarayanan et al. (2016) Balaji Lakshminarayanan, Alexander Pritzel, and Charles Blundell. Simple and scalable predictive uncertainty estimation using deep ensembles. arXiv preprint arXiv:1612.01474, 2016.
  • Lampinen and Vehtari (2001) Jouko Lampinen and Aki Vehtari. Bayesian approach for neural networks—review and case studies. Neural networks, 14(3):257–274, 2001.
  • Larrazabal et al. (2021) Agostina J Larrazabal, César Martínez, Jose Dolz, and Enzo Ferrante. Orthogonal ensemble networks for biomedical image segmentation. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pages 594–603. Springer, 2021.
  • Lifshitz and Landau (1984) EM Lifshitz and Lev Davidovich Landau. Statistical physics (course of theoretical physics, volume 5), 1984.
  • MacKay (1995) David JC MacKay. Bayesian neural networks and density networks. Nuclear Instruments and Methods in Physics Research Section A: Accelerators, Spectrometers, Detectors and Associated Equipment, 354(1):73–80, 1995.
  • Maddox et al. (2019) Wesley J Maddox, Pavel Izmailov, Timur Garipov, Dmitry P Vetrov, and Andrew Gordon Wilson. A simple baseline for bayesian uncertainty in deep learning. Advances in Neural Information Processing Systems, 32, 2019.
  • Marcot and Penman (2019) Bruce G Marcot and Trent D Penman. Advances in bayesian network modelling: Integration of modelling technologies. Environmental modelling & software, 111:386–393, 2019.
  • Mehrtash et al. (2020) Alireza Mehrtash, William M Wells, Clare M Tempany, Purang Abolmaesumi, and Tina Kapur. Confidence calibration and predictive uncertainty estimation for deep medical image segmentation. IEEE transactions on medical imaging, 39(12):3868–3878, 2020.
  • Messroghli et al. (2004) Daniel R Messroghli, Aleksandra Radjenovic, Sebastian Kozerke, David M Higgins, Mohan U Sivananthan, and John P Ridgway. Modified look-locker inversion recovery (molli) for high-resolution t1 mapping of the heart. Magnetic Resonance in Medicine: An Official Journal of the International Society for Magnetic Resonance in Medicine, 52(1):141–146, 2004.
  • Minderer et al. (2021) Matthias Minderer, Josip Djolonga, Rob Romijnders, Frances Hubis, Xiaohua Zhai, Neil Houlsby, Dustin Tran, and Mario Lucic. Revisiting the calibration of modern neural networks. Advances in Neural Information Processing Systems, 34:15682–15694, 2021.
  • Mingard et al. (2021) Chris Mingard, Guillermo Valle-Pérez, Joar Skalse, and Ard A Louis. Is sgd a bayesian sampler? well, almost. Journal of Machine Learning Research, 22, 2021.
  • Nabarro et al. (2021) Seth Nabarro, Stoil Ganev, Adria Garriga-Alonso, Vincent Fortuin, Mark van der Wilk, and Laurence Aitchison. Data augmentation in bayesian neural networks and the cold posterior effect. arXiv preprint arXiv:2106.05586, 2021.
  • Neal (2012) Radford M Neal. Bayesian learning for neural networks, volume 118. Springer Science & Business Media, 2012.
  • Neal et al. (2011) Radford M Neal et al. Mcmc using hamiltonian dynamics. Handbook of markov chain monte carlo, 2(11):2, 2011.
  • Ouyang et al. (2022) Cheng Ouyang, Shuo Wang, Chen Chen, Zeju Li, Wenjia Bai, Bernhard Kainz, and Daniel Rueckert. Improved post-hoc probability calibration for out-of-domain mri segmentation. In International Workshop on Uncertainty for Safe Utilization of Machine Learning in Medical Imaging, pages 59–69. Springer, 2022.
  • Ovadia et al. (2019) Yaniv Ovadia, Emily Fertig, Jie Ren, Zachary Nado, David Sculley, Sebastian Nowozin, Joshua V Dillon, Balaji Lakshminarayanan, and Jasper Snoek. Can you trust your model’s uncertainty? evaluating predictive uncertainty under dataset shift. arXiv preprint arXiv:1906.02530, 2019.
  • Risten (1989) H Risten. The fokker-planck equation: methods of solution and applications. Springer Series in Synergetics, 18:544–3, 1989.
  • Ronneberger et al. (2015) Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-net: Convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention, pages 234–241. Springer, 2015.
  • Salahuddin et al. (2023) Zohaib Salahuddin, Yi Chen, Xian Zhong, Henry C Woodruff, Nastaran Mohammadian Rad, Shruti Atul Mali, and Philippe Lambin. From head and neck tumour and lymph node segmentation to survival prediction on pet/ct: An end-to-end framework featuring uncertainty, fairness, and multi-region multi-modal radiomics. Cancers, 15(7):1932, 2023.
  • Särkkä and Solin (2019) Simo Särkkä and Arno Solin. Applied stochastic differential equations, volume 10. Cambridge University Press, 2019.
  • Speagle (2019) Joshua S Speagle. A conceptual introduction to markov chain monte carlo methods. arXiv preprint arXiv:1909.12313, 2019.
  • Wang et al. (2021) Deng-Bao Wang, Lei Feng, and Min-Ling Zhang. Rethinking calibration of deep neural networks: Do not be afraid of overconfidence. Advances in Neural Information Processing Systems, 34:11809–11820, 2021.
  • Wang et al. (2019) Guotai Wang, Wenqi Li, Michael Aertsen, Jan Deprest, Sébastien Ourselin, and Tom Vercauteren. Aleatoric uncertainty estimation with test-time augmentation for medical image segmentation with convolutional neural networks. Neurocomputing, 338:34–45, 2019.
  • Wang and Yeung (2020) Hao Wang and Dit-Yan Yeung. A survey on bayesian deep learning. ACM Computing Surveys (CSUR), 53(5):1–37, 2020.
  • Wenzel et al. (2020) Florian Wenzel, Kevin Roth, Bastiaan S Veeling, Jakub Swiatkowski, Linh Tran, Stephan Mandt, Jasper Snoek, Tim Salimans, Rodolphe Jenatton, and Sebastian Nowozin. How good is the bayes posterior in deep neural networks really? arXiv preprint arXiv:2002.02405, 2020.
  • Yan et al. (2020) W. Yan, L. Huang, L. Xia, S. Gu, and Q. Tao. Mri manufacturer shift and adaptation: Increasing the generalizability of deep learning segmentation for mr images acquired with different scanners. 2020.
  • Yan et al. (2019) Wenjun Yan, Yuanyuan Wang, Shengjia Gu, Lu Huang, Fuhua Yan, Liming Xia, and Qian Tao. The domain shift problem of medical image segmentation and vendor-adaptation by unet-gan. In Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, October 13–17, 2019, Proceedings, Part II 22, pages 623–631. Springer, 2019.
  • Zhang et al. (2019) Ruqi Zhang, Chunyuan Li, Jianyi Zhang, Changyou Chen, and Andrew Gordon Wilson. Cyclical stochastic gradient mcmc for bayesian deep learning. arXiv preprint arXiv:1902.03932, 2019.
  • Zhao et al. (2022) Yidong Zhao, Changchun Yang, Artur Schweidtmann, and Qian Tao. Efficient bayesian uncertainty estimation for nnu-net. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pages 535–544. Springer, 2022.

A Detailed evaluation results

Refer to caption
Figure 15: ROC curves of all classes on the ACDC, M&M and QMRI datasets.