File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed
Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -98,7 +98,7 @@ def step(self, closure=None):
9898 and returns the loss.
9999 """
100100 device = self .param_groups [0 ]["params" ][0 ].device
101- one_tensor = torch .tensor (1.0 , device = device )
101+ one_tensor = torch .tensor (1.0 , device = device ) # because torch.where doesn't handle scalars correctly
102102
103103 loss = None
104104 if closure is not None :
@@ -115,7 +115,9 @@ def step(self, closure=None):
115115 global_grad_norm .add_ (grad .pow (2 ).sum ())
116116
117117 global_grad_norm = torch .sqrt (global_grad_norm )
118- max_grad_norm = self .defaults ['max_grad_norm' ]
118+ # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
119+ # scalar types properly https://github.com/pytorch/pytorch/issues/9190
120+ max_grad_norm = torch .tensor (self .defaults ['max_grad_norm' ], device = device )
119121 clip_global_grad_norm = torch .where (
120122 global_grad_norm > max_grad_norm ,
121123 global_grad_norm / max_grad_norm ,
You can’t perform that action at this time.
0 commit comments