Skip to content

ComplexDropout2d Device Error #30

@lucacoma

Description

@lucacoma

Hi, thank you for the nice library.

There seems to be a small mistake in the complexPyTorch.complexLayers.ComplexDropout2d layer, which gives a device mismatch error (torch version 2.0.1+cu118):

""" .... line 106, in complex_dropout
return mask*input
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
"""

I managed to solve it by simply moving the mask on the right device in complexPyTorch.complexFunctions.complex_dropout2d as follows

`
def complex_dropout2d(input, p=0.5, training=True):

# need to have the same dropout mask for real and imaginary part,

# this not a clean solution!

device = input.device

mask = torch.ones(*input.shape, dtype = torch.float32, device = device)

mask = torch.nn.functional.dropout2d(mask, p, training)*1/(1-p)

mask.type(input.dtype)

mask = mask.to(device) # Line added

return mask*input`

Best!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions