from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
import tensorflow as tf
import numpy as np
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.training import moving_averages
from dltk.core.modules import *
[docs]class ResNet(AbstractModule):
"""ResNet module
This module builds a ResNet for classification according to He et al. 2015
"""
def __init__(self, num_classes, num_residual_units=5, filters=[16, 16, 32, 64],
strides=[[1, 1, 1], [1, 1, 1], [2, 2, 2], [2, 2, 2]], relu_leakiness=0.01,
name='resnet32'):
"""Builds a residual FCN for segmentation
Parameters
----------
num_classes : int
number of classes to segment
num_residual_units : int
number of residual units per scale
filters : tuple or list
number of filters per scale. The first is used for the initial convolution without residual connections
strides : tuple or list
strides per scale. The first is used for the initial convolution without residual connections
relu_leakiness : float
leakiness of the relus used
name : string
name of the network
"""
self.num_classes = num_classes
self.num_residual_units = num_residual_units
self.filters = filters
self.strides = strides
self.relu_leakiness = relu_leakiness
self.rank = None
super(ResNet, self).__init__(name)
def _build(self, inp, is_training=True):
"""Constructs a ResNet using the input tensor
Parameters
----------
inp : tf.Tensor
input tensor
is_training : bool
flag to specify whether this is training - passed to batch normalization
Returns
-------
dict
output dictionary containing:
- `logits` - logits of the classification
- `y_prob` - classification probabilities
- `y_` - prediction of the classification
"""
outputs = {}
filters = self.filters
strides = self.strides
assert len(strides) == len(filters)
if self.rank is None:
self.rank = len(strides[0])
assert len(inp.get_shape().as_list()) == self.rank + 2, \
'Stride gives rank {} input is rank {}'.format(self.rank, len(inp.get_shape().as_list()) - 2)
x = inp
x = Convolution(filters[0], strides=strides[0])(x)
tf.logging.info(x.get_shape())
for scale in range(1, len(filters)):
with tf.variable_scope('unit_%d_0' % (scale)):
x = VanillaResidualUnit(filters[scale], stride=strides[scale])(x, is_training=is_training)
for i in range(1, self.num_residual_units):
with tf.variable_scope('unit_%d_%d' % (scale, i)):
x = VanillaResidualUnit(filters[scale],
stride=[1,] * self.rank)(x, is_training=is_training)
tf.logging.info('feat_scale_%d shape %s', scale, x.get_shape())
with tf.variable_scope('unit_last'):
x = BatchNorm()(x)
x = leaky_relu(x, self.relu_leakiness)
axis = tuple(range(len(x.get_shape().as_list())))[1:-1]
x = tf.reduce_mean(x, axis=axis, name='global_avg_pool')
tf.logging.info('unit_last axis %s shape %s',axis,x.get_shape())
with tf.variable_scope('logits'):
x = tf.reshape(x, (tf.shape(x)[0], filters[-1]))
x = Linear(self.num_classes)(x)
outputs['logits'] = x
tf.logging.info('last conv shape %s', x.get_shape())
with tf.variable_scope('pred'):
y_prob = tf.nn.softmax(x)
outputs['y_prob'] = y_prob
y_ = tf.argmax(x, axis=-1)
outputs['y_'] = y_
return outputs