Source code for dltk.models.segmentation.unet

from __future__ import division
from __future__ import absolute_import
from __future__ import print_function

import tensorflow as tf
import numpy as np
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.training import moving_averages

from dltk.core.modules import *


[docs]class UpsampleAndConcat(AbstractModule): """UNET upsampling module according to O. Ronneberger. """ def __init__(self, strides, name='upandconcat'): """Constructs an UpsampleAndConcat module Parameters ---------- strides : list or tuple strides to use for upsampling name : string name of the module """ self.strides = strides super(UpsampleAndConcat, self).__init__(name) def _build(self, x, x_up): """Applies the UpsampleAndConcat operation Parameters ---------- x : tf.Tensor tensor to be upsampled x_up : tf.Tensor tensor from the same scale to be convolved and added to the upsampled tensor Returns ------- tf.Tensor output of the operation """ t_conv = BilinearUpsample(strides=self.strides)(x) return tf.concat(axis=-1,values=[x_up, t_conv])
[docs]class ResUNET(SaveableModule): """ ResUNET module with residual encoder This module builds a UNET for segmentation using a residual encoder. """ output_keys = ['logits', 'y_prob', 'y_'] def __init__(self, num_classes, num_residual_units=3, filters=(16, 64, 128, 256, 512), strides=((1, 1, 1), (2, 2, 2), (2, 2, 2), (2, 2, 2), (1, 1, 1)), relu_leakiness=0.1, name='resnetfcn'): """Builds a residual UNET for segmentation Parameters ---------- num_classes : int number of classes to segment num_residual_units : int number of residual units per scale filters : tuple or list number of filters per scale. The first is used for the initial convolution without residual connections strides : tuple or list strides per scale. The first is used for the initial convolution without residual connections relu_leakiness : float leakiness of the relus used name : string name of the network """ self.num_classes = num_classes self.num_residual_units = num_residual_units self.filters = filters self.strides = strides self.relu_leakiness = relu_leakiness self.input_filters = None self.rank = None super(ResUNET, self).__init__(name) def _build_input_placeholder(self): """Abstract function to build input placeholders """ assert self.input_filters is not None, 'self.input_filters must be defined' self.input_placeholders = [tf.placeholder(tf.float32, shape=[None, ] * (1 + len(self.strides[0])) + [self.input_filters])] def _build(self, inp, is_training=True): """Constructs a ResNetUNET using the input tensor Parameters ---------- inp : tf.Tensor input tensor is_training : bool flag to specify whether this is training - passed to batch normalization Returns ------- dict output dictionary containing: - `logits` - logits of the classification - `y_prob` - classification probabilities - `y_` - prediction of the classification """ outputs = {} filters = self.filters strides = self.strides assert len(strides) == len(filters) if self.input_filters is None: self.input_filters = inp.get_shape().as_list()[-1] self._build_input_placeholder() assert self.input_filters == inp.get_shape().as_list()[-1] if self.rank is None: self.rank = len(strides[0]) assert len(inp.get_shape().as_list()) == self.rank + 2, \ 'Stride gives rank {} input is rank {}'.format(self.rank, len(inp.get_shape().as_list()) - 2) x = inp x = Convolution(filters[0], strides=strides[0])(x) tf.logging.info(x.get_shape()) # residual feature encoding blocks with num_residual_units at different scales defined via strides scales = [x] saved_strides = [] for scale in range(1, len(filters)): with tf.variable_scope('unit_%d_0' % (scale)): x = VanillaResidualUnit(filters[scale], stride=strides[scale])(x, is_training=is_training) saved_strides.append(strides[scale]) for i in range(1, self.num_residual_units): with tf.variable_scope('unit_%d_%d' % (scale, i)): x = VanillaResidualUnit(filters[scale], stride=[1] * self.rank)(x, is_training=is_training) scales.append(x) tf.logging.info('feat_scale_%d shape %s', scale, x.get_shape()) # decoder for scale in range(len(filters) - 2, -1, -1): with tf.variable_scope('upsample_%d' % scale): tf.logging.info('Building upsampling for scale %d with x (%s) x_up (%s) stride (%s)' % (scale, x.get_shape().as_list(), scales[scale].get_shape().as_list(), saved_strides[scale])) x = UpsampleAndConcat(saved_strides[scale])(x, scales[scale]) with tf.variable_scope('up_unit_%d_0' % (scale)): x = VanillaResidualUnit(filters[scale], stride=[1] * self.rank)(x, is_training=is_training) tf.logging.info('up_%d shape %s', scale, x.get_shape()) with tf.variable_scope('last'): x = Convolution(self.num_classes, 1, strides=[1] * self.rank)(x) outputs['logits'] = x tf.logging.info('last conv shape %s', x.get_shape()) with tf.variable_scope('pred'): y_prob = tf.nn.softmax(x) outputs['y_prob'] = y_prob y_ = tf.argmax(x, axis=-1) if self.num_classes > 1 else tf.cast(tf.greater_equal(x[..., 0], 0.5), tf.int32) outputs['y_'] = y_ return outputs