diff --git a/README.md b/README.md index 09f15b3..1d45599 100755 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ Following [[C. Trabelsi et al., International Conference on Learning Representat * Relu (ℂRelu) * Sigmoid * Tanh +* softmax * Dropout2d * BatchNorm1d (Naive and Covariance approach) * BatchNorm2d (Naive and Covariance approach) diff --git a/complexPyTorch/complexFunctions.py b/complexPyTorch/complexFunctions.py index 5910058..9d15c45 100755 --- a/complexPyTorch/complexFunctions.py +++ b/complexPyTorch/complexFunctions.py @@ -15,6 +15,7 @@ relu, sigmoid, tanh, + softmax, ) @@ -209,3 +210,11 @@ def complex_dropout2d(inp, p=0.5, training=True): mask = dropout2d(mask, p, training) * 1 / (1 - p) mask.type(inp.dtype) return mask * inp + +def complex_softmax(inp, dim): + """ + Perform complex softmax. + """ + real_softmax = softmax(inp.real, dim=dim) + imag_softmax = softmax(inp.imag, dim=dim) + return real_softmax.type(torch.complex64) + 1j * imag_softmax.type(torch.complex64)