Skip to content

Commit f92dd24

Browse files
author
zhaijianfeng
committed
support to save model in .h5 format, this commit only support UNet model, but other model just in the same way
1 parent c06d3d2 commit f92dd24

2 files changed

Lines changed: 83 additions & 6 deletions

File tree

‎tensorflow_advanced_segmentation_models/base/functional.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def binary_crossentropy(y_true, y_pred):
112112

113113
def categorical_focal_loss(y_true, y_pred, gamma=2.0, alpha=0.25):
114114
y_true, y_pred = gather_channels(y_true, y_pred)
115+
y_true = K.cast(y_true, K.floatx())
115116
y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon())
116117

117118
loss = - y_true * (alpha * K.pow((1 - y_pred), gamma) * K.log(y_pred))

‎tensorflow_advanced_segmentation_models/models/_custom_layers_and_blocks.py‎

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
################################################################################
55
# Layers
66
################################################################################
7+
from tensorflow.keras import activations
8+
9+
710
class ConvolutionBnActivation(tf.keras.layers.Layer):
811
"""
912
"""
@@ -12,8 +15,8 @@ def __init__(self, filters, kernel_size, strides=(1, 1), padding="same", data_fo
1215
groups=1, activation=None, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None,
1316
bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, use_batchnorm=False,
1417
axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True, trainable=True,
15-
post_activation="relu", block_name=None):
16-
super(ConvolutionBnActivation, self).__init__()
18+
post_activation="relu", block_name=None, **kwargs):
19+
super(ConvolutionBnActivation, self).__init__(**kwargs)
1720

1821

1922
# 2D Convolution Arguments
@@ -46,7 +49,8 @@ def __init__(self, filters, kernel_size, strides=(1, 1), padding="same", data_fo
4649
self.conv = None
4750
self.bn = None
4851
#tf.keras.layers.BatchNormalization(scale=False, momentum=0.9)
49-
self.post_activation = tf.keras.layers.Activation(post_activation)
52+
# self.post_activation = tf.keras.layers.Activation(post_activation)
53+
self.post_activation = activations.get(post_activation)
5054

5155
def build(self, input_shape):
5256
self.conv = tf.keras.layers.Conv2D(
@@ -90,6 +94,49 @@ def compute_output_shape(self, input_shape):
9094
print(input_shape)
9195
return [input_shape[0], input_shape[1], input_shape[2], self.filters]
9296

97+
def get_config(self):
98+
config = {
99+
"filters": self.filters,
100+
"kernel_size": self.kernel_size,
101+
"strides": self.strides,
102+
"padding": self.padding,
103+
"data_format": self.data_format,
104+
"dilation_rate": self.dilation_rate,
105+
"activation": activations.serialize(self.activation),
106+
"use_batchnorm": not self.use_bias,
107+
"kernel_initializer": self.kernel_initializer,
108+
"bias_initializer": self.bias_initializer,
109+
"kernel_regularizer": self.kernel_regularizer,
110+
"bias_regularizer": self.bias_regularizer,
111+
"activity_regularizer": self.activity_regularizer,
112+
"kernel_constraint": self.kernel_constraint,
113+
"bias_constraint": self.bias_constraint,
114+
# Batch Normalization Arguments
115+
"axis": self.axis,
116+
"momentum": self.momentum,
117+
"epsilon": self.epsilon,
118+
"center": self.center,
119+
"scale": self.scale,
120+
"trainable": self.trainable,
121+
"block_name": self.block_name,
122+
123+
124+
# self.use_bias = not (use_batchnorm)
125+
#
126+
#
127+
#
128+
#
129+
# self.conv = None
130+
# self.bn = None
131+
132+
133+
# tf.keras.layers.BatchNormalization(scale=False, momentum=0.9)
134+
# self.post_activation = tf.keras.layers.Activation(post_activation)
135+
"post_activation": activations.serialize(self.post_activation),
136+
}
137+
base_config = super(ConvolutionBnActivation, self).get_config()
138+
return dict(list(base_config.items()) + list(config.items()))
139+
93140
class AtrousSeparableConvolutionBnReLU(tf.keras.layers.Layer):
94141
"""
95142
"""
@@ -214,9 +261,10 @@ def compute_output_shape(self, input_shape):
214261
class Upsample_x2_Block(tf.keras.layers.Layer):
215262
"""
216263
"""
217-
def __init__(self, filters, trainable=None):
218-
super(Upsample_x2_Block, self).__init__()
264+
def __init__(self, filters, trainable=None, **kwargs):
265+
super(Upsample_x2_Block, self).__init__(**kwargs)
219266
self.trainable = trainable
267+
self.filters = filters
220268

221269
self.upsample2d_size2 = tf.keras.layers.UpSampling2D(size=2, interpolation="bilinear")
222270
self.conv2x2_bn_relu = tf.keras.layers.Conv2D(filters, kernel_size=(2, 2), padding="same")
@@ -242,6 +290,14 @@ def compute_output_shape(self, input_shape):
242290
print(input_shape)
243291
return [input_shape[0], input_shape[1] * 2, input_shape[2] * 2, input_shape[3]]
244292

293+
def get_config(self):
294+
config = {
295+
"filters": self.filters,
296+
"trainable": self.trainable,
297+
}
298+
base_config = super(Upsample_x2_Block, self).get_config()
299+
return dict(list(base_config.items()) + list(config.items()))
300+
245301
class Upsample_x2_Add_Block(tf.keras.layers.Layer):
246302
"""
247303
"""
@@ -1176,4 +1232,24 @@ def call(self, input1, input2, input3, input4, training=None):
11761232

11771233
x_fuse.append(self.relu(y))
11781234

1179-
return x_fuse
1235+
return x_fuse
1236+
1237+
1238+
custom_objects = {
1239+
'ConvolutionBnActivation': ConvolutionBnActivation,
1240+
'AtrousSeparableConvolutionBnReLU': AtrousSeparableConvolutionBnReLU,
1241+
'AtrousSpatialPyramidPoolingV3': AtrousSpatialPyramidPoolingV3,
1242+
'Upsample_x2_Block': Upsample_x2_Block,
1243+
'Upsample_x2_Add_Block': Upsample_x2_Add_Block,
1244+
'SpatialContextBlock': SpatialContextBlock,
1245+
'FPNBlock': FPNBlock,
1246+
'AtrousSpatialPyramidPoolingV1': AtrousSpatialPyramidPoolingV1,
1247+
'Base_OC_Module': Base_OC_Module,
1248+
'Pyramid_OC_Module': Pyramid_OC_Module,
1249+
'ASP_OC_Module': ASP_OC_Module,
1250+
'PAM_Module': PAM_Module,
1251+
'CAM_Module': CAM_Module,
1252+
'SelfAttentionBlock2D': SelfAttentionBlock2D,
1253+
}
1254+
1255+
tf.keras.utils.get_custom_objects().update(custom_objects)

0 commit comments

Comments
 (0)