Using Unbalanced Datasets for Deep Learning in Medicine


Deep learning is now considered a panacea to all classification problems; especially those involving images. This is especially true of the medical diagnostic world which generates several petabytes of data each year. With today’s medicine moving to an evidence based diagnosis, this would increase every year.

As such, it is a burgeoning challenge for data scientists to use all this data to generate meaningful insights. A major use case is to augment clinicians with intelligent tools to reduce the turn-around time for diagnosis and treatment. An example would be to analyse chest-X Rays (a potential billion images generated every year) to classify abnormalities, thus reducing the time spent by the radiologist per image. In such a case, a major concern is the data; or rather its distribution. Most machine learning algorithms rely on a balanced dataset in terms of their loss function. How would you get a good number of samples for each of the abnormalities given the rare occurrence of some them? And given the propensity of people going for tests only when seriously ill, a normal sample is far less likely to encountered. So, how do we train deep convolutional models with the a-priori information that a majority of our cases are going to belong to a subset of our classes?

The common ways of balancing datasets is generally undersampling the over-represented class or oversampling the underrepresented class. In case of undersampling, this might not work in a case where the dataset itself is small or intraclass variability is high enough that losing enough samples of a certain class may lead to that class not being fully learnt. Oversampling can be done through duplication or adding slightly perturbed variants. But in case of deep networks, a major assumption is that the images themselves should be decorrelated. Thus adding a perturbed image does not give a lot of benefit.

So, what seems to be the solution? One way is to weigh the loss function a-priori. We penalize the false classification of each class differently based on their occurrence in the data-set. This is referred to as the Bayesian Cross-Entropy Loss [1]


This gives a chance for the undersampled class to be represented by the weights. We can augment this methodology by balancing the mini batches used for learning. The question of model convergence remains in case of very highly imbalanced datasets.


Another question is regarding the parameter to use to select the best model during training. This parameter can be a function of the use-case of the model. In case of medical screening tests, it is always prudent to have a high sensitivity value in order to not miss potential abnormal cases. Thus, using the sensitivity values on the validation set along with the f-score may prove to be a good measure of a model’s effectiveness. In other cases, for eg. segmentation, higher values of IoU may be a prudent measure.

Thus, we can enhance the learning capability of a model even over unbalanced datasets by providing the right loss function and through a combination of the various methods listed above.

In terms of newer methods, recent implementations of Neural Turing Machines(NTM) [2] or adding Attention to Neural Networks exhibit promise in solving one-shot or zero-shot problems, where the number of samples are extremely low or intra-class variability is extremely high. NTMs especially have been promising in learning and storing discriminative representations for the Omniglot Dataset with it’s large number of classes and low number of samples per class. The application of these on medical datasets, though, is something that needs significant analysis.


[1] Dalyac, A., Tackling Class Imbalance with Deep Convolutional Neural Networks

[2] Graves A., Wayne G., Danihelka I., Neural Turing Machines