Source code for dltk.core.io.sliding_window

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np


[docs]class SlidingWindow(object): """SlidingWindow Sliding window iterator which produces slice objects to slice in a sliding window. This is useful for inference. """ def __init__(self, img_shape, window_shape, has_batch_dim=True, striding=None): """Constructs a sliding window iterator Parameters ---------- img_shape : tuple or list shape of the image to slide over window_shape : tuple or list shape of the window to extract striding : tuple or list, optional amount to move the window between each position """ self.img_shape = img_shape self.window_shape = window_shape self.rank = len(img_shape) self.curr_pos = [0] * self.rank self.end_pos = [0] * self.rank self.done = False self.striding = window_shape self.has_batch_dim = has_batch_dim if striding: self.striding = striding def __iter__(self): return self # py 2.* compatability hack
[docs] def next(self): return self.__next__()
def __next__(self): if self.done: raise StopIteration() slicer = [slice(None)] * (self.rank + 1) if self.has_batch_dim else [slice(None)] * self.rank move_dim = True for dim, pos in enumerate(self.curr_pos): low = pos high = pos + self.window_shape[dim] if move_dim: if high >= self.img_shape[dim]: self.curr_pos[dim] = 0 move_dim = True else: self.curr_pos[dim] += self.striding[dim] move_dim = False if high >= self.img_shape[dim]: low = self.img_shape[dim] - self.window_shape[dim] high = self.img_shape[dim] if self.has_batch_dim: slicer[dim + 1] = slice(low, high) else: slicer[dim] = slice(low, high) if (np.array(self.curr_pos) == self.end_pos).all(): self.done = True return slicer