from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
import tensorflow as tf
import numpy as np
[docs]def convolutional_autoencoder_3d(inputs, num_convolutions=1,
num_hidden_units=128, filters=(16, 32, 64),
strides=((2, 2, 2), (2, 2, 2), (2, 2, 2)),
mode=tf.estimator.ModeKeys.TRAIN,
use_bias=False,
activation=tf.nn.relu6,
kernel_initializer=tf.initializers.variance_scaling(distribution='uniform'),
bias_initializer=tf.zeros_initializer(),
kernel_regularizer=None,
bias_regularizer=None):
"""Convolutional autoencoder with num_convolutions on len(filters)
resolution scales. The downsampling of features is done via strided
convolutions and upsampling via strided transpose convolutions. On each
resolution scale s are num_convolutions with filter size = filters[s].
strides[s] determine the downsampling factor at each resolution scale.
Args:
inputs (tf.Tensor): Input tensor to the network, required to be of
rank 5.
num_convolutions (int, optional): Number of convolutions per resolution
scale.
num_hidden_units (int, optional): Number of hidden units.
filters (tuple or list, optional): Number of filters for all
convolutions at each resolution scale.
strides (tuple or list, optional): Stride of the first convolution on a
resolution scale.
mode (str, optional): One of the tf.estimator.ModeKeys strings: TRAIN,
EVAL or PREDICT
use_bias (bool, optional): Boolean, whether the layer uses a bias.
activation (optional): A function to use as activation function.
kernel_initializer (TYPE, optional): An initializer for the convolution
kernel.
bias_initializer (TYPE, optional): An initializer for the bias vector.
If None, no bias will be applied.
kernel_regularizer (None, optional): Optional regularizer for the
convolution kernel.
bias_regularizer (None, optional): Optional regularizer for the bias
vector.
Returns:
dict: dictionary of output tensors
"""
outputs = {}
assert len(strides) == len(filters)
assert len(inputs.get_shape().as_list()) == 5, \
'inputs are required to have a rank of 5.'
conv_op = tf.layers.conv3d
tp_conv_op = tf.layers.conv3d_transpose
conv_params = {'padding': 'same',
'use_bias': use_bias,
'kernel_initializer': kernel_initializer,
'bias_initializer': bias_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer}
x = inputs
tf.logging.info('Input tensor shape {}'.format(x.get_shape()))
# Convolutional feature encoding blocks with num_convolutions at different
# resolution scales res_scales
for res_scale in range(0, len(filters)):
for i in range(0, num_convolutions - 1):
with tf.variable_scope('enc_unit_{}_{}'.format(res_scale, i)):
x = conv_op(inputs=x,
filters=filters[res_scale],
kernel_size=(3, 3, 3),
strides=(1, 1, 1),
**conv_params)
x = tf.layers.batch_normalization(
inputs=x,
training=mode == tf.estimator.ModeKeys.TRAIN)
x = activation(x)
tf.logging.info('Encoder at res_scale {} shape: {}'.format(
res_scale, x.get_shape()))
# Employ strided convolutions to downsample
with tf.variable_scope('enc_unit_{}_{}'.format(
res_scale,
num_convolutions)):
# Adjust the strided conv kernel size to prevent losing information
k_size = [s * 2 if s > 1 else 3 for s in strides[res_scale]]
x = conv_op(inputs=x,
filters=filters[res_scale],
kernel_size=k_size,
strides=strides[res_scale],
**conv_params)
x = tf.layers.batch_normalization(
x, training=mode == tf.estimator.ModeKeys.TRAIN)
x = activation(x)
tf.logging.info('Encoder at res_scale {} tensor shape: {}'.format(
res_scale, x.get_shape()))
# Densely connected layer of hidden units
x_shape = x.get_shape().as_list()
x = tf.reshape(x, (tf.shape(x)[0], np.prod(x_shape[1:])))
x = tf.layers.dense(inputs=x,
units=num_hidden_units,
use_bias=conv_params['use_bias'],
kernel_initializer=conv_params['kernel_initializer'],
bias_initializer=conv_params['bias_initializer'],
kernel_regularizer=conv_params['kernel_regularizer'],
bias_regularizer=conv_params['bias_regularizer'],
name='hidden_units')
outputs['hidden_units'] = x
tf.logging.info('Hidden units tensor shape: {}'.format(x.get_shape()))
x = tf.layers.dense(inputs=x,
units=np.prod(x_shape[1:]),
activation=activation,
use_bias=conv_params['use_bias'],
kernel_initializer=conv_params['kernel_initializer'],
bias_initializer=conv_params['bias_initializer'],
kernel_regularizer=conv_params['kernel_regularizer'],
bias_regularizer=conv_params['bias_regularizer'])
x = tf.reshape(x, [tf.shape(x)[0]] + list(x_shape)[1:])
tf.logging.info('Decoder input tensor shape: {}'.format(x.get_shape()))
# Decoding blocks with num_convolutions at different resolution scales
# res_scales
for res_scale in reversed(range(0, len(filters))):
# Employ strided transpose convolutions to upsample
with tf.variable_scope('dec_unit_{}_0'.format(res_scale)):
# Adjust the strided tp conv kernel size to prevent losing
# information
k_size = [s * 2 if s > 1 else 3 for s in strides[res_scale]]
x = tp_conv_op(inputs=x,
filters=filters[res_scale],
kernel_size=k_size,
strides=strides[res_scale],
**conv_params)
x = tf.layers.batch_normalization(
x, training=mode == tf.estimator.ModeKeys.TRAIN)
x = activation(x)
tf.logging.info('Decoder at res_scale {} tensor shape: {}'.format(
res_scale, x.get_shape()))
for i in range(1, num_convolutions):
with tf.variable_scope('dec_unit_{}_{}'.format(res_scale, i)):
x = conv_op(inputs=x,
filters=filters[res_scale],
kernel_size=(3, 3, 3),
strides=(1, 1, 1),
**conv_params)
x = tf.layers.batch_normalization(
x, training=mode == tf.estimator.ModeKeys.TRAIN)
x = activation(x)
tf.logging.info('Decoder at res_scale {} tensor shape: {}'.format(
res_scale, x.get_shape()))
# A final convolution reduces the number of output features to those of
# the inputs
x = conv_op(inputs=x,
filters=inputs.get_shape().as_list()[-1],
kernel_size=(1, 1, 1),
strides=(1, 1, 1),
**conv_params)
tf.logging.info('Output tensor shape: {}'.format(x.get_shape()))
outputs['x_'] = x
return outputs