## An attempt to understand deep neural nets using the Information Bottleneck theory

Posted by Shubhojit Mallick in A.I. / Data Science on October 11, 2017

At the very core of Deep Learning lies a silent demon called the optimization problem; While deep learning has pretty much started reshaping businesses, business processes and our lives in various ways, many of us still don’t have a clear understanding about the algorithm(s) governing it. This article is an attempt to understand how Deep Neural Nets(DNN) understand and comprehend data. We will delve deeper into this problem by drawing parallels to information theory as proposed by Naftali Tishby’s work.

Imagine that we have a typical feedforward network with multiple hidden layers separating the input from the output.

X : is a high entropy variable; usually these are pixels of the input images

Y : labels of images. In the classical information theory setup, Y can be 1 bit or a couple of bits.

h1, h2…hm are the hidden layers

Ŷ : is the output layer which typically is a linear perceptron

Figure above represents a sample of the joint distribution of X and Y which are cascaded through hidden layers. Now within the hidden layers, the internal representation of the inputs are changed and after a lot of transformations we try to arrive at an output by calculating a probabilistic value for each set of input values. An important aspect of the figure above that needs highlighting is that Y is the desired output which is calculated only during the training phase. Once the training is done, the DNN receives an input X which is processed successively through the layers. Ŷ is the predicted output which will help us calculate how much information is captured by the network. Therefore, Y is not an input in the traditional sense of usage.

Images from the activation maps in hidden layers are shown below: Colour map = jet

*( Images have been taken from SigTuple’s internal database)*

Let’s now try to figure out how independent the variables X and Y are, using KL divergence;

which basically defines divergence D as:

∑ _{x}p(x) log p(x)/q(x)

Where:

p is the probability distribution of the input data.

q is the approximating distribution.

In simple terms we are trying to find out how much information there is in X about the label Y

Essentially, this KL divergence is the expectation of the log difference between the probability of data in the original distribution with the approximating distribution.

The second concept that is introduced, stems from the the fact that information within the hidden layers form a Markov chain. Quite remarkably, the underlying concepts of DNNs are very similar to those of a Markov chain i.e. each chain depends on the previous chain.

Based on the image above, let’s denote X’ as the simplest mapping of X which captures the maximum information about Y. Thus, we arrive at a Markov chain as follows:

Y ➡X➡X’

And, **in a markov chain information between X and X’ cannot be larger than information between Y and X’.**

Finally, in order to understand how different the variables X and Y are we calculate the mutual information which is given by the KL divergence between the joint distribution of the variables and the product of the marginals.

So if the variables are completely independent, mutual information is 0.

As we go across the hidden layers each hidden layer represents one random variable we now try to calculate what proportion of the information does it maintain in the input and how much information does it provide in the output.

Now, the Information Bottleneck method states that, as the data pass through the hidden layers, the network performs a kind of ‘empirical risk minimization’ i.e. it tries to get rid of a chunk of data-points, while still trying to maintain the input information. Let us assume that we have 2 spaces of objects I and J and our hypothesis function hyp(I,J) outputs an object j ∈ J given x ∈ X .

Empirical risk R(hyp) is an approximation that is computed by averaging the loss function on the training set. Thus,

R(hyp) = 1/m ∑L(h(i,j)

Where L is the loss function which measures how different the prediction of the hypothesis is from the true outcome.

Empirical risk minimization (ERM) is essentially = arg min R(hyp)

ERM therefore, first allows the data to be fitted and then the representation is compressed. **Which means, if we can compress the representation by k bits it is essentially equivalent to losing 2****k**** data points**. And hence, compression is a lot more powerful than reducing the dimensionality of the hypothesis class because the information loss is exponential.

Error is therefore bound by the exponent of the compression; mathematically it approximates to:

Ε^{2} < 2^{(h)}

Where h is mutual information between X and Y within the hidden layer. This error governs how much information the DNN can afford to lose without losing consistency of the data.

As a consequence of this, we see that most of the time during training is spent on compression so as to reduce the mutual information that the layer has about the input. In that process, it causes an exponential loss in data-points as shown above. Thus, DNNs have a tendency to do curve fitting in the initial epochs of training very quickly and after that they spend a lot of time in ‘forgetting’ that information via compression. As virtue of this compression if there isn’t enough data the compression technique actually decreases the information about the label which makes it harder for the DNN to generalize.

But, if data-compression is the sole criteria why does a DNN require multiple layers? The answer to that lies in the fact that DNN introduces noise into the data. While at the onset it does seem counter-intuitive but using Information theory it can be shown that the signal-to-noise ratio becomes small because the hidden layers add noise while maintaining the training error as small as possible. We’ll discuss more about this in another segment.

Food for thought: Based on the discussions so far, is it possible to extend the concepts derived on such simplistic neural nets to a network with orders of magnitude more weights ?