-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Expand file tree
/
Copy pathlayers.py
More file actions
executable file
·189 lines (155 loc) · 7.5 KB
/
layers.py
File metadata and controls
executable file
·189 lines (155 loc) · 7.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
""" Contains a set of utilities that allow building the UNet model
"""
import tensorflow as tf
def _crop_and_concat(inputs, residual_input):
""" Perform a central crop of ``residual_input`` and concatenate to ``inputs``
Args:
inputs (tf.Tensor): Tensor with input
residual_input (tf.Tensor): Residual input
Return:
Concatenated tf.Tensor with the size of ``inputs``
"""
factor = inputs.shape[1] / residual_input.shape[1]
return tf.concat([inputs, tf.image.central_crop(residual_input, factor)], axis=-1)
class InputBlock(tf.keras.Model):
def __init__(self, filters):
""" UNet input block
Perform two unpadded convolutions with a specified number of filters and downsample
through max-pooling. First convolution
Args:
filters (int): Number of filters in convolution
"""
super().__init__(self)
with tf.name_scope('input_block'):
self.conv1 = tf.keras.layers.Conv2D(filters=filters,
kernel_size=(3, 3),
activation=tf.nn.relu)
self.conv2 = tf.keras.layers.Conv2D(filters=filters,
kernel_size=(3, 3),
activation=tf.nn.relu)
self.maxpool = tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=2)
def call(self, inputs):
out = self.conv1(inputs)
out = self.conv2(out)
mp = self.maxpool(out)
return mp, out
class DownsampleBlock(tf.keras.Model):
def __init__(self, filters, idx):
""" UNet downsample block
Perform two unpadded convolutions with a specified number of filters and downsample
through max-pooling
Args:
filters (int): Number of filters in convolution
idx (int): Index of block
Return:
Tuple of convolved ``inputs`` after and before downsampling
"""
super().__init__(self)
with tf.name_scope('downsample_block_{}'.format(idx)):
self.conv1 = tf.keras.layers.Conv2D(filters=filters,
kernel_size=(3, 3),
activation=tf.nn.relu)
self.conv2 = tf.keras.layers.Conv2D(filters=filters,
kernel_size=(3, 3),
activation=tf.nn.relu)
self.maxpool = tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=2)
def call(self, inputs):
out = self.conv1(inputs)
out = self.conv2(out)
mp = self.maxpool(out)
return mp, out
class BottleneckBlock(tf.keras.Model):
def __init__(self, filters):
""" UNet central block
Perform two unpadded convolutions with a specified number of filters and upsample
including dropout before upsampling for training
Args:
filters (int): Number of filters in convolution
"""
super().__init__(self)
with tf.name_scope('bottleneck_block'):
self.conv1 = tf.keras.layers.Conv2D(filters=filters,
kernel_size=(3, 3),
activation=tf.nn.relu)
self.conv2 = tf.keras.layers.Conv2D(filters=filters,
kernel_size=(3, 3),
activation=tf.nn.relu)
self.dropout = tf.keras.layers.Dropout(rate=0.5)
self.conv_transpose = tf.keras.layers.Conv2DTranspose(filters=filters // 2,
kernel_size=(3, 3),
strides=(2, 2),
padding='same',
activation=tf.nn.relu)
def call(self, inputs, training):
out = self.conv1(inputs)
out = self.conv2(out)
out = self.dropout(out, training=training)
out = self.conv_transpose(out)
return out
class UpsampleBlock(tf.keras.Model):
def __init__(self, filters, idx):
""" UNet upsample block
Perform two unpadded convolutions with a specified number of filters and upsample
Args:
filters (int): Number of filters in convolution
idx (int): Index of block
"""
super().__init__(self)
with tf.name_scope('upsample_block_{}'.format(idx)):
self.conv1 = tf.keras.layers.Conv2D(filters=filters,
kernel_size=(3, 3),
activation=tf.nn.relu)
self.conv2 = tf.keras.layers.Conv2D(filters=filters,
kernel_size=(3, 3),
activation=tf.nn.relu)
self.conv_transpose = tf.keras.layers.Conv2DTranspose(filters=filters // 2,
kernel_size=(3, 3),
strides=(2, 2),
padding='same',
activation=tf.nn.relu)
def call(self, inputs, residual_input):
out = _crop_and_concat(inputs, residual_input)
out = self.conv1(out)
out = self.conv2(out)
out = self.conv_transpose(out)
return out
class OutputBlock(tf.keras.Model):
def __init__(self, filters, n_classes):
""" UNet output block
Perform three unpadded convolutions, the last one with the same number
of channels as classes we want to classify
Args:
filters (int): Number of filters in convolution
n_classes (int): Number of output classes
"""
super().__init__(self)
with tf.name_scope('output_block'):
self.conv1 = tf.keras.layers.Conv2D(filters=filters,
kernel_size=(3, 3),
activation=tf.nn.relu)
self.conv2 = tf.keras.layers.Conv2D(filters=filters,
kernel_size=(3, 3),
activation=tf.nn.relu)
self.conv3 = tf.keras.layers.Conv2D(filters=n_classes,
kernel_size=(1, 1),
activation=None)
def call(self, inputs, residual_input):
out = _crop_and_concat(inputs, residual_input)
out = self.conv1(out)
out = self.conv2(out)
out = self.conv3(out)
return out