Source code for dltk.models.segmentation.fcn

from __future__ import division
from __future__ import absolute_import
from __future__ import print_function

import tensorflow as tf
import numpy as np
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.training import moving_averages

from dltk.core.modules import *


[docs]class Upscore(AbstractModule): """Upscore module according to J. Long. """ def __init__(self, out_filters, strides, name='upscore'): """Constructs an Upscore module Parameters ---------- out_filters : int number of output filters strides : list or tuple strides to use for upsampling name : string name of the module """ self.out_filters = out_filters self.strides = strides self.in_filters = None self.rank = None super(Upscore, self).__init__(name) def _build(self, x, x_up, is_training=True): """Applies the upscore operation Parameters ---------- x : tf.Tensor tensor to be upsampled x_up : tf.Tensor tensor from the same scale to be convolved and added to the upsampled tensor is_training : bool flag for specifying whether this is training - passed to batch normalization Returns ------- tf.Tensor output of the upscore operation """ # Compute an up-conv shape dynamically from the input tensor. Input filters are required to be static. if self.in_filters is None: self.in_filters = x.get_shape().as_list()[-1] assert self.in_filters == x.get_shape().as_list()[-1], 'Module was initialised for a different input shape' if self.rank is None: self.rank = len(self.strides) assert len(x.get_shape().as_list()) == self.rank + 2, \ 'Stride gives rank {} input is rank {}'.format(self.rank, len(x.get_shape().as_list()) - 2) # Account for differences in input and output filters if self.in_filters != self.out_filters: x = Convolution(self.out_filters, name='up_score_filter_conv', strides=[1] * self.rank)(x) t_conv = BilinearUpsample(strides=self.strides)(x) conv = Convolution(self.out_filters, 1, strides=[1] * self.rank)(x_up) conv = BatchNorm()(conv, is_training) return tf.add(t_conv, conv)
[docs]class ResNetFCN(AbstractModule): """FCN module with residual encoder This module builds a FCN for segmentation using a residual encoder. """ def __init__(self, num_classes, num_residual_units=3, filters=(16, 64, 128, 256, 512), strides=((1, 1, 1), (2, 2, 2), (2, 2, 2), (2, 2, 2), (1, 1, 1)), relu_leakiness=0.1, name='resnetfcn'): """Builds a residual FCN for segmentation Parameters ---------- num_classes : int number of classes to segment num_residual_units : int number of residual units per scale filters : tuple or list number of filters per scale. The first is used for the initial convolution without residual connections strides : tuple or list strides per scale. The first is used for the initial convolution without residual connections relu_leakiness : float leakiness of the relus used name : string name of the network """ self.num_classes = num_classes self.num_residual_units = num_residual_units self.filters = filters self.strides = strides self.relu_leakiness = relu_leakiness self.rank = None super(ResNetFCN, self).__init__(name) def _build(self, inp, is_training=True): """Constructs a ResNetFCN using the input tensor Parameters ---------- inp : tf.Tensor input tensor is_training : bool flag to specify whether this is training - passed to batch normalization Returns ------- dict output dictionary containing: - `logits` - logits of the classification - `y_prob` - classification probabilities - `y_` - prediction of the classification """ outputs = {} filters = self.filters strides = self.strides assert len(strides) == len(filters) if self.rank is None: self.rank = len(strides[0]) assert len(inp.get_shape().as_list()) == self.rank + 2, \ 'Stride gives rank {} input is rank {}'.format(self.rank, len(inp.get_shape().as_list()) - 2) x = inp x = Convolution(filters[0], strides=strides[0])(x) tf.logging.info('Init conv tensor shape %s', x.get_shape()) # residual feature encoding blocks with num_residual_units at different scales defined via strides scales = [x] saved_strides = [] for scale in range(1, len(filters)): with tf.variable_scope('unit_%d_0' % (scale)): x = VanillaResidualUnit(filters[scale], stride=strides[scale])(x, is_training=is_training) saved_strides.append(strides[scale]) for i in range(1, self.num_residual_units): with tf.variable_scope('unit_%d_%d' % (scale, i)): x = VanillaResidualUnit(filters[scale], stride=[1] * self.rank)(x, is_training=is_training) scales.append(x) tf.logging.info('Encoder at scale %d tensor shape: %s', scale, x.get_shape()) # Decoder / upscore for scale in range(len(filters) - 2, -1, -1): with tf.variable_scope('upscore_%d' % scale): x = Upscore(self.num_classes, saved_strides[scale])(x, scales[scale], is_training=is_training) tf.logging.info('Decoder at scale %d tensor shape: %s', scale, x.get_shape()) with tf.variable_scope('last'): x = Convolution(self.num_classes, 1, strides=[1] * self.rank)(x) outputs['logits'] = x tf.logging.info('Logits tensor shape %s', x.get_shape()) with tf.variable_scope('pred'): y_prob = tf.nn.softmax(x) outputs['y_prob'] = y_prob y_ = tf.argmax(x, axis=-1) if self.num_classes > 1 else tf.cast(tf.greater_equal(x[..., 0], 0.5), tf.int32) outputs['y_'] = y_ return outputs