Source code for dltk.core.modules.summaries

from __future__ import division
from __future__ import absolute_import
from __future__ import print_function

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

    from StringIO import StringIO
except ImportError:
    from io import StringIO

[docs]def image_summary(img, summary_name, collections=None): """Builds an image summary from a tf.Tensor or np.ndarray If the image is a tf.Tensor 4D and 5D tensors of form (batch, x, y, channels) and (batch, x, y, z, channels) are supported. For 5D tensors each middle slice is plotted if the size of the tensor is known. Otherwise the first slice is taken. If the image is a np.ndarray 3D and 4D arrays of form (x, y, channels) and (x, y, z, channels) are supported. For 4D tensors each middle slice is plotted if the size of the tensor is known. Otherwise the first slice is taken. Parameters ---------- img : tf.Tensor or np.ndarray image to be plotted summary_name : string name of the summary to be produced collections : list or tuple, optional list of collections this summary should be added to additionally to `tf.GraphKeys.SUMMARIES` and `image_summaries` Returns ------- tf.Tensor or tf.Summary Tensor produced from tf.summary or Summary object with the plotted image(s) """ summaries = [] if isinstance(img, tf.Tensor): collections = [tf.GraphKeys.SUMMARIES, 'image_summaries'] + collections if collections is not None else [] if len(img.get_shape().as_list()) == 5: for dim in range(3): slicer = [slice(None)] * 4 pos = 0 if img.get_shape().as_list()[dim + 1]: pos = img.get_shape().as_list()[dim + 1] // 2 slicer[dim + 1] = pos summaries.append(tf.summary.image('{}_dim{}'.format(summary_name, dim), img[slicer], collections=collections)) else: summaries.append(tf.summary.image(summary_name, img, collections=collections)) return tf.summary.merge(summaries) elif isinstance(img, np.ndarray): # only works on 3D and 4D arrays -> batch isn't used # see if np.min(img) < 0.: img -= np.min(img) img /= np.max(img) if img.ndim == 4: for dim in range(3): slicer = [slice(None)] * 3 slicer[dim] = img.shape[dim] // 2 tmp_img = (img[slicer] - img[slicer].min()) tmp_img /= tmp_img.max() if tmp_img.max() > 0. else 1. s = StringIO() if img.shape[-1] == 1: plt.imsave(s, img[slicer][:, :, 0], format='png', cmap='gray') else: plt.imsave(s, img[slicer], format='png') # Create an Image object img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), height=img[slicer].shape[0], width=img[slicer].shape[1]) # Create a Summary value summaries.append(tf.Summary.Value(tag='{}_dim{}'.format(summary_name, dim), image=img_sum)) else: s = StringIO() if img.shape[-1] == 1: plt.imsave(s, img[:, :, 0], format='png', cmap='gray') else: plt.imsave(s, img, format='png') # Create an Image object img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), height=img.shape[0], width=img.shape[1]) # Create a Summary value summaries.append(tf.Summary.Value(tag=summary_name, image=img_sum)) return tf.Summary(value=summaries) else: raise Exception('Only tf.Tensors and np.ndarrays are supported.')
[docs]def scalar_summary(x, summary_name, collections=None): """Builds a scalar summary If x is a tf.Tensor it creates the summary operation to track x If x is a scalar it creates the tf.Summary object to be written be a summary writer If x is a list, tuple or dict a tf.Summary object is created for each element. The key or index is used for naming Parameters ---------- x : tf.Tensor or scalar or list or dict scalar data to be plotted summary_name : string name of the summary to be produced collections : list or tuple, optional list of collections this summary should be added to additionally to `tf.GraphKeys.SUMMARIES` and `image_summaries` Returns ------- tf.Tensor or tf.Summary Tensor produced from tf.summary or Summary object with the summarised data """ if isinstance(x, tf.Tensor): collections = [tf.GraphKeys.SUMMARIES, 'scalar_summaries'] + collections if collections is not None else [] return tf.summary.scalar(summary_name, x, collections) elif np.isscalar(x): return tf.Summary(value=[tf.Summary.Value(tag=summary_name, simple_value=x)]) elif isinstance(x, (list, tuple)): return tf.Summary(value=[tf.Summary.Value(tag='{}_{}'.format(summary_name, i), simple_value=xi) for i, xi in enumerate(x)]) elif isinstance(x, dict): return tf.Summary(value=[tf.Summary.Value(tag='{}_{}'.format(summary_name, i), simple_value=xi) for i, xi in x.items()]) else: raise Exception('Only tf.Tensors and np.ndarrays are supported.')