Skip to content

Commit f2bbf98

Browse files
committed
Fix issue where the disable_tracking decorator obfuscates layer constructors.
1 parent 9ad5a18 commit f2bbf98

File tree

4 files changed

+22
-13
lines changed

4 files changed

+22
-13
lines changed

‎docs/mkdocs.yml‎

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,6 @@ nav:
7373
- Baby RNN: examples/babi_rnn.md
7474
- Baby MemNN: examples/babi_memnn.md
7575
- CIFAR-10 CNN: examples/cifar10_cnn.md
76-
- CIFAR-10 CNN-Capsule: examples/cifar10_cnn_capsule.md
77-
- CIFAR-10 CNN with augmentation (TF): examples/cifar10_cnn_tfaugment2d.md
7876
- CIFAR-10 ResNet: examples/cifar10_resnet.md
7977
- Convolution filter visualization: examples/conv_filter_visualization.md
8078
- Convolutional LSTM: examples/conv_lstm.md

‎keras/layers/normalization.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from __future__ import division
66
from __future__ import print_function
77

8-
from ..engine.base_layer import Layer, InputSpec, disable_tracking
8+
from ..engine.base_layer import Layer, InputSpec
99
from .. import initializers
1010
from .. import regularizers
1111
from .. import constraints

‎keras/layers/recurrent.py‎

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ class StackedRNNCells(Layer):
4646
```
4747
"""
4848

49-
@disable_tracking
5049
def __init__(self, cells, **kwargs):
5150
for cell in cells:
5251
if not hasattr(cell, 'call'):
@@ -391,7 +390,6 @@ def call(self, inputs, states):
391390
```
392391
"""
393392

394-
@disable_tracking
395393
def __init__(self, cell,
396394
return_sequences=False,
397395
return_state=False,
@@ -410,7 +408,7 @@ def __init__(self, cell,
410408
'(tuple of integers, '
411409
'one integer per RNN state).')
412410
super(RNN, self).__init__(**kwargs)
413-
self.cell = cell
411+
self._set_cell(cell)
414412
self.return_sequences = return_sequences
415413
self.return_state = return_state
416414
self.go_backwards = go_backwards
@@ -424,6 +422,13 @@ def __init__(self, cell,
424422
self.constants_spec = None
425423
self._num_constants = None
426424

425+
@disable_tracking
426+
def _set_cell(self, cell):
427+
# This is isolated in its own method in order to use
428+
# the disable_tracking decorator without altering the
429+
# visible signature of __init__.
430+
self.cell = cell
431+
427432
@property
428433
def states(self):
429434
if self._states is None:

‎keras/layers/wrappers.py‎

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -360,18 +360,12 @@ class Bidirectional(Wrapper):
360360
```
361361
"""
362362

363-
@disable_tracking
364363
def __init__(self, layer, merge_mode='concat', weights=None, **kwargs):
365364
if merge_mode not in ['sum', 'mul', 'ave', 'concat', None]:
366365
raise ValueError('Invalid merge mode. '
367366
'Merge mode should be one of '
368367
'{"sum", "mul", "ave", "concat", None}')
369-
self.forward_layer = copy.copy(layer)
370-
config = layer.get_config()
371-
config['go_backwards'] = not config['go_backwards']
372-
self.backward_layer = layer.__class__.from_config(config)
373-
self.forward_layer.name = 'forward_' + self.forward_layer.name
374-
self.backward_layer.name = 'backward_' + self.backward_layer.name
368+
self._set_sublayers(layer)
375369
self.merge_mode = merge_mode
376370
if weights:
377371
nw = len(weights)
@@ -386,6 +380,18 @@ def __init__(self, layer, merge_mode='concat', weights=None, **kwargs):
386380
self.input_spec = layer.input_spec
387381
self._num_constants = None
388382

383+
@disable_tracking
384+
def _set_sublayers(self, layer):
385+
# This is isolated in its own method in order to use
386+
# the disable_tracking decorator without altering the
387+
# visible signature of __init__.
388+
self.forward_layer = copy.copy(layer)
389+
config = layer.get_config()
390+
config['go_backwards'] = not config['go_backwards']
391+
self.backward_layer = layer.__class__.from_config(config)
392+
self.forward_layer.name = 'forward_' + self.forward_layer.name
393+
self.backward_layer.name = 'backward_' + self.backward_layer.name
394+
389395
@property
390396
def trainable(self):
391397
return self._trainable

0 commit comments

Comments
 (0)