from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
import tensorflow as tf
import numpy as np
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.training import moving_averages
from dltk.core.modules import *
[docs]class Upscore(AbstractModule):
"""Upscore module according to J. Long.
"""
def __init__(self, out_filters, strides, name='upscore'):
"""Constructs an Upscore module
Parameters
----------
out_filters : int
number of output filters
strides : list or tuple
strides to use for upsampling
name : string
name of the module
"""
self.out_filters = out_filters
self.strides = strides
self.in_filters = None
self.rank = None
super(Upscore, self).__init__(name)
def _build(self, x, x_up, is_training=True):
"""Applies the upscore operation
Parameters
----------
x : tf.Tensor
tensor to be upsampled
x_up : tf.Tensor
tensor from the same scale to be convolved and added to the upsampled tensor
is_training : bool
flag for specifying whether this is training - passed to batch normalization
Returns
-------
tf.Tensor
output of the upscore operation
"""
# Compute an up-conv shape dynamically from the input tensor. Input filters are required to be static.
if self.in_filters is None:
self.in_filters = x.get_shape().as_list()[-1]
assert self.in_filters == x.get_shape().as_list()[-1], 'Module was initialised for a different input shape'
if self.rank is None:
self.rank = len(self.strides)
assert len(x.get_shape().as_list()) == self.rank + 2, \
'Stride gives rank {} input is rank {}'.format(self.rank, len(x.get_shape().as_list()) - 2)
# Account for differences in input and output filters
if self.in_filters != self.out_filters:
x = Convolution(self.out_filters, name='up_score_filter_conv', strides=[1] * self.rank)(x)
t_conv = BilinearUpsample(strides=self.strides)(x)
conv = Convolution(self.out_filters, 1, strides=[1] * self.rank)(x_up)
conv = BatchNorm()(conv, is_training)
return tf.add(t_conv, conv)
[docs]class ResNetFCN(AbstractModule):
"""FCN module with residual encoder
This module builds a FCN for segmentation using a residual encoder.
"""
def __init__(self, num_classes, num_residual_units=3, filters=(16, 64, 128, 256, 512),
strides=((1, 1, 1), (2, 2, 2), (2, 2, 2), (2, 2, 2), (1, 1, 1)), relu_leakiness=0.1,
name='resnetfcn'):
"""Builds a residual FCN for segmentation
Parameters
----------
num_classes : int
number of classes to segment
num_residual_units : int
number of residual units per scale
filters : tuple or list
number of filters per scale. The first is used for the initial convolution without residual connections
strides : tuple or list
strides per scale. The first is used for the initial convolution without residual connections
relu_leakiness : float
leakiness of the relus used
name : string
name of the network
"""
self.num_classes = num_classes
self.num_residual_units = num_residual_units
self.filters = filters
self.strides = strides
self.relu_leakiness = relu_leakiness
self.rank = None
super(ResNetFCN, self).__init__(name)
def _build(self, inp, is_training=True):
"""Constructs a ResNetFCN using the input tensor
Parameters
----------
inp : tf.Tensor
input tensor
is_training : bool
flag to specify whether this is training - passed to batch normalization
Returns
-------
dict
output dictionary containing:
- `logits` - logits of the classification
- `y_prob` - classification probabilities
- `y_` - prediction of the classification
"""
outputs = {}
filters = self.filters
strides = self.strides
assert len(strides) == len(filters)
if self.rank is None:
self.rank = len(strides[0])
assert len(inp.get_shape().as_list()) == self.rank + 2, \
'Stride gives rank {} input is rank {}'.format(self.rank, len(inp.get_shape().as_list()) - 2)
x = inp
x = Convolution(filters[0], strides=strides[0])(x)
tf.logging.info('Init conv tensor shape %s', x.get_shape())
# residual feature encoding blocks with num_residual_units at different scales defined via strides
scales = [x]
saved_strides = []
for scale in range(1, len(filters)):
with tf.variable_scope('unit_%d_0' % (scale)):
x = VanillaResidualUnit(filters[scale], stride=strides[scale])(x, is_training=is_training)
saved_strides.append(strides[scale])
for i in range(1, self.num_residual_units):
with tf.variable_scope('unit_%d_%d' % (scale, i)):
x = VanillaResidualUnit(filters[scale], stride=[1] * self.rank)(x, is_training=is_training)
scales.append(x)
tf.logging.info('Encoder at scale %d tensor shape: %s', scale, x.get_shape())
# Decoder / upscore
for scale in range(len(filters) - 2, -1, -1):
with tf.variable_scope('upscore_%d' % scale):
x = Upscore(self.num_classes, saved_strides[scale])(x, scales[scale], is_training=is_training)
tf.logging.info('Decoder at scale %d tensor shape: %s', scale, x.get_shape())
with tf.variable_scope('last'):
x = Convolution(self.num_classes, 1, strides=[1] * self.rank)(x)
outputs['logits'] = x
tf.logging.info('Logits tensor shape %s', x.get_shape())
with tf.variable_scope('pred'):
y_prob = tf.nn.softmax(x)
outputs['y_prob'] = y_prob
y_ = tf.argmax(x, axis=-1) if self.num_classes > 1 else tf.cast(tf.greater_equal(x[..., 0], 0.5), tf.int32)
outputs['y_'] = y_
return outputs