from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
import tensorflow as tf
import numpy as np
[docs]def get_linear_upsampling_kernel(kernel_spatial_shape,
out_filters,
in_filters,
trainable=False):
"""Builds a kernel for linear upsampling with the shape
[kernel_spatial_shape] + [out_filters, in_filters]. Can be set to
trainable to potentially learn a better upsamling.
Args:
kernel_spatial_shape (list or tuple): Spatial dimensions of the
upsampling kernel. Is required to be of rank 2 or 3,
(i.e. [dim_x, dim_y] or [dim_x, dim_y, dim_z])
out_filters (int): Number of output filters.
in_filters (int): Number of input filters.
trainable (bool, optional): Flag to set the returned tf.Variable
to be trainable or not.
Returns:
tf.Variable: Linear upsampling kernel
"""
rank = len(list(kernel_spatial_shape))
assert 1 < rank < 4, \
'Transposed convolutions are only supported in 2D and 3D'
kernel_shape = tuple(kernel_spatial_shape + [out_filters, in_filters])
size = kernel_spatial_shape
factor = (np.array(size) + 1) // 2
center = np.zeros_like(factor, np.float)
for i in range(len(factor)):
if size[i] % 2 == 1:
center[i] = factor[i] - 1
else:
center[i] = factor[i] - 0.5
weights = np.zeros(kernel_shape)
if rank == 2:
og = np.ogrid[:size[0], :size[1]]
x_filt = (1 - abs(og[0] - center[0]) / np.float(factor[0]))
y_filt = (1 - abs(og[1] - center[1]) / np.float(factor[1]))
filt = x_filt * y_filt
for i in range(out_filters):
weights[:, :, i, i] = filt
else:
og = np.ogrid[:size[0], :size[1], :size[2]]
x_filt = (1 - abs(og[0] - center[0]) / np.float(factor[0]))
y_filt = (1 - abs(og[1] - center[1]) / np.float(factor[1]))
z_filt = (1 - abs(og[2] - center[2]) / np.float(factor[2]))
filt = x_filt * y_filt * z_filt
for i in range(out_filters):
weights[:, :, :, i, i] = filt
init = tf.constant_initializer(value=weights, dtype=tf.float32)
return tf.get_variable(name="linear_up_kernel",
initializer=init,
shape=weights.shape,
trainable=trainable)
[docs]def linear_upsample_3d(inputs,
strides=(2, 2, 2),
use_bias=False,
trainable=False,
name='linear_upsample_3d'):
"""Linear upsampling layer in 3D using strided transpose convolutions. The
upsampling kernel size will be automatically computed to avoid
information loss.
Args:
inputs (tf.Tensor): Input tensor to be upsampled
strides (tuple, optional): The strides determine the upsampling factor
in each dimension.
use_bias (bool, optional): Flag to train an additional bias.
trainable (bool, optional): Flag to set the variables to be trainable or not.
name (str, optional): Name of the layer.
Returns:
tf.Tensor: Upsampled Tensor
"""
static_inp_shape = tuple(inputs.get_shape().as_list())
dyn_inp_shape = tf.shape(inputs)
rank = len(static_inp_shape)
num_filters = static_inp_shape[-1]
strides_5d = [1, ] + list(strides) + [1, ]
kernel_size = [2 * s if s > 1 else 1 for s in strides]
kernel = get_linear_upsampling_kernel(
kernel_spatial_shape=kernel_size,
out_filters=num_filters,
in_filters=num_filters,
trainable=trainable)
dyn_out_shape = [dyn_inp_shape[i] * strides_5d[i] for i in range(rank)]
dyn_out_shape[-1] = num_filters
static_out_shape = [static_inp_shape[i] * strides_5d[i]
if isinstance(static_inp_shape[i], int)
else None for i in range(rank)]
static_out_shape[-1] = num_filters
tf.logging.info('Upsampling from {} to {}'.format(
static_inp_shape, static_out_shape))
upsampled = tf.nn.conv3d_transpose(
value=inputs,
filter=kernel,
output_shape=dyn_out_shape,
strides=strides_5d,
padding='SAME',
name='upsample')
upsampled.set_shape(static_out_shape)
return upsampled