Source code for dltk.networks.segmentation.unet

from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

import tensorflow as tf

# residual_unit_3d, se_residual_unit_3d, resnext_unit_3d
from dltk.core.units.residual_unit import se_resnext_unit_3d
from dltk.core.upsample import linear_upsample_3d
from dltk.core.activations import leaky_relu


[docs]def upsample_and_concat(inputs, inputs2, strides=(2, 2, 2)): """Upsampling and concatenation layer according to [1]. [1] O. Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image Segmentation. MICCAI 2015. Args: inputs (TYPE): Input features to be upsampled. inputs2 (TYPE): Higher resolution features from the encoder to concatenate. strides (tuple, optional): Upsampling factor for a strided transpose convolution. Returns: tf.Tensor: Upsampled feature tensor """ assert len(inputs.get_shape().as_list()) == 5, \ 'inputs are required to have a rank of 5.' assert len(inputs.get_shape().as_list()) == len(inputs2.get_shape().as_list()), \ 'Ranks of input and input2 differ' # Upsample inputs inputs = linear_upsample_3d(inputs, strides) return tf.concat(axis=-1, values=[inputs2, inputs])
[docs]def residual_unet_3d(inputs, num_classes, num_res_units=1, filters=(16, 32, 64, 128), strides=((1, 1, 1), (2, 2, 2), (2, 2, 2), (2, 2, 2)), mode=tf.estimator.ModeKeys.EVAL, use_bias=False, activation=leaky_relu, kernel_initializer=tf.initializers.variance_scaling(distribution='uniform'), bias_initializer=tf.zeros_initializer(), kernel_regularizer=None, bias_regularizer=None): """ Image segmentation network based on a flexible UNET architecture [1] using residual units [2] as feature extractors. Downsampling and upsampling of features is done via strided convolutions and transpose convolutions, respectively. On each resolution scale s are num_residual_units with filter size = filters[s]. strides[s] determine the downsampling factor at each resolution scale. [1] O. Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image Segmentation. MICCAI 2015. [2] K. He et al. Identity Mappings in Deep Residual Networks. ECCV 2016. Args: inputs (tf.Tensor): Input feature tensor to the network (rank 5 required). num_classes (int): Number of output classes. num_res_units (int, optional): Number of residual units at each resolution scale. filters (tuple, optional): Number of filters for all residual units at each resolution scale. strides (tuple, optional): Stride of the first unit on a resolution scale. mode (TYPE, optional): One of the tf.estimator.ModeKeys strings: TRAIN, EVAL or PREDICT use_bias (bool, optional): Boolean, whether the layer uses a bias. activation (optional): A function to use as activation function. kernel_initializer (TYPE, optional): An initializer for the convolution kernel. bias_initializer (TYPE, optional): An initializer for the bias vector. If None, no bias will be applied. kernel_regularizer (None, optional): Optional regularizer for the convolution kernel. bias_regularizer (None, optional): Optional regularizer for the bias vector. Returns: dict: dictionary of output tensors """ outputs = {} assert len(strides) == len(filters) assert len(inputs.get_shape().as_list()) == 5, \ 'inputs are required to have a rank of 5.' conv_params = {'padding': 'same', 'use_bias': use_bias, 'kernel_initializer': kernel_initializer, 'bias_initializer': bias_initializer, 'kernel_regularizer': kernel_regularizer, 'bias_regularizer': bias_regularizer} x = inputs # Initial convolution with filters[0] x = tf.layers.conv3d(inputs=x, filters=filters[0], kernel_size=(3, 3, 3), strides=strides[0], **conv_params) tf.logging.info('Init conv tensor shape {}'.format(x.get_shape())) # Residual feature encoding blocks with num_res_units at different # resolution scales res_scales res_scales = [x] saved_strides = [] for res_scale in range(1, len(filters)): # Transition block # Adjust the strided conv kernel size to prevent losing information with tf.variable_scope('down_{}'.format(res_scale)): k = [s * 2 if s > 1 else k for k, s in zip((3, 3, 3), strides[res_scale])] x = tf.layers.conv3d(inputs=x, filters=filters[res_scale], kernel_size=k, strides=strides[res_scale], **conv_params) # Features are downsampled via strided convolutions. These are defined # in `strides` and subsequently saved saved_strides.append(strides[res_scale]) # Encoder blocks for i in range(0, num_res_units): with tf.variable_scope('enc_unit_{}_{}'.format(res_scale, i)): x = se_resnext_unit_3d( inputs=x, out_filters=filters[res_scale], activation=activation, mode=mode) res_scales.append(x) tf.logging.info('Encoder at res_scale {} tensor shape: {}'.format( res_scale, x.get_shape())) # Upsample and concat layers [1] reconstruct the predictions to higher # resolution scales for res_scale in range(len(filters) - 2, -1, -1): # Transition block for upsampling with tf.variable_scope('up_concat_{}'.format(res_scale)): x = upsample_and_concat( inputs=x, inputs2=res_scales[res_scale], strides=saved_strides[res_scale]) for i in range(0, num_res_units): with tf.variable_scope('dec_unit_{}_{}'.format(res_scale, i)): x = se_resnext_unit_3d( inputs=x, out_filters=filters[res_scale], mode=mode) tf.logging.info('Decoder at res_scale {} tensor shape: {}'.format( res_scale, x.get_shape())) # Last convolution with tf.variable_scope('last'): x = tf.layers.conv3d(inputs=x, filters=num_classes, kernel_size=(1, 1, 1), strides=(1, 1, 1), **conv_params) tf.logging.info('Output tensor shape {}'.format(x.get_shape())) # Define the outputs outputs['logits'] = x with tf.variable_scope('pred'): y_prob = tf.nn.softmax(x) outputs['y_prob'] = y_prob y_ = tf.argmax(x, axis=-1) \ if num_classes > 1 \ else tf.cast(tf.greater_equal(x[..., 0], 0.5), tf.int32) outputs['y_'] = y_ return outputs
[docs]def asymmetric_residual_unet_3d( inputs, num_classes, num_res_units=1, filters=(16, 32, 64, 128), strides=((1, 1, 1), (2, 2, 2), (2, 2, 2), (2, 2, 2)), mode=tf.estimator.ModeKeys.EVAL, use_bias=False, activation=leaky_relu, kernel_initializer=tf.initializers.variance_scaling(distribution='uniform'), bias_initializer=tf.zeros_initializer(), kernel_regularizer=None, bias_regularizer=None): """ Image segmentation network based on a flexible UNET architecture [1] using residual units [2] as feature extractors. Downsampling and upsampling of features is done via strided convolutions and transpose convolutions, respectively. On each resolution scale s are num_residual_units with filter size = filters[s]. strides[s] determine the downsampling factor at each resolution scale. [1] O. Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image Segmentation. MICCAI 2015. [2] K. He et al. Identity Mappings in Deep Residual Networks. ECCV 2016. Args: inputs (tf.Tensor): Input feature tensor to the network (rank 5 required). num_classes (int): Number of output classes. num_res_units (int, optional): Number of residual units at each resolution scale. filters (tuple, optional): Number of filters for all residual units at each resolution scale. strides (tuple, optional): Stride of the first unit on a resolution scale. mode (TYPE, optional): One of the tf.estimator.ModeKeys strings: TRAIN, EVAL or PREDICT use_bias (bool, optional): Boolean, whether the layer uses a bias. activation (optional): A function to use as activation function. kernel_initializer (TYPE, optional): An initializer for the convolution kernel. bias_initializer (TYPE, optional): An initializer for the bias vector. If None, no bias will be applied. kernel_regularizer (None, optional): Optional regularizer for the convolution kernel. bias_regularizer (None, optional): Optional regularizer for the bias vector. Returns: dict: dictionary of output tensors """ outputs = {} assert len(strides) == len(filters) assert len(inputs.get_shape().as_list()) == 5, \ 'inputs are required to have a rank of 5.' conv_params = {'padding': 'same', 'use_bias': use_bias, 'kernel_initializer': kernel_initializer, 'bias_initializer': bias_initializer, 'kernel_regularizer': kernel_regularizer, 'bias_regularizer': bias_regularizer} x = inputs # Initial convolution with filters[0] x = tf.layers.conv3d(inputs=x, filters=filters[0], kernel_size=(3, 3, 3), strides=strides[0], **conv_params) tf.logging.info('Init conv tensor shape {}'.format(x.get_shape())) # Residual feature encoding blocks with num_res_units at different # resolution scales res_scales res_scales = [x] saved_strides = [] for res_scale in range(1, len(filters)): # Features are downsampled via strided convolutions. These are defined # in `strides` and subsequently saved with tf.variable_scope('enc_unit_{}_0'.format(res_scale)): x = se_resnext_unit_3d( inputs=x, out_filters=filters[res_scale], strides=strides[res_scale], activation=activation, mode=mode) saved_strides.append(strides[res_scale]) for i in range(1, num_res_units): with tf.variable_scope('enc_unit_{}_{}'.format(res_scale, i)): x = se_resnext_unit_3d( inputs=x, out_filters=filters[res_scale], strides=(1, 1, 1), activation=activation, mode=mode) res_scales.append(x) tf.logging.info('Encoder at res_scale {} tensor shape: {}'.format( res_scale, x.get_shape())) # Upsample and concat layers [1] reconstruct the predictions to higher # resolution scales for res_scale in range(len(filters) - 2, -1, -1): with tf.variable_scope('up_concat_{}'.format(res_scale)): x = upsample_and_concat( inputs=x, inputs2=res_scales[res_scale], strides=saved_strides[res_scale]) with tf.variable_scope('dec_unit_{}'.format(res_scale)): x = se_resnext_unit_3d( inputs=x, out_filters=filters[res_scale], strides=(1, 1, 1), activation=activation, mode=mode) tf.logging.info('Decoder at res_scale {} tensor shape: {}'.format( res_scale, x.get_shape())) # Last convolution with tf.variable_scope('last'): x = tf.layers.conv3d(inputs=x, filters=num_classes, kernel_size=(1, 1, 1), strides=(1, 1, 1), **conv_params) tf.logging.info('Output tensor shape {}'.format(x.get_shape())) # Define the outputs outputs['logits'] = x with tf.variable_scope('pred'): y_prob = tf.nn.softmax(x) outputs['y_prob'] = y_prob y_ = tf.argmax(x, axis=-1) \ if num_classes > 1 \ else tf.cast(tf.greater_equal(x[..., 0], 0.5), tf.int32) outputs['y_'] = y_ return outputs