Source code for dltk.core.modules.losses

from __future__ import division
from __future__ import absolute_import
from __future__ import print_function

import tensorflow as tf
from dltk.core.modules.summaries import *
import numpy as np


[docs]def sparse_crossentropy(logits, labels, name='crossentropy', collections=['losses']): """ Crossentropy loss Calculates the crossentropy loss and builds a scalar summary. Parameters ---------- logits : tf.Tensor logit prediction for which to calculate crossentropy error labels : tf.Tensor labels used for crossentropy error calculation name : string name of this operation and summary collections : list or tuple list of collections to add the summaries to Returns ------- tf.Tensor Tensor representing the loss """ ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) loss = tf.reduce_mean(ce, name=name) scalar_summary(loss, name, collections) return loss
[docs]def sparse_balanced_crossentropy(logits, labels, name='crossentropy', collections=['losses']): """ Crossentropy loss Calculates the crossentropy loss and builds a scalar summary. Parameters ---------- logits : tf.Tensor logit prediction for which to calculate crossentropy error labels : tf.Tensor labels used for crossentropy error calculation name : string name of this operation and summary collections : list or tuple list of collections to add the summaries to Returns ------- tf.Tensor Tensor representing the loss """ eps = tf.constant(np.finfo(np.float32).tiny) num_classes = tf.cast(tf.shape(logits)[-1], tf.int32) probs = tf.nn.softmax(logits) probs += tf.cast(tf.less(probs, eps), tf.float32) * eps log = -1. * tf.log(probs) oh_labels = tf.one_hot(labels, num_classes) class_occurances = tf.stop_gradient(tf.bincount(labels, minlength=num_classes, dtype=tf.float32)) weights = (1. / (class_occurances + tf.constant(1e-8))) * (tf.cast(tf.reduce_prod(tf.shape(labels)), tf.float32) / tf.cast(num_classes, tf.float32)) weights = tf.reshape(weights, ([1,] * len(labels.get_shape().as_list())) + [logits.get_shape().as_list()[-1]]) loss = tf.reduce_mean(tf.reduce_sum(oh_labels * log * weights, axis=-1)) scalar_summary(loss, name, collections) return loss
[docs]def mse(x, y, name='mse', collections=['losses']): """ Mean squared error Calculates the crossentropy loss and builds a scalar summary. Parameters ---------- x : tf.Tensor prediction for which to calculate the error y : tf.Tensor targets with which to calculate the error name : string name of this operation and summary collections : list or tuple list of collections to add the summaries to Returns ------- tf.Tensor Tensor representing the loss """ loss = tf.reduce_mean(tf.square(x - y), name=name) scalar_summary(loss, name, collections) return loss
[docs]def dice_loss(logits, labels, num_classes, smooth=1e-5, include_background=True, only_present=False, name='dice_loss', collections=['losses']): """ Smooth dice loss Calculates the smooth dice loss and builds a scalar summary. Parameters ---------- logits : tf.Tensor prediction for which to calculate the error labels : tf.Tensor sparse targets with which to calculate the error num_classes : int number of class labels to evaluate on include_background : bool flag to include a loss on the background label or not name : string name of this operation and summary collections : list or tuple list of collections to add the summaries to Returns ------- tf.Tensor Tensor representing the loss """ probs = tf.nn.softmax(logits) onehot_labels = tf.one_hot(labels, num_classes, dtype=tf.float32, name='onehot_labels') label_sum = tf.reduce_sum(onehot_labels, axis=[1, 2, 3], name='label_sum') pred_sum = tf.reduce_sum(probs, axis=[1, 2, 3], name='pred_sum') intersection = tf.reduce_sum(onehot_labels * probs, axis=[1, 2, 3], name='intersection') per_sample_per_class_dice = (2. * intersection + smooth) / (label_sum + pred_sum + smooth) flat_per_sample_per_class_dice = tf.reshape(per_sample_per_class_dice if include_background else per_sample_per_class_dice[:, 1:] , (-1, )) if only_present: flat_label = tf.reshape(label_sum if include_background else label_sum[:, 1:] , (-1, )) masked_dice = tf.boolean_mask(flat_per_sample_per_class_dice, tf.logical_not(tf.equal(flat_label, 0))) else: masked_dice = tf.boolean_mask(flat_per_sample_per_class_dice, tf.logical_not(tf.is_nan(flat_per_sample_per_class_dice))) dice = tf.reduce_mean(masked_dice) loss = 1. - dice scalar_summary(loss, name, collections) return loss