44################################################################################
55# Layers
66################################################################################
7+ from tensorflow .keras import activations
8+
9+
710class ConvolutionBnActivation (tf .keras .layers .Layer ):
811 """
912 """
@@ -12,8 +15,8 @@ def __init__(self, filters, kernel_size, strides=(1, 1), padding="same", data_fo
1215 groups = 1 , activation = None , kernel_initializer = "glorot_uniform" , bias_initializer = "zeros" , kernel_regularizer = None ,
1316 bias_regularizer = None , activity_regularizer = None , kernel_constraint = None , bias_constraint = None , use_batchnorm = False ,
1417 axis = - 1 , momentum = 0.99 , epsilon = 0.001 , center = True , scale = True , trainable = True ,
15- post_activation = "relu" , block_name = None ):
16- super (ConvolutionBnActivation , self ).__init__ ()
18+ post_activation = "relu" , block_name = None , ** kwargs ):
19+ super (ConvolutionBnActivation , self ).__init__ (** kwargs )
1720
1821
1922 # 2D Convolution Arguments
@@ -46,7 +49,8 @@ def __init__(self, filters, kernel_size, strides=(1, 1), padding="same", data_fo
4649 self .conv = None
4750 self .bn = None
4851 #tf.keras.layers.BatchNormalization(scale=False, momentum=0.9)
49- self .post_activation = tf .keras .layers .Activation (post_activation )
52+ # self.post_activation = tf.keras.layers.Activation(post_activation)
53+ self .post_activation = activations .get (post_activation )
5054
5155 def build (self , input_shape ):
5256 self .conv = tf .keras .layers .Conv2D (
@@ -90,6 +94,49 @@ def compute_output_shape(self, input_shape):
9094 print (input_shape )
9195 return [input_shape [0 ], input_shape [1 ], input_shape [2 ], self .filters ]
9296
97+ def get_config (self ):
98+ config = {
99+ "filters" : self .filters ,
100+ "kernel_size" : self .kernel_size ,
101+ "strides" : self .strides ,
102+ "padding" : self .padding ,
103+ "data_format" : self .data_format ,
104+ "dilation_rate" : self .dilation_rate ,
105+ "activation" : activations .serialize (self .activation ),
106+ "use_batchnorm" : not self .use_bias ,
107+ "kernel_initializer" : self .kernel_initializer ,
108+ "bias_initializer" : self .bias_initializer ,
109+ "kernel_regularizer" : self .kernel_regularizer ,
110+ "bias_regularizer" : self .bias_regularizer ,
111+ "activity_regularizer" : self .activity_regularizer ,
112+ "kernel_constraint" : self .kernel_constraint ,
113+ "bias_constraint" : self .bias_constraint ,
114+ # Batch Normalization Arguments
115+ "axis" : self .axis ,
116+ "momentum" : self .momentum ,
117+ "epsilon" : self .epsilon ,
118+ "center" : self .center ,
119+ "scale" : self .scale ,
120+ "trainable" : self .trainable ,
121+ "block_name" : self .block_name ,
122+
123+
124+ # self.use_bias = not (use_batchnorm)
125+ #
126+ #
127+ #
128+ #
129+ # self.conv = None
130+ # self.bn = None
131+
132+
133+ # tf.keras.layers.BatchNormalization(scale=False, momentum=0.9)
134+ # self.post_activation = tf.keras.layers.Activation(post_activation)
135+ "post_activation" : activations .serialize (self .post_activation ),
136+ }
137+ base_config = super (ConvolutionBnActivation , self ).get_config ()
138+ return dict (list (base_config .items ()) + list (config .items ()))
139+
93140class AtrousSeparableConvolutionBnReLU (tf .keras .layers .Layer ):
94141 """
95142 """
@@ -214,9 +261,10 @@ def compute_output_shape(self, input_shape):
214261class Upsample_x2_Block (tf .keras .layers .Layer ):
215262 """
216263 """
217- def __init__ (self , filters , trainable = None ):
218- super (Upsample_x2_Block , self ).__init__ ()
264+ def __init__ (self , filters , trainable = None , ** kwargs ):
265+ super (Upsample_x2_Block , self ).__init__ (** kwargs )
219266 self .trainable = trainable
267+ self .filters = filters
220268
221269 self .upsample2d_size2 = tf .keras .layers .UpSampling2D (size = 2 , interpolation = "bilinear" )
222270 self .conv2x2_bn_relu = tf .keras .layers .Conv2D (filters , kernel_size = (2 , 2 ), padding = "same" )
@@ -242,6 +290,14 @@ def compute_output_shape(self, input_shape):
242290 print (input_shape )
243291 return [input_shape [0 ], input_shape [1 ] * 2 , input_shape [2 ] * 2 , input_shape [3 ]]
244292
293+ def get_config (self ):
294+ config = {
295+ "filters" : self .filters ,
296+ "trainable" : self .trainable ,
297+ }
298+ base_config = super (Upsample_x2_Block , self ).get_config ()
299+ return dict (list (base_config .items ()) + list (config .items ()))
300+
245301class Upsample_x2_Add_Block (tf .keras .layers .Layer ):
246302 """
247303 """
@@ -1176,4 +1232,24 @@ def call(self, input1, input2, input3, input4, training=None):
11761232
11771233 x_fuse .append (self .relu (y ))
11781234
1179- return x_fuse
1235+ return x_fuse
1236+
1237+
1238+ custom_objects = {
1239+ 'ConvolutionBnActivation' : ConvolutionBnActivation ,
1240+ 'AtrousSeparableConvolutionBnReLU' : AtrousSeparableConvolutionBnReLU ,
1241+ 'AtrousSpatialPyramidPoolingV3' : AtrousSpatialPyramidPoolingV3 ,
1242+ 'Upsample_x2_Block' : Upsample_x2_Block ,
1243+ 'Upsample_x2_Add_Block' : Upsample_x2_Add_Block ,
1244+ 'SpatialContextBlock' : SpatialContextBlock ,
1245+ 'FPNBlock' : FPNBlock ,
1246+ 'AtrousSpatialPyramidPoolingV1' : AtrousSpatialPyramidPoolingV1 ,
1247+ 'Base_OC_Module' : Base_OC_Module ,
1248+ 'Pyramid_OC_Module' : Pyramid_OC_Module ,
1249+ 'ASP_OC_Module' : ASP_OC_Module ,
1250+ 'PAM_Module' : PAM_Module ,
1251+ 'CAM_Module' : CAM_Module ,
1252+ 'SelfAttentionBlock2D' : SelfAttentionBlock2D ,
1253+ }
1254+
1255+ tf .keras .utils .get_custom_objects ().update (custom_objects )
0 commit comments