1 parent 4733603 commit 57b8a6aCopy full SHA for 57b8a6a
1 file changed
TensorFlow2/Segmentation/UNet_Medical/utils/losses.py
@@ -32,8 +32,8 @@ def partial_losses(predict, target):
32
flat_labels = tf.reshape(target,
33
[tf.shape(input=predict)[0], -1, n_classes])
34
35
- crossentropy_loss = tf.reduce_mean(input_tensor=tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits,
36
- labels=flat_labels),
+ crossentropy_loss = tf.reduce_mean(input_tensor=tf.keras.backend.binary_crossentropy(output=flat_logits,
+ target=flat_labels),
37
name='cross_loss_ref')
38
dice_loss = tf.reduce_mean(input_tensor=1 - dice_coef(flat_logits, flat_labels), name='dice_loss_ref')
39
return crossentropy_loss, dice_loss
0 commit comments