from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
import tensorflow as tf
import numpy as np
[docs]def sparse_balanced_crossentropy(logits, labels):
    """
    Calculates a class frequency balanced crossentropy loss from sparse labels.
    Args:
        logits (tf.Tensor): logits prediction for which to calculate
            crossentropy error
        labels (tf.Tensor): sparse labels used for crossentropy error
            calculation
    Returns:
        tf.Tensor: Tensor scalar representing the mean loss
    """
    epsilon = tf.constant(np.finfo(np.float32).tiny)
    num_classes = tf.cast(tf.shape(logits)[-1], tf.int32)
    probs = tf.nn.softmax(logits)
    probs += tf.cast(tf.less(probs, epsilon), tf.float32) * epsilon
    log = -1. * tf.log(probs)
    onehot_labels = tf.one_hot(labels, num_classes)
    class_frequencies = tf.stop_gradient(tf.bincount(
        labels, minlength=num_classes, dtype=tf.float32))
    weights = (1. / (class_frequencies + tf.constant(1e-8)))
    weights *= (tf.cast(tf.reduce_prod(tf.shape(labels)), tf.float32) / tf.cast(num_classes, tf.float32))
    new_shape = (([1, ] * len(labels.get_shape().as_list())) + [logits.get_shape().as_list()[-1]])
    weights = tf.reshape(weights, new_shape)
    loss = tf.reduce_mean(tf.reduce_sum(onehot_labels * log * weights, axis=-1))
    return loss 
[docs]def dice_loss(logits,
              labels,
              num_classes,
              smooth=1e-5,
              include_background=True,
              only_present=False):
    """Calculates a smooth Dice coefficient loss from sparse labels.
    Args:
        logits (tf.Tensor): logits prediction for which to calculate
            crossentropy error
        labels (tf.Tensor): sparse labels used for crossentropy error
            calculation
        num_classes (int): number of class labels to evaluate on
        smooth (float): smoothing coefficient for the loss computation
        include_background (bool): flag to include a loss on the background
            label or not
        only_present (bool): flag to include only labels present in the
            inputs or not
    Returns:
        tf.Tensor: Tensor scalar representing the loss
    """
    # Get a softmax probability of the logits predictions and a one hot
    # encoding of the labels tensor
    probs = tf.nn.softmax(logits)
    onehot_labels = tf.one_hot(
        indices=labels,
        depth=num_classes,
        dtype=tf.float32,
        name='onehot_labels')
    # Compute the Dice similarity coefficient
    label_sum = tf.reduce_sum(onehot_labels, axis=[1, 2, 3], name='label_sum')
    pred_sum = tf.reduce_sum(probs, axis=[1, 2, 3], name='pred_sum')
    intersection = tf.reduce_sum(onehot_labels * probs, axis=[1, 2, 3],
                                 name='intersection')
    per_sample_per_class_dice = (2. * intersection + smooth)
    per_sample_per_class_dice /= (label_sum + pred_sum + smooth)
    # Include or exclude the background label for the computation
    if include_background:
        flat_per_sample_per_class_dice = tf.reshape(
            per_sample_per_class_dice, (-1, ))
        flat_label = tf.reshape(label_sum, (-1, ))
    else:
        flat_per_sample_per_class_dice = tf.reshape(
            per_sample_per_class_dice[:, 1:], (-1, ))
        flat_label = tf.reshape(label_sum[:, 1:], (-1, ))
    # Include or exclude non-present labels for the computation
    if only_present:
        masked_dice = tf.boolean_mask(flat_per_sample_per_class_dice,
                                      tf.logical_not(tf.equal(flat_label, 0)))
    else:
        masked_dice = tf.boolean_mask(
            flat_per_sample_per_class_dice,
            tf.logical_not(tf.is_nan(flat_per_sample_per_class_dice)))
    dice = tf.reduce_mean(masked_dice)
    loss = 1. - dice
    return loss