From 526956ac056827a9b37372da63d7feb354aa5b6c Mon Sep 17 00:00:00 2001 From: lsewcx Date: Mon, 11 Nov 2024 09:56:04 +0800 Subject: [PATCH 1/3] add softmax --- README.md | 1 + complexPyTorch/complexFunctions.py | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/README.md b/README.md index 09f15b3..1d45599 100755 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ Following [[C. Trabelsi et al., International Conference on Learning Representat * Relu (ℂRelu) * Sigmoid * Tanh +* softmax * Dropout2d * BatchNorm1d (Naive and Covariance approach) * BatchNorm2d (Naive and Covariance approach) diff --git a/complexPyTorch/complexFunctions.py b/complexPyTorch/complexFunctions.py index 5910058..34eddfa 100755 --- a/complexPyTorch/complexFunctions.py +++ b/complexPyTorch/complexFunctions.py @@ -15,6 +15,7 @@ relu, sigmoid, tanh, + softmax, ) @@ -209,3 +210,11 @@ def complex_dropout2d(inp, p=0.5, training=True): mask = dropout2d(mask, p, training) * 1 / (1 - p) mask.type(inp.dtype) return mask * inp + +def complex_softmax(inp, dim, dtype=None): + """ + Perform complex softmax. + """ + real_softmax = softmax(inp.real, dim=dim) + imag_softmax = softmax(inp.imag, dim=dim) + return real_softmax.type(torch.complex64) + 1j * imag_softmax.type(torch.complex64) From 1e3cb994ef5ac3599df18497a791d936eb801015 Mon Sep 17 00:00:00 2001 From: lsewcx Date: Mon, 11 Nov 2024 10:44:50 +0800 Subject: [PATCH 2/3] dtype make torch.complex64 --- complexPyTorch/complexFunctions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/complexPyTorch/complexFunctions.py b/complexPyTorch/complexFunctions.py index 34eddfa..e5be5de 100755 --- a/complexPyTorch/complexFunctions.py +++ b/complexPyTorch/complexFunctions.py @@ -211,10 +211,10 @@ def complex_dropout2d(inp, p=0.5, training=True): mask.type(inp.dtype) return mask * inp -def complex_softmax(inp, dim, dtype=None): +def complex_softmax(inp, dim, dtype=torch.complex64): """ Perform complex softmax. """ - real_softmax = softmax(inp.real, dim=dim) - imag_softmax = softmax(inp.imag, dim=dim) + real_softmax = softmax(inp.real, dim=dim, dtype=dtype) + imag_softmax = softmax(inp.imag, dim=dim, dtype=dtype) return real_softmax.type(torch.complex64) + 1j * imag_softmax.type(torch.complex64) From 9a206759bafaefef796e7d9ec93aa6642817dd34 Mon Sep 17 00:00:00 2001 From: lsewcx Date: Mon, 11 Nov 2024 10:47:50 +0800 Subject: [PATCH 3/3] remove dtype --- complexPyTorch/complexFunctions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/complexPyTorch/complexFunctions.py b/complexPyTorch/complexFunctions.py index e5be5de..9d15c45 100755 --- a/complexPyTorch/complexFunctions.py +++ b/complexPyTorch/complexFunctions.py @@ -211,10 +211,10 @@ def complex_dropout2d(inp, p=0.5, training=True): mask.type(inp.dtype) return mask * inp -def complex_softmax(inp, dim, dtype=torch.complex64): +def complex_softmax(inp, dim): """ Perform complex softmax. """ - real_softmax = softmax(inp.real, dim=dim, dtype=dtype) - imag_softmax = softmax(inp.imag, dim=dim, dtype=dtype) + real_softmax = softmax(inp.real, dim=dim) + imag_softmax = softmax(inp.imag, dim=dim) return real_softmax.type(torch.complex64) + 1j * imag_softmax.type(torch.complex64)