Source code for dltk.models.gan.dcgan

from __future__ import division
from __future__ import absolute_import
from __future__ import print_function

import tensorflow as tf
import numpy as np
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.training import moving_averages

from dltk.core.modules import *

[docs]class DCGAN(AbstractModule): """Convolutional Autoencoder This module builds a convolutional autoencoder with varying number of layers and hidden units. """ def __init__(self, discriminator_filters=(64, 128, 256, 512), generator_filters=(512, 256, 128, 64, 1), discriminator_strides=((1, 1, 1), (2, 2, 2), (1, 1, 1), (2, 2, 2)), generator_strides=((7, 7, 7), (2, 2, 2), (2, 2, 2)), relu_leakiness=0.01, generator_activation=tf.identity, name='dcgan'): """Deep Convolutional Generative Adversarial Network Parameters ---------- discriminator_filters : list or tuple list of filters used for the discriminator generator_filters : list or tuple list of filters used for the generator discriminator_strides : list or tuple list of strides used for the discriminator generator_strides : list or tuple list of strides used for the generator relu_leakiness : float leakiness of the relus used in the discriminator generator_activation : function function to be used as activation for the generator name : string name of the network used for scoping """ self.discriminator_filters = discriminator_filters self.discriminator_strides = discriminator_strides self.generator_filters = generator_filters self.generator_strides = generator_strides self.discriminator_strides = discriminator_strides self.relu_leakiness = relu_leakiness self.generator_activation = generator_activation self.in_filter = None assert len(discriminator_filters) == len(discriminator_strides) super(DCGAN, self).__init__(name)
[docs] class Discriminator(AbstractModule): def __init__(self, filters, strides, relu_leakiness, name): """Constructs the discriminator of a DCGAN Parameters ---------- filters : list or tuple filters for convolutional layers strides : list or tuple strides to be used for convolutions relu_leakiness : float leakines of relu nonlinearity name : string name of the network """ self.filters = filters self.strides = strides self.relu_leakiness = relu_leakiness self.in_filter = None assert len(strides) == len(filters) super(DCGAN.Discriminator, self).__init__(name) def _build(self, x, is_training=True): if self.in_filter is None: self.in_filter = x.get_shape().as_list()[-1] assert self.in_filter == x.get_shape().as_list()[-1], 'Network was built for a different input shape' out = {} for i in range(len(self.filters) -1): with tf.variable_scope('l{}'.format(i)): x = Convolution(self.filters[i], 4, self.strides[i])(x) x = BatchNorm()(x) x = leaky_relu(x, self.relu_leakiness) with tf.variable_scope('final'): x = tf.reshape(x, (tf.shape(x)[0], np.prod(x.get_shape().as_list()[1:]))) x = Linear(1)(x) out['logits'] = x x = tf.nn.sigmoid(x) out['probs'] = x out['pred'] = tf.greater(x, 0.5) return out
[docs] class Generator(AbstractModule): def __init__(self, filters, strides, output_activation, name): """Constructs the discriminator of a DCGAN Parameters ---------- filters : list or tuple filters for convolutional layers strides : list or tuple strides to be used for convolutions name : string name of the network """ self.filters = filters self.strides = strides self.in_filter = None self.output_activation = output_activation assert len(strides) == len(filters) super(DCGAN.Generator, self).__init__(name) def _build(self, x, is_training=True): if self.in_filter is None: self.in_filter = x.get_shape().as_list()[-1] assert self.in_filter == x.get_shape().as_list()[-1], 'Network was built for a different input shape' x = tf.reshape(x, [tf.shape(x)[0]] + [1,] * len(self.strides[0]) + [self.in_filter]) out = {} for i in range(len(self.filters) - 1): with tf.variable_scope('l{}'.format(i)): x = TransposedConvolution(self.filters[i], strides=self.strides[i])(x) x = BatchNorm()(x) x = tf.nn.relu(x) with tf.variable_scope('final'): x = TransposedConvolution(self.filters[-1], strides=self.strides[-1])(x) x = self.output_activation(x) out['gen'] = x return out
def _build(self, noise, samples, is_training=True): """Constructs a DCGAN Parameters ---------- noise : tf.Tensor noise tensor for the generator to generate fake samples samples : tf.Tensor real samples used by the discriminator is_training : bool flag to specify whether this is training - passed to batch normalization Returns ------- dict output dictionary containing: - `gen` - generator output dictionary - `gen` - generated sample - `disc_gen` - discriminator output dictionary for generated sample - `disc_sample` - discriminator output dictionary for real sample - `d_loss` - discriminator loss - `g_loss` - generator loss """ if self.in_filter is None: self.in_filter = samples.get_shape().as_list()[-1] assert self.in_filter == samples.get_shape().as_list()[-1], 'Network was built for a different input shape' assert self.in_filter == self.generator_filters[-1], 'Generator was built for a different sample shape' out = {} self.disc = self.Discriminator(self.discriminator_filters, self.discriminator_strides, self.relu_leakiness, 'disc') self.gen = self.Generator(self.generator_filters, self.generator_strides, self.generator_activation, 'gen') out['gen'] = self.gen(noise) out['disc_gen'] = self.disc(out['gen']['gen']) out['disc_sample'] = self.disc(samples) out['d_loss'] = -(tf.reduce_mean(tf.log(out['disc_sample']['probs'])) + tf.reduce_mean(tf.log(1. - out['disc_gen']['probs']))) out['g_loss'] = -tf.reduce_mean(tf.log(out['disc_gen']['probs'])) return out