Building apps for editing Face GANs with Dash and Pytorch Hub

In a pub located in downtown Montreal (the same city where Plotly was founded), some evening in 2014, then PhD student Ian Goodfellow thought about a game theory-inspired approach to generate realistic images using Deep Learning: what if we could pitch two neural networks to play a game, where one would generate images from random noise (the generator), and the other one (the discriminator) would learn to predict whether an image is real (drawn from some training set) or fake (generated by the adversary). If you train those models correctly, you would have a generator that is good enough to fool a neural network into thinking the images are real; and perhaps such generators would be able to fool humans, too. By midnight, Ian Goodfellow was able to implement the first version of the Generative Adversarial Network (GAN), and subsequently published the seminal paper that would create a completely new subfield of Deep Learning, where researchers would compete to produce more realistic, higher-resolution, and more controllable GANs that would fool both machines and humans into believing that images of animals, cars, or a human faces are real.

The first faces generated by the original GAN. Retrieved from Goodfellow, 2014.

In the next 6 years, many improvements were made to the original model. Subsequent works added Convolutional Networks, novel loss functions like the Wasserstein distance, regularization using gradient penalties, and progressive training for generating higher resolution images. Nowadays, it is almost impossible to distinguish between real and fake human faces, as demonstrated by models like StyleGAN, which borrowed many ideas from the early works in the subject.

Nvidia’s StyleGAN lets you generate and interpolate high-resolution images of faces, often indistinguishable from real people.

Moreover, in addition to improving the end result, researchers have been interested in using GANs as a way to modify existing images. For example, in the case of StyleGAN, it’s possible to combine two images by retaining certain characteristics from the first image, and take all the other features from the second image.

Other research projects have applied editing tricks to other scenarios. For example, with CycleGAN, you can edit images of horses to become zebras, or change a scenery from summer to winter, and vice versa. In fact, it can even be efficiently applied to individual frames in a video!

This example was retrieved from the author’s respository.v

Recently-published research has made it possible to edit GANs using user interfaces. With GANPaint, we can modify generated images by adding or removing items such as trees, doors, and domes. GauGAN takes it one step further by allowing us to use colored brushes to directly draw the scene we want to show, and automatically generate an image of such sceneries.

Such advanced user interfaces (UIs) have been traditionally built using thousands of lines of HTML/CSS and JavaScript, and often require libraries like React.js. Having to learn new languages and libraries to present your work makes it much harder to express exciting new ideas. For this reason, we worked hard to make Dash the best tool for ML engineers, data scientists and researchers to build the perfect UI for displaying their projects — zero HTML/CSS or Javascript needed, just pure Python. To show this, we have built an app that lets you edit AI-generated faces, with all the computations in real-time. To do that, we integrated the Pytorch implementation of Progressive GANs (PGGAN) with the famous transparent latent GAN (TL-GAN).

This app lets you edit synthetically-generated faces using TL-GAN. All the components were built with our Design Kit, and it’s running the PGGAN model in real time on Dash Kubernetes. You can try it out now.

Making GANs transparent with TL-GAN

Created by Shaobo as part of a Insight Data Science project, the transparent latent GAN (TL-GAN) introduces a simple, but perfectly executed idea: shine a light on the random noise used to generate realistic faces by learning an association between the input noise and facial features like age, sex, skin tone, and whether the person is smiling or wearing glasses, hats, necklace, and more. To achieve this, the author first trained a ML model to classify images based on some 40 facial features (using labels from the CelebA dataset), which was then used this model to label hundreds of thousands of images generated by the officially-released PGGAN. Finally, a linear regression was trained to predict the features output by the ML model given the latent vectors (i.e. the random noise), and the trained weights were used to control the noise to give an output that correlated more heavily with the desired features.

Integrating PyTorch Hub with Dash

Pytorch Hub is an incredible repository of pretrained models built in Pytorch, which can all be imported and loaded in a few lines of code. For our own app, all we needed to do was to load the pggan model from torch.hub (which is included in the official PyTorch release) at the start, and start using it in our callbacks.

Traditionally, if you wanted to deploy a model loaded from Pytorch Hub, you would need to design a REST API with Flask, then communicate with a front-end built in a library like React.js. Since you are outputting images, you would then need to worry about encoding the image into string using schemes like base64, and ensure that the component for displaying the image is compatible.

With Dash, you don’t need to worry about JS, React components, or REST APIs. All you need to do is to write a callback that calls the pggan model you loaded at the start, and input the noise modified by TL-GAN, and display it in a Plotly.py graph using either Plotly Express or image layouts. This way, you can easily zoom into the image to examine smaller details about the image that might go unnoticed. Furthermore, with Dash Design Kit, you can effortlessly modify the graph component to have the exact background color and text font you want.

Plotly graphs can be easily integrated into Dash and are fully customizable by the Design Kit.

Dynamically control components with Pattern Matching

When you use the app, you will notice that you can control the number of sliders and checkboxes you want to display, and each of those components can be used to update our image.

To accomplish this, you can create a callback that checks if there are new features added in the dropdown and add them to the existing sliders without changing the values of the sliders or checkboxes. To accomplish this, you only need to write a callback in a similar fashion:

Using those dynamically generated sliders, you can update the image generated by the PGGAN using a callback function similar to this one:

Each interaction is completely controlled by two callbacks using Pattern Matching . The GAN is running in real time on our servers.

Notice how we are using a different syntax for the id of the sliders and checkbox, and how concise those callbacks are. This is because we are using the pattern matching functionality that the Dash team recently released. Those new callbacks let us control a group of components using a single dcc.Inputdcc.Output or dcc.State, which made it possible for us to dynamically add and remove components using only two callbacks.

In fact, it’s possible to create complex filters by grouping the components based on a common role, or select a subset of all the indices for a certain type by using MATCH instead of ALL. By looking at the docs, you can create increasingly advanced interactions between components, and pattern matching is completely open-source through the support of sponsors.

Optimized and scalable deployments through Dash Enterprise

The most significant change we brought to the original TL-GAN model is that we used the PyTorch Hub version of PGGAN instead of the original Tensorflow version. As a result, we were able to considerably reduce the size of our repository, since we do not need to include the entire source code and pickle files. Instead, we rely on PyTorch Hub’s caching mechanism and focus on improving the inference speed of the model. As a result, each image is generated in less than one second on our CPU servers, and easily scaled through our Kubernetes Scaling.

However, with the increasing need for faster inference and larger production models, traditional approaches might not be sufficient for you. With Dash, it is possible to leverage the power of GPUs through our partnership with NVIDIA RAPIDS. If you are interested in productionizing Deep Learning and Computer Vision models, whether it is on CPUs or GPUs, reach out to learn how Dash can help you.