Skip to content

Commit f899e9d

Browse files
committed
只优化需要梯度的参数
1 parent a93a1f3 commit f899e9d

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

‎torchocr/optimizer/__init__.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
def build_optimizer(optim_config, lr_scheduler_config, epochs, step_each_epoch, model):
99
from . import lr
1010
config = copy.deepcopy(optim_config)
11-
optim = getattr(torch.optim, config.pop('name'))(params=model.parameters(), **config)
11+
train_params = filter(lambda p: p.requires_grad, model.parameters())
12+
optim = getattr(torch.optim, config.pop('name'))(params=train_params, **config)
1213

1314
lr_config = copy.deepcopy(lr_scheduler_config)
1415
lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch})

0 commit comments

Comments
 (0)