AI can do a lot of cool stuff. Tell apart pictures of cats and dogs, translate between languages, turn pictures of cloudy days into sunny days, and even defeat the world’s best Starcraft players!
What’s interested me recently, is text to image generation. Yes, you could just search google images and it would give you beautiful pictures of cats, pizza, and landscapes, but what if you were specifically looking for a small bird with a red belly, a small bill and red wings. I literally typed this into Google, and this was the first results
Text to image generation algorithms, as their name implies, actually generate images based on the text description rather than searching through a database, which opens up all sorts of new possibilities.
Imagine simply talking, and having a computer generate an image of your dream kitchen, or a mockup of your latest app. The technique could also be extended to designing 3D objects, making CAD design a breeze.
However, more importantly, it’s really cool 😎.
All current state-of-the-art algorithms use a neural network architecture called Generative Adversarial Networks (GAN).
GAN’s are made up of two parts, a generator, and a discriminator, which are sort of battling against each other. The descriminator’s goal is to distinguish between real images and those produced by the generator, and the generator’s goal is to fool the discriminator into thinking the photos it generated are real.
Overtime, the generator and discriminator keep trying to one up each other, and get better and better at their task. After training, we then usually throw out the discriminator, and use the generator to create realistic copies of the training data.
Both the generator and discriminator are neural networks, which are optimized using gradient descent in respect to a loss function, specifically:
The first term in the function is the log of the discriminator’s prediction of a real image x, and the second term is the log of 1 minus the discriminator’s prediction of a generated image.
The way we integrate our text description into this, is by using conditional GANs, where both the generator and discriminator take in an extra input Y, which can be anything, a photo, text, etc. The GAN is then trained to generate images based on its conditioning (ex. input: “black”, output: picture of a black bird — input: blue, output: picture of a blue bird).
That’s all the preliminaries you need to understand MirrorGan, currently one of the best text to image generation algorithms. You can find the original paper here.
The model is divided into three modules
- STEM: Takes in the text description, and represents it using numbers.
- GLAM: This is where the magic happens. It’s multiple feature generators stacked on top of each other which use an attention model (I’ll get to that later).
- Stream: Takes the images generated by GLAM and creates a text description of it. The similarity between Stream’s text regeneration and the original text description is an additional loss term.
Computers are really good at working with numbers, but not words, so the first thing we have to do is represent the text description with numbers, which is called an embedding. To do that we use a pre trained recurrent neural network (RNN), which takes the text description as input, and outputs a word level and sentence level embeddings.
Where T is the text description, w is the word level embedding and s is the sentence level embedding.
The word level embedding w is a DxL matrix, where L is the number of words, and D is the dimension of the hidden state of each word. The sentence level embedding s is a vector with D dimensions. In the paper, they set D equal to 256, but it’s a hyperparameter which can be changed.
In order to create more training data, the sentence level embeddings are passed through a conditioning augmentation. This samples a new text embedding from the gaussian distribution N (μ(‘t ), Σ(‘t )), where ‘t is the text embedding, μ(‘t ) is the mean of the text embedding, and Σ(‘t ) is the diagonal covariance matrix. The new sentence embedding D’ is a 100 dimension vector.
This techniques create more training pairs, and makes the model more robust to slight variations in sentences.
This is the main module where the magic happens. It’s made up of multiple visual feature transformers (neural networks which extract visual features) which are stacked on top of each other and trained one after the other. At each stage of training, a generator Gi is fed the output of visual feature transformer fi and constructs an image, which is judged by discriminator Di.
The first visual feature transformer f0 simply takes the conditioned sentence level embeddings, and z, a random noise vector: f0 = F0 (z, sca) . The output of visual feature is then fed into generator G0, which creates a 64x64 image. Discriminator D0 is first trained on a batch of real and fake images, then the discriminator’s weights are frozen and the generator is trained.
Along with the output of the previous visual feature, the second and third visual feature transformers are also fed Fatt, which is the proposed global local attention network.
Fatt is the concatenation of Attw (word level attention) and Atts (sentence level attention).
Attw takes the word level embedding w and visual feature f as input. First, w is fed through a perceptron layer U (one layer neural network) to convert it into the same semantic space as the visual feature. The output is multiplied by the visual feature to obtain the attention score.
Finally, to get Attw, we take the inner product of the attention score, and the output of U. The final equation looks like this:
Attw has the same dimension as f.
Similarly, to get Atts, the conditioned sentence embedding sca is fed through a perception layer V so it’s in the same semantic space as the visual feature. Then it is element wise multiplied with the visual feature f to obtain the attention score. Finally, Atts is obtained through the element wise multiplication of the attention score and perception layer V’s output.
The ith visual feature is obtained by concatenating fi-1, Attwi-1, and Attsi-1.
This module generates a text description of the generated image, and is used to check how similar its generated description is to the original description.
The module is made up of an encoder-decoder architecture. The encoder is a CNN trained on ImageNet, and the decoder is an RNN. The final equations look like this:
Where I is the generated image, x-1 is the visual feature fed into the RNN, and We is the word embedding matrix, mapping word features to the visual feature space. Pt+1 is a probability distribution over the words.
The STREAM module is pre-trained, and its weights kept fixed while the other two modules are training, The authors of the paper say doing so makes the training process more stable, and much less expensive in terms of memory and time.
The generator’s loss function is made up of two parts, how well it’s able to fool the discriminator, and how closely the original description matches with the description generated by STREAM.
The first part (representing ability to fool discriminator) is further broken up into two parts.
The first term measures how realistic the discriminator thought the generator’s image was. Since there’s a negative sign in front of the first term, minimizing it means maximizing the log function log(D(I)), which means the generator’s trying to get the discriminator to think its image is real.
The second term measures how closely the original text description’s meaning matches that of the generated image.
The paper proposes another loss function which checks how closely the semantics of the original text description match that generated by the STREAM module.
Hence the final loss function is
λ is a hyperparameter which can be tuned to balance adversarial loss with text-semantic reconstruction loss.
The discriminator is trained to minimize the loss function:
The first two terms train the discriminator to distinguish between images which look real, and those that look fake by maximizing its score of images IGT (real images), and minimizing its score of images I (generated images). The last two terms train the discriminator to identify whether an image corresponds with its text description, by maximizing its score for (IGT ,s) and minimizing its score (I, s).
That’s it! Now that we have our graph and loss function, we simply train the STREAM module first, then freeze its weights and train the STEM and GLAM modules!
The paper’s results are quite impressive, better than AttGan and StackGan.
The two major datasets to train these algorithms are CUB (pictures of birds) and COCO (picture of basically everything). Although the algorithms do well on the CUB dataset, they vary between mediocre and absolute trash when trained on the COCO dataset, it’s pretty hard to have a model generate pictures of people, food, and landscapes.
Clearly, there’s a long way to go before we can talk to photoshop telling it to draw something. I’m really excited to see what sorts of algorithms people come up with to bridge this gap and allow the model to generalize better.