-
Notifications
You must be signed in to change notification settings - Fork 160
Open
Description
Hi, thank you for the nice library.
There seems to be a small mistake in the complexPyTorch.complexLayers.ComplexDropout2d layer, which gives a device mismatch error (torch version 2.0.1+cu118):
""" .... line 106, in complex_dropout
return mask*input
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
"""
I managed to solve it by simply moving the mask on the right device in complexPyTorch.complexFunctions.complex_dropout2d as follows
`
def complex_dropout2d(input, p=0.5, training=True):
# need to have the same dropout mask for real and imaginary part,
# this not a clean solution!
device = input.device
mask = torch.ones(*input.shape, dtype = torch.float32, device = device)
mask = torch.nn.functional.dropout2d(mask, p, training)*1/(1-p)
mask.type(input.dtype)
mask = mask.to(device) # Line added
return mask*input`
Best!
Weijie-ZHAO, CBZ199671, Moorsol, silver-ac and nctamer
Metadata
Metadata
Assignees
Labels
No labels