My undergrad thesis project was to develop two neural networks to classify 4D fMRI data into brain disease categories, evaluate those classifiers, and explain their predictions through tailored visualization techniques.

In 2017, I completed an undergraduate thesis-equivalent project focused on a problem in medical imaging. At the time, automated diagnoses of diseases directly from imaging data using deep learning was a relatively new innovation. In particular, there was a significant confidence gap held by doctors about the effectiveness of these methods because there was no clear understanding of how the deep learning classifiers arrived at their conclusions. This gap is why deep learning classifiers are typically referred to as “black boxes”. My thesis was to address this gap on a specific task by both diagnosing brain diseases and at the same providing explanatory visualizations. The specific task was diagnosing fMRI brain images from a UMass Medical dataset of youth brain-disease into brain-disease categories. After developing the neural network architectures the visualization component used what were at-the-time state of the art techniques in neural network visualization.

My team (Miya Gaskell, Ezra Davis, and myself) worked under the supervision of Prof. Xiangnan Kong. The project was the first experience with neural networks, convolutional layers, and visualization for all three members of the team. As part of what we delivered we produced a thesis paper and a poster which we presented to the WPI faculty at project presentation day.

Dataset

We were fortunate to be able to work with UMass Medical for this project through an ongoing collaboration Prof. Kong had with them. All of the data was provided to us by Mass Medical’s Center for Comparative Neuroimagaing, the subjects of which were youths undergoing a brain-disease The dataset of 88 patients was train/test split patient-wise i.e. all of the fMRI frames from an individual patient were either contained entirely in the training or testing set in order to prevent data leakage. This is important because one of our architectures classified on individual frames, which dramatically increased the size of the dataset, and had we not prevented the leakage this way and used the same splits we would not be able to properly evaluate our classifier’s effectiveness. Because this was a small but high quality dataset, we acknowledged the need for more data to make the study truly take advantage of the power of a deep-learned classifier. Additionally, the lack of data creates trouble extrapolating our results to the clinical setting. Nevertheless our approach could be directly applied, including visualization techniques, to any large fMRI classification dataset.

Background

Although at this point neural networks are somewhat common knowledge, it is important to describe what we learned over the course of this project. To do that I will define the terms informally, using language from the report. Credit goes to Ezra for this specific explanation.

Artificial neural networks are one of many machine learning techniques that are primarily used for classifying data. Originally, this technique was inspired by biological neural networks. One common example of artificial neural networks using a dataset of many of images of cats and dogs, you could use an artificial neural network to recognize and correctly identify if a new photo is of a cat or a dog. For our purposes, we use artificial neural networks to help diagnose patients based on their brain scans. Neural networks form their predictions based on a training process – you show examples of both the input and output (e.g. patient A has disorder 1, patient B has disorder 2, etc.) again and again, and the network gets better at predicting the output given the inputs as its guesses get more accurate. The training process is rarely perfect, but it can create a neural network that is quite good at predicting the output for any new input that’s not in the training set.

Convolutional neural networks are a specific type of neural network that is usually used for classifying images. Traditional neural networks, when given an image, consider each pixel’s location as a separate input, and if the image is shifted one pixel over, this is an entirely different image. Convolutional neural networks (CNNs), however function differently. CNNs (like many modern traditional neural networks) are typically comprised of several different layers, each of which does some small piece of the analysis, with the first layer performing simple analysis and feeding its results into the next layer. This is called deep learning. Each convolutional layer consists of a number of filters (typically 16 or 32) that look at a a part of the image at a time (commonly 5×5 or 3×3 pixels, smaller areas compute faster). Each filter lights up under different conditions – for instance, one filter could detect edges in an image, whereas its neighboring filter detects circles. The output of these filters (each of which is roughly the size of the original image) are fed into the next convolutional layer (typically with simpler non-convolutional layers in between), which has filters that can predict more complex features (e.g. faces of cats versus faces of dogs).

In our cats and dogs example the algorithm is ultimately trying to predict whether the photo contains a cat or a dog. In more traditional image analysis techniques, the programmer might explicitly set up a set of rules for differentiating pictures of cats from dogs, such as the shape of the ears or if the animal is wearing a collar. In an artificial neural network, the computer decides on what features are important on its own, not communicating these features back to the programmer. Artificial neural networks act like a black box – we can see the the inputs and outputs, but the internal workings are hidden from view.

Convolutional neural networks are great at classifying images and other spatial data (such as fMRI brain scans), but it can be hard to understand how they really work. Because the training process is intentionally randomized and the “learning” process is automated, it can be hard to figure out exactly what each filter is trying to detect. There are a number of visualization techniques for exactly that purpose (such as those in Zeiler and Fergus, 2013, Visualizing and Understanding Convolutional Networks), but before our project, we did not find any for 3D data.

Approach

We ultimately settled on two architectures for our deep neural networks.

The first was a 3D CNN that we would run at each frame of video (no notion of time, but excellent spatial perception), and the end product was to take the mean probabilities across the whole time series as our final predictions. We develop a max-activation map and a guided saliency map video to enhance confidence in the accuracy of the method. Because the guided saliency map was based upon a per-frame classification with no constraint that the classes be temporally consistent, we were surprised to find that the video did not have a lot of flickering or other temporal artifacts and usually resembled small portions of salient activation moving smoothly around the brain volume.

The second was a 1D CNN that separated mean activations over expert-chosen brain regions. This had the advantage of explicitly incorporating interactions spatially as well as temporal patterns, but the disadvantage of no within-region resolution. Ultimately this performed worse than the 3D CNN approach in pure accuracy. The visualization for the architecture took advantage of the region mapping and produced a graph of key activation-interactions overlaid on the brain itself by thresholding a saliency map.

Illustrative examples

Expanding upon the brief summaries of what we delivered above, I’d like to show some concrete examples.

First, our max-patch activation map. Each feature map (collection of filters) in our convolutional layers is a linear function of its input. To each image patch (across all input channels) to which it is applied, the feature map will produce a single output activation. The strength of this activation is determined by how aligned the feature map is with its input. Given a specific layer in our neural network, our max-patch visualization visualizes the location in the image that is producing the highest activation for each feature map at that layer. If used on multiple samples, early works on neural network visualization such as those by Chris Olah have shown that this can give a feel for what is qualitatively important in the sense of e.g. “this is a neuron that looks for sharp edges” or “this is a neuron that looks for thin squiggles”. While it may be overly reductionist now, a common viewpoint is that earlier layers in the network will detect features that are strictly less abstract (lower level) than later layers that build upon their representations.

Shown above is an example of max-patch on image data from a digit classifier (MNIST).

Because a single patch is not sufficient to reason about a single sample, we choose instead to look at the top 10 patches when applying it to aid confidence that the network is zeroing in on salient parts of an fMRI frame. The idea here is that the doctor would compare these patches to which they personally might think could ever be salient to determine if the network is focusing on a reasonable set of locations in the brain.

Second, I will outline our guided saliency map. Saliency maps use the method by which neural networks are typically trained to show what features in the input space would most quickly change the class assigned to an image. We back-propagate a unit gradient in only the class of interest to arrive at the map telling us which features we should change to most quickly increase how much log-probability our network assigns to that class. We also back-propagate a leave-one-out map of all but that feature and visualize that using the same technique. Because neural networks are very nonlinear these maps can be nonsensical by the time they reach the image and are unlikely to be sparse. A quick trick is to threshold the gradient in image space as we back-propagate to only be positive. The idea behind this uses the fact that each linear transformation thresholds the input to be positive. Therefore, positive gradients do not have the same propagation issues that negative ones do in that there is no upper bound to how much more positive an input representation could be in contrast to negative gradients. Additionally by “choosing a direction” positive-only gradients sparsify the saliency maps.

After creating the saliency maps and guided-back-propagating them all the way to the input frame for each input in a sequence, the new sequence of maps can be transformed to a video by choosing a view (in our case a standard cross-sectional diagnostic view) and rendering a video.

Shown above is a single frame from one such video showing how the saliency map suggesting the most salient region of a frame from our classifier’s perspective.

Lastly, for the time series representations our activation-interaction maps, displayed below, are useful in ascertaining which correlations between regions our classifier looks at. Because each region is a 1D time-series, and our network is run on all time series at the same time, we can determine correlations between the series that are important by looking at large activations at a specific point in our architecture. Specifically, the architecture begins with isolated convolutional layers that only apply to one region at a time. After a certain point, we consider all-pairs and produce n-choose-2 channels by combining the input channels using a weighted sum (weights learned for all pairs). By looking at the strength of the resultant activations and choosing only those outside of a region around zero, important connections can be separated from the connections that have comparatively low-strength activations.

Results

Ultimately, we were able to show via our held-out test set that our 3D approach classified better on the 4D fMRI data than what was reported as the median accuracy obtained by doctors. This comes with a somewhat large caveat that we are comparing our brain disease classification task to diagnosing all brain diseases and also that our results are highly limited by the size of our dataset. Because all patients were from the same study with the same setup and people involved, it is unlikely that we fit to particularities of the setup itself.

A more interesting result is that we found a brain region, and motifs in that region, that was being focused on by our guided backpropagation (the parietal lobe) which matches research from 2013 on predicting bipolar depression. This was done by 3 undergraduate students without prior medical training and this made us excited about the potential of neural network visualization techniques for hypothesis generation in the medical field.

Acknowledgments

I want to thank my teammates Miya Gaskell and Ezra Davis, and our advisor Professor Kong for all the hard work they put in to make the project work. Additonally, I want to thank Constance Moore, UMass Medical School’s Center for Comparative Neuroimaging, and the patients and medical staff for providing the brain scans to make this research possible.