Source code for dltk.models.gan.wdcgan

from __future__ import division
from __future__ import absolute_import
from __future__ import print_function

import tensorflow as tf

from dltk.core.modules import *
from dltk.models.gan.dcgan import DCGAN


[docs]class WDCGAN(DCGAN): """Deep convolution generative network trained with a wasserstein loss based on https://github.com/igul222/improved_wgan_training """ def __init__(self, discriminator_filters=(64, 128, 256, 512), generator_filters=(512, 256, 128, 64, 1), discriminator_strides=((1, 1, 1), (2, 2, 2), (1, 1, 1), (2, 2, 2)), generator_strides=((7, 7, 7), (2, 2, 2), (2, 2, 2)), relu_leakiness=0.01, generator_activation=tf.identity, clip_val=0.01, improved=True, name='wdcgan'): """Wasserstein Deep Convolutional Generative Adversarial Network Parameters ---------- discriminator_filters : list or tuple list of filters used for the discriminator generator_filters : list or tuple list of filters used for the generator discriminator_strides : list or tuple list of strides used for the discriminator generator_strides : list or tuple list of strides used for the generator relu_leakiness : float leakiness of the relus used in the discriminator generator_activation : function function to be used as activation for the generator clip_val: float If improved, factor for the gradient penalty. If not improved, value to clip the weights by improved : bool flag to toggle improved wasserstein GAN name : string name of the network used for scoping """ self.batch_size = tf.placeholder(tf.int32, shape=[]) self.improved = improved self.clip_val = clip_val super(WDCGAN, self).__init__(discriminator_filters, generator_filters, discriminator_strides, generator_strides, relu_leakiness, generator_activation, name) def _build(self, noise, samples, is_training=True): """Constructs a Wasserstein GAN Parameters ---------- noise : tf.Tensor noise tensor for the generator to generate fake samples samples : tf.Tensor real samples used by the discriminator is_training : bool flag to specify whether this is training - passed to batch normalization Returns ------- dict output dictionary containing: - `gen` - generator output dictionary - `gen` - generated sample - `disc_gen` - discriminator output dictionary for generated sample - `disc_gen_logits` - mean of logit values of discriminator on generated samples - `disc_sample_logits` - mean of logit values of discriminator on real samples - `disc_sample` - discriminator output dictionary for real sample - `d_loss` - discriminator loss - `g_loss` - generator loss - `clip_ops` - None if improved wgan, else operation to run to clip critic/discriminator """ if self.in_filter is None: self.in_filter = samples.get_shape().as_list()[-1] assert self.in_filter == samples.get_shape().as_list()[-1], 'Network was built for a different input shape' assert self.in_filter == self.generator_filters[-1], 'Generator was built for a different sample shape' out = {} self.disc = self.Discriminator(self.discriminator_filters, self.discriminator_strides, self.relu_leakiness, 'disc') self.gen = self.Generator(self.generator_filters, self.generator_strides, self.generator_activation, 'gen') out['gen'] = self.gen(noise) out['disc_gen'] = self.disc(out['gen']['gen']) out['disc_sample'] = self.disc(samples) out['disc_gen_logits'] = tf.reduce_mean(out['disc_gen']['logits']) out['disc_sample_logits'] = tf.reduce_mean(out['disc_sample']['logits']) out['d_loss'] = out['disc_gen_logits'] - out['disc_sample_logits'] out['g_loss'] = - out['disc_gen_logits'] if self.improved: with tf.variable_scope('gradient_penalty'): alpha = tf.random_uniform( shape=[self.batch_size,1,1,1], minval=0., maxval=1. ) differences = out['gen']['gen'] - samples interpolates = samples + alpha * differences gradients = tf.gradients(self.disc(interpolates)['logits'], [interpolates])[0] slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2) # Over loading clip_val out['d_loss'] += self.clip_val * gradient_penalty out['clip_ops'] = None else: with tf.variable_scope('clipping'): clip_ops = [] for var in self.disc.get_variables(): clip_ops.append( tf.assign( var, tf.clip_by_value(var, -self.clip_val, self.clip_val) ) ) out['clip_ops'] = tf.group(*clip_ops) return out