Source code for dltk.utils

from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

import numpy as np


[docs]class SlidingWindow(object): """SlidingWindow Sliding window iterator which produces slice objects to slice in a sliding window. This is useful for inference. """ def __init__(self, img_shape, window_shape, has_batch_dim=True, striding=None): """Constructs a sliding window iterator Args: img_shape (array_like): shape of the image to slide over window_shape (array_like): shape of the window to extract has_batch_dim (bool, optional): flag to indicate whether a batch dimension is present striding (array_like, optional): amount to move the window between each position """ self.img_shape = img_shape self.window_shape = window_shape self.rank = len(img_shape) self.curr_pos = [0] * self.rank self.end_pos = [0] * self.rank self.done = False self.striding = window_shape self.has_batch_dim = has_batch_dim if striding: self.striding = striding def __iter__(self): return self # py 2.* compatability hack
[docs] def next(self): return self.__next__()
def __next__(self): if self.done: raise StopIteration() if self.has_batch_dim: slicer = [slice(None)] * (self.rank + 1) else: slicer = [slice(None)] * self.rank move_dim = True for dim, pos in enumerate(self.curr_pos): low = pos high = pos + self.window_shape[dim] if move_dim: if high >= self.img_shape[dim]: self.curr_pos[dim] = 0 move_dim = True else: self.curr_pos[dim] += self.striding[dim] move_dim = False if high >= self.img_shape[dim]: low = self.img_shape[dim] - self.window_shape[dim] high = self.img_shape[dim] if self.has_batch_dim: slicer[dim + 1] = slice(low, high) else: slicer[dim] = slice(low, high) if (np.array(self.curr_pos) == self.end_pos).all(): self.done = True return slicer
[docs]def sliding_window_segmentation_inference(session, ops_list, sample_dict, batch_size=1): """ Utility function to perform sliding window inference for segmentation Args: session (tf.Session): TensorFlow session to run ops with ops_list (array_like): Operators to fetch assemble with sliding window sample_dict (dict): Dictionary with tf.Placeholder keys mapping the placeholders to their respective input batch_size (int, optional): Number of sliding windows to batch for calculation Returns: list: List of np.arrays corresponding to the assembled outputs of ops_list """ # TODO: asserts assert batch_size > 0, 'Batch size has to be 1 or bigger' pl_shape = list(sample_dict.keys())[0].get_shape().as_list() pl_bshape = pl_shape[1:-1] inp_shape = list(list(sample_dict.values())[0].shape) inp_bshape = inp_shape[1:-1] out_dummies = [np.zeros( [inp_shape[0], ] + inp_bshape + [op.get_shape().as_list()[-1]] if len(op.get_shape().as_list()) == len(inp_shape) else []) for op in ops_list] out_dummy_counter = [np.zeros_like(o) for o in out_dummies] op_shape = list(ops_list[0].get_shape().as_list()) op_bshape = op_shape[1:-1] out_diff = np.array(pl_bshape) - np.array(op_bshape) padding = [[0, 0]] + [[diff // 2, diff - diff // 2] for diff in out_diff] + [[0, 0]] padded_dict = {k: np.pad(v, padding, mode='constant') for k, v in sample_dict.items()} f_bshape = list(padded_dict.values())[0].shape[1:-1] striding = list(np.array(op_bshape) // 2) if all(out_diff == 0) else op_bshape sw = SlidingWindow(f_bshape, pl_bshape, striding=striding) out_sw = SlidingWindow(inp_bshape, op_bshape, striding=striding) if batch_size > 1: slicers = [] out_slicers = [] done = False while True: try: slicer = next(sw) out_slicer = next(out_sw) except StopIteration: done = True if batch_size == 1: sw_dict = {k: v[slicer] for k, v in padded_dict.items()} op_parts = session.run(ops_list, feed_dict=sw_dict) for idx in range(len(op_parts)): out_dummies[idx][out_slicer] += op_parts[idx] out_dummy_counter[idx][out_slicer] += 1 else: slicers.append(slicer) out_slicers.append(out_slicer) if len(slicers) == batch_size or done: slices_dict = {k: np.concatenate( [v[slicer] for slicer in slicers], 0) for k, v in padded_dict.items()} all_op_parts = session.run(ops_list, feed_dict=slices_dict) zipped_parts = zip(*[np.array_split(part, len(slicers)) for part in all_op_parts]) for out_slicer, op_parts in zip(out_slicers, zipped_parts): for idx in range(len(op_parts)): out_dummies[idx][out_slicer] += op_parts[idx] out_dummy_counter[idx][out_slicer] += 1 slicers = [] out_slicers = [] if done: break return [o / c for o, c in zip(out_dummies, out_dummy_counter)]