Inspiration
Having worked on Deep Learning research for the past 4 years, one problem which always bothered us is the usability and visualization of deep learning models. Although past research has shown several techniques like gradCAM, it is still not usable by developers who have less experience in the same.
We propose a web application for developers, or infact any users to understand and interpret the deep learning models and their misclassifications by interacting with the three views.
How we built it
The visualization framework in this project relies on three backend machine learning/ deep learning models - the Teacher Network, the Student Network and the Variational AutoEncoder (VAE) Network. The Teacher model is a 50 layer Residual Convolutional Network (ResNet) which captures and learns both spatial and individual pixel based features from the network. The model is trained with 10,000 images while tested with 60,000 images from the MNIST dataset since training with 60k images resulted in an almost perfect network resulting in a high accuracy model with almost no false predictions. Thus, we undertrained the model with a smaller dataset while tested it on a large dataset.
The second network - VAE model is trained on the MNIST dataset. The VAE is a variational autoencoder with a standard encoder decoder architecture. One major difference in a VAE model compared to other autoencoders is that VAE’s latent spaces are continuous, allowing easy random sampling and interpolation and thus let’s users play around with the parameters while generating sensible neighbour images.
The third network - Student Network is a linear (one layer) neural network or a multi-class perceptron. The Student network is trained using images from VAE as input and it’s corresponding labels from the Teacher network as output. the Student Network helps observe the boundary between the false positives and false negatives using the neighbour images generated by the VAE. These three networks work in a pipeline and helps the user visualize the reason for false predictions in a deep learning model
We used GCP Compute Engine and TensorFlow 2.0 to run our backend deep learning models. Our entire project is hosted on google cloud instance itself.
Challenges we ran into
Integration of front end with Flask API, there were multiple compatibility issues with javascript and python integration.
What we learned
The proposed visualization framework enables user to understand and interpret the deep learning models and their misclassifications by interacting with the three views. The softmax values for the misclassified image is helpful in viewing how confident the model is, in making the prediction. The confusion matrix helps users visualize the classes which are misclassified and choose a cell of interest. The parallel coordinates plot helps user play around with different parameters of VAE and utilize it’s main advantage - flexible interpolation and easy random sampling. The text boxes and buttons help users to provide values for the hyperparameters of the student and VAE models. The 20 images in student model helps visualize which pixels are being focused on by the model.
What's next for Visualize it Deeper
Although the three views seamlessly complement each other in providing user with maximum features, many improvements can be made to the framework in future. The chord diagram provides user with information about inter-class and intra-class comparison. One other important modification would be the inclusion of class activation maps in teacher view to visualize a heatmap of the pixels contributing to the final prediction. Second, we could visualize the filters using visualization techniques like occlusion instead of just resizing and displaying the weights as such.
Website:

Log in or sign up for Devpost to join the conversation.