from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
import tensorflow as tf
# residual_unit_3d, se_residual_unit_3d, resnext_unit_3d
from dltk.core.units.residual_unit import se_resnext_unit_3d
from dltk.core.upsample import linear_upsample_3d
[docs]def upscore_layer_3d(inputs,
inputs2,
out_filters,
in_filters=None,
strides=(2, 2, 2),
mode=tf.estimator.ModeKeys.EVAL,
use_bias=False,
kernel_initializer=tf.initializers.variance_scaling(distribution='uniform'),
bias_initializer=tf.zeros_initializer(),
kernel_regularizer=None,
bias_regularizer=None):
"""Upscore layer according to [1].
[1] J. Long et al. Fully convolutional networks for semantic segmentation.
CVPR 2015.
Args:
inputs (tf.Tensor): Input features to be upscored.
inputs2 (tf.Tensor): Higher resolution features from the encoder to add.
out_filters (int): Number of output filters (typically, number of
segmentation classes)
in_filters (None, optional): None or number of input filters.
strides (tuple, optional): Upsampling factor for a strided transpose
convolution.
mode (TYPE, optional): One of the tf.estimator.ModeKeys strings: TRAIN,
EVAL or PREDICT
use_bias (bool, optional): Boolean, whether the layer uses a bias.
kernel_initializer (TYPE, optional): An initializer for the convolution
kernel.
bias_initializer (TYPE, optional): An initializer for the bias vector.
If None, no bias will be applied.
kernel_regularizer (None, optional): Optional regularizer for the
convolution kernel.
bias_regularizer (None, optional): Optional regularizer for the bias
vector.
Returns:
tf.Tensor: Upscore tensor
"""
conv_params = {'use_bias': use_bias,
'kernel_initializer': kernel_initializer,
'bias_initializer': bias_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer}
# Compute an upsampling shape dynamically from the input tensor. Input
# filters are required to be static.
if in_filters is None:
in_filters = inputs.get_shape().as_list()[-1]
assert len(inputs.get_shape().as_list()) == 5, \
'inputs are required to have a rank of 5.'
assert len(inputs.get_shape().as_list()) == len(inputs2.get_shape().as_list()), \
'Ranks of input and input2 differ'
# Account for differences in the number of input and output filters
if in_filters != out_filters:
x = tf.layers.conv3d(inputs=inputs,
filters=out_filters,
kernel_size=(1, 1, 1),
strides=(1, 1, 1),
padding='same',
name='filter_conversion',
**conv_params)
else:
x = inputs
# Upsample inputs
x = linear_upsample_3d(inputs=x, strides=strides)
# Skip connection
x2 = tf.layers.conv3d(inputs=inputs2,
filters=out_filters,
kernel_size=(1, 1, 1),
strides=(1, 1, 1),
padding='same',
**conv_params)
x2 = tf.layers.batch_normalization(
x2, training=mode == tf.estimator.ModeKeys.TRAIN)
# Return the element-wise sum
return tf.add(x, x2)
[docs]def residual_fcn_3d(inputs,
num_classes,
num_res_units=1,
filters=(16, 32, 64, 128),
strides=((1, 1, 1), (2, 2, 2), (2, 2, 2), (2, 2, 2)),
mode=tf.estimator.ModeKeys.EVAL,
use_bias=False,
activation=tf.nn.relu6,
kernel_initializer=tf.initializers.variance_scaling(distribution='uniform'),
bias_initializer=tf.zeros_initializer(),
kernel_regularizer=None,
bias_regularizer=None):
"""
Image segmentation network based on an FCN architecture [1] using
residual units [2] as feature extractors. Downsampling and upsampling
of features is done via strided convolutions and transpose convolutions,
respectively. On each resolution scale s are num_residual_units with
filter size = filters[s]. strides[s] determine the downsampling factor
at each resolution scale.
[1] J. Long et al. Fully convolutional networks for semantic segmentation.
CVPR 2015.
[2] K. He et al. Identity Mappings in Deep Residual Networks. ECCV 2016.
Args:
inputs (tf.Tensor): Input feature tensor to the network (rank 5
required).
num_classes (int): Number of output classes.
num_res_units (int, optional): Number of residual units at each
resolution scale.
filters (tuple, optional): Number of filters for all residual units at
each resolution scale.
strides (tuple, optional): Stride of the first unit on a resolution
scale.
mode (TYPE, optional): One of the tf.estimator.ModeKeys strings:
TRAIN, EVAL or PREDICT
use_bias (bool, optional): Boolean, whether the layer uses a bias.
activation (optional): A function to use as activation function.
kernel_initializer (TYPE, optional): An initializer for the convolution
kernel.
bias_initializer (TYPE, optional): An initializer for the bias vector.
If None, no bias will be applied.
kernel_regularizer (None, optional): Optional regularizer for the
convolution kernel.
bias_regularizer (None, optional): Optional regularizer for the bias
vector.
Returns:
dict: dictionary of output tensors
"""
outputs = {}
assert len(strides) == len(filters)
assert len(inputs.get_shape().as_list()) == 5, \
'inputs are required to have a rank of 5.'
conv_params = {'use_bias': use_bias,
'kernel_initializer': kernel_initializer,
'bias_initializer': bias_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer}
x = inputs
# Inital convolution with filters[0]
x = tf.layers.conv3d(inputs=x,
filters=filters[0],
kernel_size=(3, 3, 3),
strides=strides[0],
padding='same',
**conv_params)
tf.logging.info('Init conv tensor shape {}'.format(x.get_shape()))
# Residual feature encoding blocks with num_res_units at different
# resolution scales res_scales
res_scales = [x]
saved_strides = []
for res_scale in range(1, len(filters)):
# Features are downsampled via strided convolutions. These are defined
# in `strides` and subsequently saved
with tf.variable_scope('transition_unit_{}'.format(res_scale)):
k = [s * 2 if s > 1 else k for k, s in zip((3, 3, 3), strides[res_scale])]
x = tf.layers.conv3d(inputs=x,
filters=filters[res_scale],
kernel_size=k,
strides=strides[res_scale],
padding='same',
**conv_params)
saved_strides.append(strides[res_scale])
for i in range(0, num_res_units):
with tf.variable_scope('unit_{}_{}'.format(res_scale, i)):
x = se_resnext_unit_3d(
inputs=x,
out_filters=filters[res_scale],
activation=activation,
mode=mode)
res_scales.append(x)
tf.logging.info('Encoder at res_scale {} tensor shape: {}'.format(
res_scale, x.get_shape()))
# Upscore layers [2] reconstruct the predictions to higher resolution
# scales
for res_scale in range(len(filters) - 2, -1, -1):
with tf.variable_scope('upscore_{}'.format(res_scale)):
x = upscore_layer_3d(
inputs=x,
inputs2=res_scales[res_scale],
out_filters=num_classes,
strides=saved_strides[res_scale],
mode=mode,
**conv_params)
tf.logging.info('Decoder at res_scale {} tensor shape: {}'.format(
res_scale, x.get_shape()))
# Last convolution
with tf.variable_scope('last'):
x = tf.layers.conv3d(inputs=x,
filters=num_classes,
kernel_size=(1, 1, 1),
strides=(1, 1, 1),
padding='same',
**conv_params)
tf.logging.info('Output tensor shape {}'.format(x.get_shape()))
# Define the outputs
outputs['logits'] = x
with tf.variable_scope('pred'):
y_prob = tf.nn.softmax(x)
outputs['y_prob'] = y_prob
y_ = tf.argmax(x, axis=-1) \
if num_classes > 1 \
else tf.cast(tf.greater_equal(x[..., 0], 0.5), tf.int32)
outputs['y_'] = y_
return outputs