Source code for dltk.core.modules.convolution

from __future__ import division
from __future__ import absolute_import
from __future__ import print_function

import tensorflow as tf
import numpy as np

from dltk.core.modules.base import AbstractModule


[docs]class Convolution(AbstractModule): """Convolution module This module builds a n-D convolution based on the dimensionality of the input and applies it to the input. """ def __init__(self, out_filters, filter_shape=3, strides=1, dilation_rate=1, padding='SAME', use_bias=False, name='conv'): """Constructs the convolution template Parameters ---------- out_filters : int number of output filters filter_shape : int or tuple or list, optional shape of the filter to use for the convolution strides : int or tuple or list, optional stride of the convolution operation dilation_rate : tuple or list, optional dilation rate used for dilated convolution. If used, stride must be 1 padding : str edge padding of convolutions, one of 'SAME' or 'VALID' use_bias : bool flag to toggle addition of a bias per output filter name : string name of the module """ if isinstance(filter_shape, int) and isinstance(strides, int) and isinstance(dilation_rate, int): filter_shape = np.array([filter_shape] * 3) strides = [strides] * 3 dilation_rate = [dilation_rate] * 3 elif isinstance(filter_shape, int) and isinstance(strides, (list, tuple)) and isinstance(dilation_rate, int): filter_shape = np.array([filter_shape] * len(strides)) dilation_rate = [dilation_rate] * len(strides) elif isinstance(filter_shape, int) and isinstance(dilation_rate, (list, tuple)) and isinstance(strides, int): filter_shape = np.array([filter_shape] * len(dilation_rate)) strides = [strides] * len(dilation_rate) elif (isinstance(filter_shape, (list, tuple, np.ndarray)) and isinstance(dilation_rate, int) and isinstance(strides, int)): dilation_rate = [dilation_rate] * len(filter_shape) strides = [strides] * len(filter_shape) elif isinstance(strides, int): strides = [strides] * len(filter_shape) elif isinstance(dilation_rate, int): dilation_rate = [dilation_rate] * len(filter_shape) else: raise Exception('Could not infer the dimensionality of the operation or both strides and dilation was' 'passed as list or tuple') assert len(strides) == len(filter_shape), 'Stride len must match len of filter shape' assert len(strides) == len(dilation_rate), 'Dilation rate and stride len must match' assert np.prod(dilation_rate) == 1 or np.prod(strides) == 1, 'Dilation rate or strides must be 1' assert padding == 'SAME' or padding == 'VALID', 'Padding must be either SAME or VALID' self.filter_shape = filter_shape self.in_shape = None self.in_filters = None self.out_filters = out_filters self.strides = strides self.use_bias = use_bias self.dilation_rate = dilation_rate self.padding = padding self._rank = len(list(self.filter_shape)) assert self._rank < 4, 'Convolutions are only supported up to 3D' super(Convolution, self).__init__(name=name) def _build(self, inp): """Applies a convolution operation to an input tensor Parameters ---------- inp : tf.Tensor input tensor to be convolved Returns ------- tf.Tensor convolved tensor """ assert len(inp.get_shape().as_list()) - 2 == self._rank, \ 'The input has {} dimensions but this is a {}D convolution'.format( len(inp.get_shape().as_list()), self._rank) self.in_shape = tuple(inp.get_shape().as_list()) if self.in_filters is None: self.in_filters = self.in_shape[-1] assert self.in_filters == self.in_shape[-1], 'Convolution was built for different number of channels' self.in_filters = self.in_shape[-1] kernel_shape = tuple(list(self.filter_shape) + [self.in_filters, self.out_filters]) self._k = tf.get_variable("k", shape=kernel_shape, initializer=tf.uniform_unit_scaling_initializer(), collections=self.WEIGHT_COLLECTIONS) self.variables.append(self._k) outp = tf.nn.convolution(inp, self._k, padding=self.padding, strides=self.strides, dilation_rate=self.dilation_rate, name='conv') if self.use_bias: self._b = tf.get_variable("b", shape=(self.out_filters,), initializer=tf.constant_initializer(), collections=self.BIAS_COLLECTIONS) self.variables.append(self._b) outp += self._b return outp