Source code for dltk.core.upsample

from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

import tensorflow as tf
import numpy as np

[docs]def get_linear_upsampling_kernel(kernel_spatial_shape, out_filters, in_filters, trainable=False): """Builds a kernel for linear upsampling with the shape [kernel_spatial_shape] + [out_filters, in_filters]. Can be set to trainable to potentially learn a better upsamling. Args: kernel_spatial_shape (list or tuple): Spatial dimensions of the upsampling kernel. Is required to be of rank 2 or 3, (i.e. [dim_x, dim_y] or [dim_x, dim_y, dim_z]) out_filters (int): Number of output filters. in_filters (int): Number of input filters. trainable (bool, optional): Flag to set the returned tf.Variable to be trainable or not. Returns: tf.Variable: Linear upsampling kernel """ rank = len(list(kernel_spatial_shape)) assert 1 < rank < 4, \ 'Transposed convolutions are only supported in 2D and 3D' kernel_shape = tuple(kernel_spatial_shape + [out_filters, in_filters]) size = kernel_spatial_shape factor = (np.array(size) + 1) // 2 center = np.zeros_like(factor, np.float) for i in range(len(factor)): if size[i] % 2 == 1: center[i] = factor[i] - 1 else: center[i] = factor[i] - 0.5 weights = np.zeros(kernel_shape) if rank == 2: og = np.ogrid[:size[0], :size[1]] x_filt = (1 - abs(og[0] - center[0]) / np.float(factor[0])) y_filt = (1 - abs(og[1] - center[1]) / np.float(factor[1])) filt = x_filt * y_filt for i in range(out_filters): weights[:, :, i, i] = filt else: og = np.ogrid[:size[0], :size[1], :size[2]] x_filt = (1 - abs(og[0] - center[0]) / np.float(factor[0])) y_filt = (1 - abs(og[1] - center[1]) / np.float(factor[1])) z_filt = (1 - abs(og[2] - center[2]) / np.float(factor[2])) filt = x_filt * y_filt * z_filt for i in range(out_filters): weights[:, :, :, i, i] = filt init = tf.constant_initializer(value=weights, dtype=tf.float32) return tf.get_variable(name="linear_up_kernel", initializer=init, shape=weights.shape, trainable=trainable)
[docs]def linear_upsample_3d(inputs, strides=(2, 2, 2), use_bias=False, trainable=False, name='linear_upsample_3d'): """Linear upsampling layer in 3D using strided transpose convolutions. The upsampling kernel size will be automatically computed to avoid information loss. Args: inputs (tf.Tensor): Input tensor to be upsampled strides (tuple, optional): The strides determine the upsampling factor in each dimension. use_bias (bool, optional): Flag to train an additional bias. trainable (bool, optional): Flag to set the variables to be trainable or not. name (str, optional): Name of the layer. Returns: tf.Tensor: Upsampled Tensor """ static_inp_shape = tuple(inputs.get_shape().as_list()) dyn_inp_shape = tf.shape(inputs) rank = len(static_inp_shape) num_filters = static_inp_shape[-1] strides_5d = [1, ] + list(strides) + [1, ] kernel_size = [2 * s if s > 1 else 1 for s in strides] kernel = get_linear_upsampling_kernel( kernel_spatial_shape=kernel_size, out_filters=num_filters, in_filters=num_filters, trainable=trainable) dyn_out_shape = [dyn_inp_shape[i] * strides_5d[i] for i in range(rank)] dyn_out_shape[-1] = num_filters static_out_shape = [static_inp_shape[i] * strides_5d[i] if isinstance(static_inp_shape[i], int) else None for i in range(rank)] static_out_shape[-1] = num_filters'Upsampling from {} to {}'.format( static_inp_shape, static_out_shape)) upsampled = tf.nn.conv3d_transpose( value=inputs, filter=kernel, output_shape=dyn_out_shape, strides=strides_5d, padding='SAME', name='upsample') upsampled.set_shape(static_out_shape) return upsampled