Source code for dltk.core.losses

from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

import tensorflow as tf
import numpy as np


[docs]def sparse_balanced_crossentropy(logits, labels): """ Calculates a class frequency balanced crossentropy loss from sparse labels. Args: logits (tf.Tensor): logits prediction for which to calculate crossentropy error labels (tf.Tensor): sparse labels used for crossentropy error calculation Returns: tf.Tensor: Tensor scalar representing the mean loss """ epsilon = tf.constant(np.finfo(np.float32).tiny) num_classes = tf.cast(tf.shape(logits)[-1], tf.int32) probs = tf.nn.softmax(logits) probs += tf.cast(tf.less(probs, epsilon), tf.float32) * epsilon log = -1. * tf.log(probs) onehot_labels = tf.one_hot(labels, num_classes) class_frequencies = tf.stop_gradient(tf.bincount( labels, minlength=num_classes, dtype=tf.float32)) weights = (1. / (class_frequencies + tf.constant(1e-8))) weights *= (tf.cast(tf.reduce_prod(tf.shape(labels)), tf.float32) / tf.cast(num_classes, tf.float32)) new_shape = (([1, ] * len(labels.get_shape().as_list())) + [logits.get_shape().as_list()[-1]]) weights = tf.reshape(weights, new_shape) loss = tf.reduce_mean(tf.reduce_sum(onehot_labels * log * weights, axis=-1)) return loss
[docs]def dice_loss(logits, labels, num_classes, smooth=1e-5, include_background=True, only_present=False): """Calculates a smooth Dice coefficient loss from sparse labels. Args: logits (tf.Tensor): logits prediction for which to calculate crossentropy error labels (tf.Tensor): sparse labels used for crossentropy error calculation num_classes (int): number of class labels to evaluate on smooth (float): smoothing coefficient for the loss computation include_background (bool): flag to include a loss on the background label or not only_present (bool): flag to include only labels present in the inputs or not Returns: tf.Tensor: Tensor scalar representing the loss """ # Get a softmax probability of the logits predictions and a one hot # encoding of the labels tensor probs = tf.nn.softmax(logits) onehot_labels = tf.one_hot( indices=labels, depth=num_classes, dtype=tf.float32, name='onehot_labels') # Compute the Dice similarity coefficient label_sum = tf.reduce_sum(onehot_labels, axis=[1, 2, 3], name='label_sum') pred_sum = tf.reduce_sum(probs, axis=[1, 2, 3], name='pred_sum') intersection = tf.reduce_sum(onehot_labels * probs, axis=[1, 2, 3], name='intersection') per_sample_per_class_dice = (2. * intersection + smooth) per_sample_per_class_dice /= (label_sum + pred_sum + smooth) # Include or exclude the background label for the computation if include_background: flat_per_sample_per_class_dice = tf.reshape( per_sample_per_class_dice, (-1, )) flat_label = tf.reshape(label_sum, (-1, )) else: flat_per_sample_per_class_dice = tf.reshape( per_sample_per_class_dice[:, 1:], (-1, )) flat_label = tf.reshape(label_sum[:, 1:], (-1, )) # Include or exclude non-present labels for the computation if only_present: masked_dice = tf.boolean_mask(flat_per_sample_per_class_dice, tf.logical_not(tf.equal(flat_label, 0))) else: masked_dice = tf.boolean_mask( flat_per_sample_per_class_dice, tf.logical_not(tf.is_nan(flat_per_sample_per_class_dice))) dice = tf.reduce_mean(masked_dice) loss = 1. - dice return loss