Introduction
Deep learning models, especially the convolutional neural networks (CNNs), have achieved remarkable successes during the past years, achieving state-of-the-art or even human-level performance on a variety of challenging medical imaging problems [1]–[3]. Typically, the deep networks are trained and tested on datasets where all the images are sampled from the same distribution. Despite the risk of over-fitting, the models are able to produce highly-accurate predictions on new test data from the same domain. However, it has been frequently observed that established models under-perform when being tested on samples coming from a related but different new target domain [4]–[6]. For medical image computing in digital health field, the scenarios include the case that the test and training images come from different sites [7], [8] or different scanning protocols [9], [10] or even different imaging modalities [11], [12].
Different from natural images which are generally obtained by optical cameras, a typical situation in the medical field is the usage of various imaging modalities, capturing different physical properties. These different modalities play complementary roles in clinical procedure of disease diagnosis and treatment. For example, Magnetic Resonance Imaging (MRI) and Computed Tomography (CT) have become indispensable tools for cardiac imaging. Specifically, MRI is ionizing radiation free and captures great contrast between soft tissues with high resolution in temporal space [13]. It features multi-parametric assessment of the myocardial contractivity and viability. By contrast, CT allows rapid imaging of the cardiac morphology, myocardial viability and coronary calcification, with great spatial resolution [14].
In practice, often the same image analysis task is required, such as segmentation or quantification of cardiac structures. Considering that annotation is prohibitively time-consuming and expensive (e.g., a whole heart segmentation from either MRI or CT takes up to 8 hours by a well-trained operator [15]), effectively adapting the model trained on one modality to the other holds clinical benefits. However, the appearances of cardiac MRI and CT are considerably different, with distinct contrasts and intensity histograms, as shown in Fig. 1. Unsupervised domain adaptation under such significant domain shift is very challenging, as well as yet to be explored.
Illustration of severe domain shift existing in cross-modality images. The appearances of the cardiac structures (AA: Ascending aorta, LV-blood: Left ventricle blood cavity, LV-myo: Left ventricle myocardium) look significantly different on MRI and CT images, though segmentation masks look very similar.
Most methods for unsupervised domain adaptation focus on aligning the distributions in latent feature space, by minimizing measures of distances between the features extracted from the source and the target domain [16], [17]. For example, the Maximum Mean Discrepancy (MMD) is minimized together with a task-specific loss to learn domain-invariant and semantically-meaningful features in [18]. Long et al. [19] minimize MMD of domain features embedded in a reproducing kernel Hilbert space. Sun and Saenko [20] propose to align the feature covariances between domains. More recently, with the advancement of generative adversarial networks (GAN) [21] and its powerful extensions [22], [23], the latent feature spaces across domains can be implicitly aligned via adversarial learning. Notably, the DANN method is proposed to extract domain-invariant features by fully sharing weights of the CNN encoder between domains [17]. Tzeng et al. [16] introduce a more untied adversarial learning framework, named ADDA, where each domain has a dedicated domain-specific encoder before the last classification layer shared by both domains. Most recent studies further propose to train with auxiliary tasks on both domains, which serves as additional constraints for feature alignment [24]. Alternatively, in terms of GAN based domain adaptation, another stream of solutions align input spaces of networks instead. They make use of unsupervised image-to-image translation, i.e., training the network with target-like synthetic source data, or testing with source-like target ones [12], [25], [26]. Most of them are based on a CycleGAN [23] foundation, where bi-directional image translations are learned by two GANs separately, and the consistency constraint between image transforms are enforced to preserve semantic information between transforms. Some other methods [27] combine feature alignment and image transformation together, however, the established framework rather sticks to improving target domain performance, and is not flexible between source and target domains.
For medical image computing, adversarial learning has presented inspiring efficacy on a wide variety of tasks [28]–[32]. In particular towards domain adaptation, early attempts have been made recently with the aim of generalizing the learned models to unseen target domains. For example, Zhang et al. [25] transformed unlabelled target X-Ray images to appear like those source radiographs and then directly tested the transformed images with the source model trained with labelled source data only. Similarly based on the CycleGAN, Jiang et al. [12] proposed a two-stage approach to first transform target CT images to resemble source MRI images, then conducted semi-supervised tumor segmentation with both synthetic image and a limited number of real MRI. Meanwhile, following the spirit of aligning latent feature spaces, there are a set of works aiming to extract domain-invariant representations from unpaired data. Degel et al. [33] minimized a segmentation loss in together with a domain adversarial training loss to encourage feature domain-invariance across ultrasound datasets for left atrium segmentation. Ren et al. [34] utilized adversarial learning to align the feature distribution of target images to the source domain for classifying histology images obtained in different staining procedures. These works have demonstrated that imposing alignment in feature space helps to generalize deep models to new data from a different domain. One of the most related work is Kamnitsas et al. [9], which conducted unsupervised domain adaptation by adversarial learning in multi-level feature space for brain lesion segmentation. The experimental setting is challenging as the unlabelled target domain contains a new MRI sequence which is unseen to the source model. By sharing encoders and aligning multi-level features, the method achieved promising results in target domain. However, designing flexible framework for unsupervised domain adaptation with minimal affect on the original source model, while mitigating drastic domain shift between modalities remains an open problem. Meanwhile, in terms of the highly challenging setting of cross-modality image segmentation from CT and MRI, to the best of our knowledge, related literature is limited. Valindria et al. [35] developed a joint learning method for multi-organ segmentation using unpaired MRI and CT. Zhang et al. [30] proposed cross-modality image translation for improving cardiac segmentation with synthetic data. However, these works did not aim at our topic of unsupervised domain adaptation of CNNs, which is in principle much more difficult since annotation of target domain is completely unavailable.
In this paper, we study the challenging topic of unsupervised cross-modality domain adaptation on multi-class segmentation problem. We present a flexible plug-and-play adversarial domain adaptation network, called PnP-AdaNet, which effectively aligns the feature space of the target domain to that of the source domain. Specifically, the early encoders are replaced for target domain input, and higher layers are shared between domains. At adversarial learning, we build two domain discriminators, respectively connecting multi-level features and segmentation predictions for joint alignment of multi-level feature spaces and output spaces between domains. This paper is a substantial extension of our prior work [11], by further modifying the method with a significant performance boost, presenting a completely new series of ablation analysis of our proposed Pnp-AdaNet, adding a comprehensive comparison with the state-of-the-art methods, and elaborating the literature review and discussion about unsupervised domain adaptation. Our main contributions are summarized as follows:
We address the challenging yet crucial task of unsupervised cross-modality domain adaptation for medical image segmentation. A novel PnP-AdaNet is proposed to enable a flexible adaptation of the segmentation CNNs by plug-and-play feature encoders.
Our model is learned with unpaired MRI and CT images via adversarial learning. To enhance supervision from discriminators, we aggregate multi-level features and segmentation mask predictions during training process.
We extensively validate our method on multi-class cardiac segmentation with a public challenge data [36]. The mean Dice of four structures has been recovered from 13.2% to 63.9%, outperforming many state-of-the-art domain adaptation approaches We also conduct comprehensive ablation studies on key method components. We release our code to facilitate the research community.
Methods
Fig. 2 is an overview of our proposed PnP-AdaNet method. With a standard segmentation CNN learned on the source domain, we replace its early layers with a domain adaptation module while retain its higher layers, for testing on target domain data. Hence, we call our method as plug-and-play domain adaptation framework. The adaptation module maps the target images to the source domain in a latent feature space with aligned distribution. This process is trained with adversarial loss in an unsupervised way.
Overview of our proposed PnP-AdaNet (plug-and-play adversarial domain adaptation network), consisting of a source segmentation network, a domain adaptation module (DAM) and two discriminators. Multi-level activations and predicted segmentation masks are aggregated for alignment of the latent feature space. The domain router is for testing. It chooses which set of early layers to connect to higher layers for segmentation task. Specifically, when testing source data, it chooses to use original source early layers; when testing target data, it chooses to use DAM layers.
A. Segmentation Network Without Skip Connection
The essence of our proposed PnP-AdaNet is to establish an independent encoder for each domain and align their feature distributions in the latent space. Considering that only the early layers which compose the independent encoders are updated while those higher layers are fixed, the feature spaces at different layers need to be self-contained, i.e., not mixed-up with each other. This means that the network architectures using skip connections, e.g., the U-Net [37] and DenseNet [38], are not suitable choices. Otherwise, the plug-and-play setting would be problematic, because those domain-specific low-level features forwarded by skip-connections would affect the aligned high-level feature space (which is supposed to be shared across domains).
In this regard, we set up our segmentation model as a dilated network [39], which can extract representative features from large receptive fields, while also preserving the spatial acuity of feature maps. Residual connections inside local scopes are used for ease of gradients flow. Specifically, as illustrated in Fig. 2, the input image is firstly forwarded into a convolutional layer, then proceeded to a series of residual modules (termed as RM1-7, each consisting of the stacked
Formally, we denote the annotated dataset of the source domain by \begin{align*}&\hspace{-0.5pc}\mathcal {L}_{\text {seg}} = - \sum \limits _{c \in C} \frac { \sum _{i=1}^{N^{s}} 2 y_{i,c}^{s} \hat {y}_{i,c}^{s} }{ \sum _{i=1}^{N^{s}} y_{i,c}^{s} y_{i,c}^{s} + \sum _{i=1}^{N^{s}} \hat {y}_{i,c}^{s} \hat {y}_{i,c}^{s} } \\& \qquad \qquad \quad {-\lambda \sum \limits _{i=1}^{N^{s}} \sum \limits _{c \in C} w^{s}_{c} \cdot y_{i,c}^{s} \log (\hat {p}_{i,c}^{s}) + \beta ||W||_{2}^{2}.}\tag{1}\end{align*}
For the ease of notation, we will omit the subscript index
B. Plug-and-Play Adaptation Mechanism
After obtaining the segmentation network trained on the source domain, we next aim to adapt it onto the target domain, in an unsupervised manner. In conventional transfer learning, it is common to update the last several layers of the pre-trained network towards a new given task with a new label space. The supporting assumption is that the early layers in the network extract low-level features which are universal for vision tasks. The higher layers are more task-specific and learn semantic-level features for conducting the defined predictions [41], [42]. In contrast, for domain adaptation, the defined task remains unchanged across domains. This means that the label space for source and target domains are identical, e.g., we segment the same anatomical structures from unpaired MRI/CT images. Not that images from different domains are not required to be co-registered. Basically, the distribution shift between the cross-modality domains are primarily low-level characteristics (e.g., gray-scale intensities) rather than high-level ones (e.g., geometric or semantic structures).
In these regards, for our model, we design a plug-and-play adaptation mechanism, i.e., a set of early layers being replaced, while the higher layers are reused for a new target domain. The underlying intuition is that, the higher layers are closely correlated with the shared semantic labels, while the respective early-layer encoders perform distribution mappings in feature space for our unsupervised domain adaptation. Formally, the obtained source segmentation model \begin{equation*} \hat {y}^{s} = M^{s}(x^{s})= M^{s}_{l_{1}:l_{n}}(x^{s}) = M^{s}_{l_{n}} \circ\ldots \circ M^{s}_{l_{1}} (x^{s}).\tag{2}\end{equation*}
For input \begin{equation*} \hat {y}^{t} = M^{s}_{l_{d+1}:l_{n}} \circ \mathcal {M}(x^{t})= M^{s}_{l_{n}} \circ\ldots \circ M^{s}_{l_{d+1}} \circ \mathcal {M} (x^{t}),\tag{3}\end{equation*}
Overall, we can find that the proposed plug-and-play domain adaptation mechanism is elegant and rather flexible at testing. During inference for the target domain, the DAM directly replaces the early
C. Adversarial Learning for Feature Space Alignment
We train our plug-and-play domain adaptation network via adversarial learning in an unsupervised manner. In the spirit of GAN, a generator and a discriminator form a minimax two-player game. The generator aims to capture the distribution of the real data, while the discriminator should identify whether a presented sample comes from the real or learned distributions. In our PnP-AdaNet, the DAM serves as the generator which maps the input target image into the latent feature space of the source domain. The aim of the domain adaptation module is to encode representations which are aligned with those encoded from the source domain images. Hence, the fixed layers in the higher part of source network can be reused to make semantic-level predictions of segmentation masks. As we have zero annotation for the target domain, the adaptation process is implicitly supervised by the discriminators, i.e., forming an adversarial learning game.
In our framework, we propose to employ two discriminators. Specifically, the input to the first discriminator (i.e., the green part in Fig. 2) is an array of aggregated feature maps of the segmenter. This input has a high dimension with a relatively complicated distribution. A natural thinking here is that, we connect the output features of the DAM into the discriminator. However, the convolutional neural network has a hierarchical architecture, and the features at one certain layer rely on activations from its previous layers, and also, the features are proceeded to affect following layers. If we just monitor the encoded features immediately obtained from the DAM, the latent space alignment can be unstable. In other words, we have no idea whether those activations in the layers earlier to the adaptation depth are aligned. Also, the small shift that may still exist at the adaptation layer
In practice, we aggregate the activations from multiple levels of layers, and reshape them to the same resolution for channel concatenation. Formally, we refer to the feature maps in the selected frozen layers as the set of \begin{align*} W(\mathbb {P}_{\text {feature}}^{s}, \mathbb {P}_{\text {feature}}^{t})=\inf \limits _{\gamma \sim \prod (\mathbb {P}_{\text {feature}}^{s}, \mathbb {P}_{\text {feature}}^{t})} \mathbb {E}_{(\text {x},\text {y})\sim \gamma }[\Vert \text {x} - \text {y} \Vert],\!\!\!\! \\ {}\tag{4}\end{align*}
Aligning the latent feature space by directly inputting the high-dimensional activations to a discriminator is effective and essential. This might be fine for classification tasks, but may be sub-optimal for segmentation tasks which require pixel-wise predictions with fine structures. This implies that well-aigned output spaces, in our case segmentation mask, between source and target are also critical. Early studies using GANs for segmentation applications (not necessarily under the domain adaptation setting) commonly input the predicted segmentation masks to the discriminator. When the shape or structure of the predicted segmentation mask looks distorted (i.e., not looking like the real mask), the discriminator would impose a penalty. For the specific problem of domain adaptation at segmentation, we also consider that monitoring the shape of the predicted segmentation mask is important. This serves as a correction mechanism for imperfections for feature alignment in adversarial training.
To this end, we further include an auxiliary discriminator in our PnP-AdaNet, whose inputs are the predicted segmentation masks of the source and target domains. In this case, the input is more compact and the discriminator focus purely on outputs, compared with those of the first discriminator. We denote the segmentation predictions for the target and source domains by \begin{align*} W(\mathbb {P}_{\text {mask}}^{s}, \mathbb {P}_{\text {mask}}^{t})=\inf \limits _{\gamma \sim \prod (\mathbb {P}_{\text {mask}}^{s}, \mathbb {P}_{\text {mask}}^{t})} \mathbb {E}_{(\text {x},\text {y})\sim \gamma }[\Vert \text {x} - \text {y} \Vert]. \\ {}\tag{5}\end{align*}
The detailed network architectures of the discriminators are illustrated in Fig. 2. For the model configuration, the feature discriminator is relatively deeper than the mask discriminator.
D. Loss Functions and Training Strategies
In adversarial learning, the DAM is pitted against an adversary with the above two discriminators. We represent the first discriminator with \begin{align*}&\hspace{-0.5pc}\mathcal {L}_{\mathcal {M}}= -\mathbb {E}_{ (\mathcal {M}_{A}(x^{t}), F_{H}(x^{t})) \sim \mathbb {P}_{\text {feature}}^{t}} [\mathcal {D}_{f}(\mathcal {M}_{A}(x^{t}), F_{H}(x^{t}))] \\& \qquad \qquad \qquad \qquad \qquad { -\mathbb {E}_{ \mathcal {S}(x^{t}) \sim \mathbb {P}_{\text {mask}}^{t}} [\mathcal {D}_{m}(\mathcal {S}(x^{t}))].} \tag{6}\end{align*}
\begin{align*} \mathcal {L}_{\mathcal {D}_{f}}=&\mathbb {E}_{(\mathcal {M}_{A}(x^{t}), F_{H}(x^{t})) \sim \mathbb {P}_{\text {feature}}^{t}}[\mathcal {D}_{f}(\mathcal {M}_{A}(x^{t}), F_{H}(x^{t}))] \\&-\mathbb {E}_{ (M^{s}_{A}(x^{s}), F_{H}(x^{s}))\sim \mathbb {P}_{\text {feature}}^{s}}[\mathcal {D}_{f}(M^{s}_{A}(x^{s}),F_{H}(x^{s}))], \\&s.t. ~\Vert \mathcal {D}_{f} \Vert _{L}\!\leq \! K, \tag{7}\\ \mathcal {L}_{\mathcal {D}_{m}}=&\mathbb {E}_{\mathcal {S}(x^{t}) \sim \mathbb {P}_{\text {mask}}^{t}}[\mathcal {D}_{m}(\mathcal {S}(x^{t}))] - \mathbb {E}_{\mathcal {S}(x^{s}) \sim \mathbb {P}_{\text {mask}}^{s}}[\mathcal {D}_{m}(\mathcal {S}(x^{s}))], \\&-s.t. ~\Vert \mathcal {D}_{m} \Vert _{L}\!\leq \! K, \tag{8}\end{align*}
In practice, we first train the segmentation network on the source domain in a supervised manner with standard stochastic gradient descent. We employ the Adam optimizer with a batch size of 10 and a learning rate of
Dataset and Evaluation Metrics
A. Dataset for Cross-Modality Adaptation
Publicly available medical datasets which contain different modalities of images for the same anatomical structure are rare, hindering progress of investigating the challenging task of MRI and CT cross-modality domain adaptation. Fortunately, the challenge of MICCAI 2017 Multi-Modality Whole Heart Segmentation (MM-WHS) presented 20 MRI and 20 CT cardiac images with accurate manual segmentation annotations. The images are unpaired with the MRI data and CT data coming from different patients and different sites. We refer the readers to the original data description paper of Zhuang et al. [36] for more details about data acquisition such as the employed scanning protocols. For evaluating the domain adaptation on cardiac segmentation, we include following four structures: ascending aorta (AA), the left atrium blood cavity (LA-blood), the left ventricle blood cavity (LV-blood), and the myocardium of the left ventricle (LV-myo). We randomly split each modality of the data into training (16 subjects) and testing (4 subjects) subsets in experiments.
Our Pnp-AdaNet is designed for unpaired cross-modality medical image segmentation, the MRI and CT images are not co-registered in our experiments, as it is not necessary. In pre-processing, the MRI and CT images are reoriented (in view direction), resized and cropped centering at the heart region, such that the view of multi-modal images are roughly on the same page. We extract MRI and CT scans by 2D slices of size
B. Evaluation Metrics on Segmentation Performance
For the evaluation metrics, we follow the common practice to quantitatively evaluate segmentation methods [44]. The Dice coefficient ([%]) is used to assess the agreement between the predicted segmentation and ground truth for cardiac structures. We use the symmetric average surface distance (ASD)[voxel]) to measure the segmentation performance from the perspective of boundary agreement. A higher Dice and a lower ASD indicate a better segmentation performance. Both metrics are presented in the format of mean±std, which shows the average performance with the cross-subject variations of the results. For some results, the N/A on the ASD means that at least one subject did not receive any correct prediction on the structure.
Experiments
A. Experimental Settings
In our experiments, we first set the source domain as MRI and target domain as CT. We conducted extensive experiments to demonstrate the severe cross-modality domain shift and the effectiveness of domain adaptation strategies. Specifically, we designed the following experimental configurations:
training and testing the segmentation network on the source domain (i.e., Seg-MRI);
training and testing the segmentation network on annotated target domain images, as an upper bound (i.e., Seg-CT);
directly testing the source domain segmenter on target data, with no domain adaptation (i.e., Seg-CT-noDA);
our PnP-AdaNet for unsupervised domain adaptation.
In addition, we optimized the practical configurations of our PnP-AdaNet. Specifically, our ablation studies investigated: i) how the balancing ratio in loss functions between the two discriminators affects the domain adaptation performance, and ii) the importance of inputing multiple levels of feature maps to
B. Results of PnP-AdaNet for Unsupervised Domain Adaptation
Our employed segmentation network without skip connection serves as the basis for our subsequent domain adaptation procedures. Hence, we choose to first validate the performance of this dilated network architecture on the applied whole-heart segmentation task. Specifically, for Seg-MRI setting, our model achieves an average Dice of 79.4% across four structures. To see the comparison, we reference Payer et al. [45] which used two cascaded fully convolutional networks and achieved an average Dice of 80.2%, ranking the first in MICCAI 2017 MM-WHS Challenge. We also directly quote their reported U-Net results as a complementary comparison. As the detailed results listed in Table 1, our network’s segmentation performance with standard training is comparable to the state-of-the-arts, which employed networks based on 3D convolutional kernels. Hence, we can safely regard our trained segmentation network as a standard baseline model for the cardiac segmentation task.
As for observing the severe domain shift inherent in cross-modality medical images, we first directly deploy the segmentation model trained on MRI domain to CT data. Unsurprisingly, the MRI segmenter completely fails on CT images, with an average Dice of merely 13.2% across all the four structures. Specifically, the Seg-CT-noDA only receives a Dice of 2.7% for LA-blood and 3.4% for LV-blood. After domain adaptation, our PnP-AdaNet presents a great recovery of the segmentation performance on target CT data compared with Seg-CT-noDA. More specifically, our method has increased the average Dice across the four cardiac structures by 50.7%, achieving a score of 63.9%. As presented in the last column of Fig. 3, the predicted segmentation masks from PnP-AdaNet can successfully localize the cardiac structures and further capture their anatomical shapes. Notably, the segmentation performance on aorta has been significantly recovered after the adaptation process, almost approaching the fully supervised upper bound.
Results of different methods for CT image segmentations. Each row presents one typical example, from left to right: (a) raw CT images (b) ground truth labels (c) supervised training on CT (d) directly applying MRI segmenter on CT data (e) results of DANN (f) results of ADDA (g) results of CycleGAN (h) results of our proposed PnP-AdaNet. The structures of AA, LA-blood, LV-blood and LV-myo are indicated by yellow, red, green and blue colors, respectively.
C. Ablation Study on Configurations
In this subsection, we extensively investigate the configurations of our PnP-AdaNet. Specifically, we observe the domain adaptation performance by adjusting two key properties: i) the ratio balancing for losses of these two discriminators
Specifically, we adjust the ratio of
As already observed that aligning the feature space is essential in our method, we furthermore investigate which layers to input their features to
D. Comparison With State-of-the-Art Methods
As our investigated topic of unsupervised domain adaptation on cross-modality medical image segmentation is quite new in the field, there are few previous studies that could be compared with. However, this problem deserves further careful explorations, given that CNNs are dominating current segmentation methods, and their generalization capability considerably matters. Different from the cross-site and cross-sequence domain shifts, we regard cross-modality shift between CT and MRI as one of the most challenging settings.
To promote and facilitate future studies on cross-modality domain adaptation, in addition to our own prior work of [11], we also implemented several state-of-the-art methods which are popular for unsupervised domain adaptation in natural computer vision. We used the same dataset and segmenter architecture settings for all the methods, with results listed in Table 3. Specifically, the approach of DANN [17] encourages domain-invariance in feature space. The source and target domains share the feature extractor. A domain classifier is connected to the output of the encoder to monitor domain-invariance. Another alternative approach is ADDA [16] which also aligns source and target domain distributions in the feature space. In this method, the source and target domains have their own feature encoders until the last softmax layer. A discriminator is used to differentiate which features come from which domain. A third comparison method we include in the comparison is CycleGAN [23], which produces impressive image-to-image transformations. We transform the MRI images to the appearance of CT, and train a segmenter with the transformed images and the MRI labels. We demonstrate typical examples of the generated CT images from MRI using CycleGAN in Fig. 4.
Examples of MRI to CT image-to-image translations with CycleGAN. Left to right: Original MRI image, generated CT image, MRI with segmentation ground truth, and generated CT with corresponding MRI ground truth.
Observing performances of the domain adaptation methods, we can see that all of them are able to effectively recover the segmentation accuracy on target CT data. More specifically, in terms of different cardiac structures, all the methods perform better for ascending aorta and left atrium blood cavity than on the structures of left ventricle blood cavity and myocardium. The reason could be that these anatomical structures, especially the myocardium, have relatively complicated geometry, which increases the difficulty for unsupervised domain adaptation. In contrast, the aorta presents a more compact shape as well as clear boundaries, and reasonably, receiving the best recovery. Compared with DANN and ADDA which also used feature-level adaptation, our proposed PnP-AdaNet achieves better performance. The reason is that we neither share nor separate all the feature encoding layers between source and target domains. Instead, our plug-and-play mechanism regards low-level features as domain-specific, while high-level feature encoders as sharable. Moreover, connecting multi-level features to the discriminator also plays a crucial role. The CycleGAN presents highly competitive results. We analyze that this can attribute to the pixel-wise supervision, particularly towards the segmentation task. The explicit supervision, although it maybe noisy, can encourage the predictions to consistently focus on the heart region. The good ASD results with CycleGAN also indicate clean masks. The deficiency of our feature-aligning method mainly appears at those unclear boundaries between neighboring structures, or wrong predictions on relatively homogeneous tissues but away from the ROI. With a very simple post-processing strategy, i.e., only remaining the largest 3D connected component for every class, we can reduce the ASD to 4.1, 5.4, 7.4 and 6.2 for the AA, LA-blood, LV-blood and LV-myo, respectively. Finally, compared with our own prior work of [11], the performance improvement comes from the usage of dual discriminators, with a more explicit constrain on the shape of segmentation.
E. Reverting Domain Adaptation Direction
A natural question to ask is whether domain adaptation for the segmentation task is symmetric to modality, i.e. whether the reverse adaptation direction from CT to MRI can also be achieved. To investigate this, we apply the same model setting, but only replacing the source domain as CT and target domain as MRI. The quantitative adaptation results are presented in Table 4. Directly using the CT segmenter on MRI data also unsurprisingly fails. Our proposed PnP-AdaNet is able to recover the average segmentation Dice to 54.3%. Notably, the best recovered structure in this reverse setting is the LV-blood, not the same structure (AA) as the direction of MRI to CT. As a more interesting observation, the Dice increases from complete failure (3.4%) to a considerably high value (77.7%). Results of the other three structures are not as promising. Compared with adaptation from MRI to CT, the reverse direction yields lower performance, which can be expected. Segmenting cardiac MRI itself is more difficult than segmenting cardiac CT. This is also notable from Table 1, where the CT segmentation Dice is higher than the result of MRI across all four structures. In these regards, transferring a CT segmenter to MRI seems more challenging. Our experimental results indicate that cross-modality domain adaptation can be achieved in both directions, however, the difficulty seems to be asymmetric.
Discussions
In this paper, we try to tackle the domain adaptation problem under a very challenging as well as important setting of cross-modality medical datasets. As deep learning has become the de-facto standard for solving segmentation and detection tasks, investigating its generalization capability and robustness is essential. Existing successful practice using deep networks is to train and test the models with the same data source. However, it has been frequently revealed in very recent works, that the models would perform poorly on unseen datasets [7], [10], [26]. Resolving the domain adaptation issue holds great potentials for, applying trained deep learning models to wider clinical use, building more powerful networks using large-scale database combing images from multiple sites, and helping to understand how the networks capture the data distributions to make recognition predictions. In these regards, we explore effective unsupervised domain adaptation solution, under one of the most challenging settings as the multi-class segmentation of cross-modality medical images.
The appearance differences between MRI and CT scans are apparent. Although the human eye could match the structures on the two modalities, from a perspective of image computing, the system “sees” distinct distributions of intensity values. For concrete example, see Fig. 1 again, the intensity range of myocardium and its contrast with the nearby tissues is very different on MRI and CT. In this case, the model that learns discriminative features from one modality is naturally inapplicable to the data manifold of another modality. Generally speaking, it is rather difficult for a model to generalize from one domain to another, where the measure of their support intersection is almost zero. On the other hand, multi-modality data have become indispensable tools in modern clinical routine. Cross-modality synthesis and segmentation has been gaining research popularity rapidly [46], [47]. Considering cardiology as currently the most typical scenario which uses both MRI and CT, we conduct our study on cardiac images.
Towards domain adaptation, we present a novel framework, i.e., PnP-AdaNet, which is very flexible and light-weight in practical use. We replace the early layers of an established network with a DAM at the testing process. The assumption is that, in cross-modality medical images for the same organ, those low-level features are domain-specific, while the semantic feature compositions at higher layers are shared and could be re-used across different modalities. To learn the DAM, we encourage the extracted features of the target domain to be aligned with the source feature in distributions. A notable setting in our PnP-AdaNet is incorporating multi-level adversarial learning, which makes the DAM tightly conditioned by the source distributions. Another hyper-parameter, which needs to be considered when using the plug-and-play strategy, is the adaptation depth
When thinking of solutions for domain adaptation, an alternative way is to take advantages of image-to-image translation. The networks using cycle-consistency loss can generate plausible synthetic images, which then can either be employed for training or testing. The potential risk of pixel-space translation would be that the distortions and artifacts might be propagated or amplified in downstream processing. For example, Cohen et al. [49] found that CycleGAN based medical image translations models trained on imbalanced datasets would hide dangerous brain tumors in synthetic images. From our own experimental experiences, the cross-modality translation of non-tumor medical images is feasible using CycleGAN, and the reconstructed images also look quite realistic. It can bring positive effects when using synthetic images towards the scenario of domain adaptation, however, it cannot solve the problem perfectly. We analyze that the reasons are at least two-folds. One is that there still exists domain shift between the synthetic images and real images, although their appearances somehow look similar [30]. The other is that the generated images may on a pixel-level not match with the original image. Therefore, the training pairs of image and label are imperfect. In contrast, feature-level adaptations, such as ours, do not suffer from these problems, as the adaptation and task-specific training are not detached.
Although we achieved promising results, there are still limitations. The backbone of our network is a 2D CNN. In practice, we aggregate the adjacent three slices as the input channels to the models. This is common practice to overcome current single GPU memory constraints. It maybe beneficial to employ more carefully tailored networks for CT/MRI image analysis with 2.5D [50] or 3D CNNs [51]. In this work, we also tried to conduct the domain adaptation on the basis of 3D segmentation networks. However, we were faced with difficulties of memory consumption given the use of one segmenter, one generator and two discriminators. Moreover, optimizing 3D networks in unsupervised adversarial learning is also very challenging. In these regards, we chose to first down-grade the network to 2D, so that we can focus on the domain adaptation part which is the core of this paper. There are still not many works using 3D CNNs for GANs on medical applications. One recent work is Zhang et al. [30] which synthesized CT from MRI on cardiac data using an end-to-end 3D CNN. Another work is Pan et al. [28] which used a 3D cycle-consistent GAN to synthesize missing PET from MRI for neuroimage. To the best of our knowledge, the work of [9] was the first to employ 3D CNNs in domain adaptation for medical image segmentation task, which tackled different MRI sequences. Exploring the effectiveness of cross-modality domain adaptation approaches on the basis of 3D network is planned in our future work. In addition, we also plan to extend our method towards heterogeneous domain adaptation, for tackling a more challenging situation where the target domain is significantly different from the source domain in terms of not only low-level appearance but also high-level semantic structures. For example, we may adapt a neural network trained on a healthy cohort of MRI to a severe-diseased cohort of CT data, which holds promising clinical significance.
Conclusion
In conclusion, we investigate the challenging yet crucial task of unsupervised domain adaptation on cross-modality medical images. We propose a novel approach, called plug-and-play adversarial domain adaptation network (PnP-AdaNet), which aligns the latent feature space of the target domain to that of the source domain. Extensive experiments with comparision with state-of-the-arts have validated the effectiveness of our method. Moreover, to facilitate future research on the cross-modality domain adaptation problem, we open-source our code using the well-organized cardiac segmentation challenge dataset. We believe that the cross-modality domain adaptation task will witness rapid development as an important topic in the scope of digital health.
ACKNOWLEDGMENT
(Qi Dou and Cheng Ouyang contributed equally to this work.)