from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
import numpy as np
[docs]class SlidingWindow(object):
    """SlidingWindow
    Sliding window iterator which produces slice objects to slice in a sliding
    window. This is useful for inference.
    """
    def __init__(self, img_shape, window_shape, has_batch_dim=True,
                 striding=None):
        """Constructs a sliding window iterator
        Args:
            img_shape (array_like): shape of the image to slide over
            window_shape (array_like): shape of the window to extract
            has_batch_dim (bool, optional): flag to indicate whether a batch
                dimension is present
            striding (array_like, optional): amount to move the window between
                each position
        """
        self.img_shape = img_shape
        self.window_shape = window_shape
        self.rank = len(img_shape)
        self.curr_pos = [0] * self.rank
        self.end_pos = [0] * self.rank
        self.done = False
        self.striding = window_shape
        self.has_batch_dim = has_batch_dim
        if striding:
            self.striding = striding
    def __iter__(self):
        return self
    # py 2.* compatability hack
[docs]    def next(self):
        return self.__next__() 
    def __next__(self):
        if self.done:
            raise StopIteration()
        if self.has_batch_dim:
            slicer = [slice(None)] * (self.rank + 1)
        else:
            slicer = [slice(None)] * self.rank
        move_dim = True
        for dim, pos in enumerate(self.curr_pos):
            low = pos
            high = pos + self.window_shape[dim]
            if move_dim:
                if high >= self.img_shape[dim]:
                    self.curr_pos[dim] = 0
                    move_dim = True
                else:
                    self.curr_pos[dim] += self.striding[dim]
                    move_dim = False
            if high >= self.img_shape[dim]:
                low = self.img_shape[dim] - self.window_shape[dim]
                high = self.img_shape[dim]
            if self.has_batch_dim:
                slicer[dim + 1] = slice(low, high)
            else:
                slicer[dim] = slice(low, high)
        if (np.array(self.curr_pos) == self.end_pos).all():
            self.done = True
        return slicer 
[docs]def sliding_window_segmentation_inference(session,
                                          ops_list,
                                          sample_dict,
                                          batch_size=1):
    """
    Utility function to perform sliding window inference for segmentation
    Args:
        session (tf.Session): TensorFlow session to run ops with
        ops_list (array_like): Operators to fetch assemble with sliding window
        sample_dict (dict): Dictionary with tf.Placeholder keys mapping the
        placeholders to their respective input
        batch_size (int, optional): Number of sliding windows to batch for
            calculation
    Returns:
        list: List of np.arrays corresponding to the assembled outputs of
            ops_list
    """
    # TODO: asserts
    assert batch_size > 0, 'Batch size has to be 1 or bigger'
    pl_shape = list(sample_dict.keys())[0].get_shape().as_list()
    pl_bshape = pl_shape[1:-1]
    inp_shape = list(list(sample_dict.values())[0].shape)
    inp_bshape = inp_shape[1:-1]
    out_dummies = [np.zeros(
        [inp_shape[0], ] + inp_bshape + [op.get_shape().as_list()[-1]]
        if len(op.get_shape().as_list()) == len(inp_shape) else [])
                   for op in ops_list]
    out_dummy_counter = [np.zeros_like(o) for o in out_dummies]
    op_shape = list(ops_list[0].get_shape().as_list())
    op_bshape = op_shape[1:-1]
    out_diff = np.array(pl_bshape) - np.array(op_bshape)
    padding = [[0, 0]] + [[diff // 2, diff - diff // 2] for diff
                          in out_diff] + [[0, 0]]
    padded_dict = {k: np.pad(v, padding, mode='constant') for k, v
                   in sample_dict.items()}
    f_bshape = list(padded_dict.values())[0].shape[1:-1]
    striding = list(np.array(op_bshape) // 2) if all(out_diff == 0) else op_bshape
    sw = SlidingWindow(f_bshape, pl_bshape, striding=striding)
    out_sw = SlidingWindow(inp_bshape, op_bshape, striding=striding)
    if batch_size > 1:
        slicers = []
        out_slicers = []
    done = False
    while True:
        try:
            slicer = next(sw)
            out_slicer = next(out_sw)
        except StopIteration:
            done = True
        if batch_size == 1:
            sw_dict = {k: v[slicer] for k, v in padded_dict.items()}
            op_parts = session.run(ops_list, feed_dict=sw_dict)
            for idx in range(len(op_parts)):
                out_dummies[idx][out_slicer] += op_parts[idx]
                out_dummy_counter[idx][out_slicer] += 1
        else:
            slicers.append(slicer)
            out_slicers.append(out_slicer)
            if len(slicers) == batch_size or done:
                slices_dict = {k: np.concatenate(
                    [v[slicer] for slicer in slicers], 0)
                               for k, v in padded_dict.items()}
                all_op_parts = session.run(ops_list, feed_dict=slices_dict)
                zipped_parts = zip(*[np.array_split(part, len(slicers)) for
                                     part in all_op_parts])
                for out_slicer, op_parts in zip(out_slicers, zipped_parts):
                    for idx in range(len(op_parts)):
                        out_dummies[idx][out_slicer] += op_parts[idx]
                        out_dummy_counter[idx][out_slicer] += 1
                slicers = []
                out_slicers = []
        if done:
            break
    return [o / c for o, c in zip(out_dummies, out_dummy_counter)]