Skip to content

Commit 3350681

Browse files
Remove entmax dependency (#352)
* factored out entmax * removed entmax dependency * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e1d5752 commit 3350681

File tree

2 files changed

+319
-7
lines changed

2 files changed

+319
-7
lines changed

requirements/base.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,3 @@ matplotlib >3.1
1414
ipywidgets
1515
einops >=0.6.0, <0.8.0
1616
rich >=11.0.0
17-
entmax>=1.1

src/pytorch_tabular/models/common/layers/activations.py

Lines changed: 319 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# W605
22
import torch
3+
import torch.nn as nn
34
import torch.nn.functional as F
4-
from entmax import entmax15, sparsemax
55
from torch import Tensor
66
from torch.autograd import Function
77
from torch.jit import script
@@ -43,11 +43,6 @@ def sparsemoid(input):
4343
return (0.5 * input + 0.5).clamp_(0, 1)
4444

4545

46-
entmoid15 = Entmoid15.apply
47-
entmax15 = entmax15
48-
sparsemax = sparsemax
49-
50-
5146
def t_softmax(input: Tensor, t: Tensor = None, dim: int = -1) -> Tensor:
5247
if t is None:
5348
t = torch.tensor(0.5, device=input.device)
@@ -98,3 +93,321 @@ def calculate_t(cls, input: Tensor, r: Tensor, dim: int = -1, eps: float = 1e-8)
9893
def forward(self, input: Tensor, r: Tensor):
9994
t = RSoftmax.calculate_t(input, r, self.dim, self.eps)
10095
return self.tsoftmax(input, t)
96+
97+
98+
"""
99+
An implementation of entmax (Peters et al., 2019). See
100+
https://arxiv.org/pdf/1905.05702 for detailed description.
101+
102+
This builds on previous work with sparsemax (Martins & Astudillo, 2016).
103+
See https://arxiv.org/pdf/1602.02068.
104+
"""
105+
106+
# Author: Ben Peters
107+
# Author: Vlad Niculae <vlad@vene.ro>
108+
# License: MIT
109+
# Including here for conda compatibility
110+
111+
112+
def _make_ix_like(X, dim):
113+
d = X.size(dim)
114+
rho = torch.arange(1, d + 1, device=X.device, dtype=X.dtype)
115+
view = [1] * X.dim()
116+
view[0] = -1
117+
return rho.view(view).transpose(0, dim)
118+
119+
120+
def _roll_last(X, dim):
121+
if dim == -1:
122+
return X
123+
elif dim < 0:
124+
dim = X.dim() - dim
125+
126+
perm = [i for i in range(X.dim()) if i != dim] + [dim]
127+
return X.permute(perm)
128+
129+
130+
def _sparsemax_threshold_and_support(X, dim=-1, k=None):
131+
"""Core computation for sparsemax: optimal threshold and support size.
132+
133+
Parameters
134+
----------
135+
X : torch.Tensor
136+
The input tensor to compute thresholds over.
137+
138+
dim : int
139+
The dimension along which to apply sparsemax.
140+
141+
k : int or None
142+
number of largest elements to partial-sort over. For optimal
143+
performance, should be slightly bigger than the expected number of
144+
nonzeros in the solution. If the solution is more than k-sparse,
145+
this function is recursively called with a 2*k schedule.
146+
If `None`, full sorting is performed from the beginning.
147+
148+
Returns
149+
-------
150+
tau : torch.Tensor like `X`, with all but the `dim` dimension intact
151+
the threshold value for each vector
152+
support_size : torch LongTensor, shape like `tau`
153+
the number of nonzeros in each vector.
154+
"""
155+
156+
if k is None or k >= X.shape[dim]: # do full sort
157+
topk, _ = torch.sort(X, dim=dim, descending=True)
158+
else:
159+
topk, _ = torch.topk(X, k=k, dim=dim)
160+
161+
topk_cumsum = topk.cumsum(dim) - 1
162+
rhos = _make_ix_like(topk, dim)
163+
support = rhos * topk > topk_cumsum
164+
165+
support_size = support.sum(dim=dim).unsqueeze(dim)
166+
tau = topk_cumsum.gather(dim, support_size - 1)
167+
tau /= support_size.to(X.dtype)
168+
169+
if k is not None and k < X.shape[dim]:
170+
unsolved = (support_size == k).squeeze(dim)
171+
172+
if torch.any(unsolved):
173+
in_ = _roll_last(X, dim)[unsolved]
174+
tau_, ss_ = _sparsemax_threshold_and_support(in_, dim=-1, k=2 * k)
175+
_roll_last(tau, dim)[unsolved] = tau_
176+
_roll_last(support_size, dim)[unsolved] = ss_
177+
178+
return tau, support_size
179+
180+
181+
def _entmax_threshold_and_support(X, dim=-1, k=None):
182+
"""Core computation for 1.5-entmax: optimal threshold and support size.
183+
184+
Parameters
185+
----------
186+
X : torch.Tensor
187+
The input tensor to compute thresholds over.
188+
189+
dim : int
190+
The dimension along which to apply 1.5-entmax.
191+
192+
k : int or None
193+
number of largest elements to partial-sort over. For optimal
194+
performance, should be slightly bigger than the expected number of
195+
nonzeros in the solution. If the solution is more than k-sparse,
196+
this function is recursively called with a 2*k schedule.
197+
If `None`, full sorting is performed from the beginning.
198+
199+
Returns
200+
-------
201+
tau : torch.Tensor like `X`, with all but the `dim` dimension intact
202+
the threshold value for each vector
203+
support_size : torch LongTensor, shape like `tau`
204+
the number of nonzeros in each vector.
205+
"""
206+
207+
if k is None or k >= X.shape[dim]: # do full sort
208+
Xsrt, _ = torch.sort(X, dim=dim, descending=True)
209+
else:
210+
Xsrt, _ = torch.topk(X, k=k, dim=dim)
211+
212+
rho = _make_ix_like(Xsrt, dim)
213+
mean = Xsrt.cumsum(dim) / rho
214+
mean_sq = (Xsrt**2).cumsum(dim) / rho
215+
ss = rho * (mean_sq - mean**2)
216+
delta = (1 - ss) / rho
217+
218+
# NOTE this is not exactly the same as in reference algo
219+
# Fortunately it seems the clamped values never wrongly
220+
# get selected by tau <= sorted_z. Prove this!
221+
delta_nz = torch.clamp(delta, 0)
222+
tau = mean - torch.sqrt(delta_nz)
223+
224+
support_size = (tau <= Xsrt).sum(dim).unsqueeze(dim)
225+
tau_star = tau.gather(dim, support_size - 1)
226+
227+
if k is not None and k < X.shape[dim]:
228+
unsolved = (support_size == k).squeeze(dim)
229+
230+
if torch.any(unsolved):
231+
X_ = _roll_last(X, dim)[unsolved]
232+
tau_, ss_ = _entmax_threshold_and_support(X_, dim=-1, k=2 * k)
233+
_roll_last(tau_star, dim)[unsolved] = tau_
234+
_roll_last(support_size, dim)[unsolved] = ss_
235+
236+
return tau_star, support_size
237+
238+
239+
class SparsemaxFunction(Function):
240+
@classmethod
241+
def forward(cls, ctx, X, dim=-1, k=None):
242+
ctx.dim = dim
243+
max_val, _ = X.max(dim=dim, keepdim=True)
244+
X = X - max_val # same numerical stability trick as softmax
245+
tau, supp_size = _sparsemax_threshold_and_support(X, dim=dim, k=k)
246+
output = torch.clamp(X - tau, min=0)
247+
ctx.save_for_backward(supp_size, output)
248+
return output
249+
250+
@classmethod
251+
def backward(cls, ctx, grad_output):
252+
supp_size, output = ctx.saved_tensors
253+
dim = ctx.dim
254+
grad_input = grad_output.clone()
255+
grad_input[output == 0] = 0
256+
257+
v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze(dim)
258+
v_hat = v_hat.unsqueeze(dim)
259+
grad_input = torch.where(output != 0, grad_input - v_hat, grad_input)
260+
return grad_input, None, None
261+
262+
263+
class Entmax15Function(Function):
264+
@classmethod
265+
def forward(cls, ctx, X, dim=0, k=None):
266+
ctx.dim = dim
267+
268+
max_val, _ = X.max(dim=dim, keepdim=True)
269+
X = X - max_val # same numerical stability trick as for softmax
270+
X = X / 2 # divide by 2 to solve actual Entmax
271+
272+
tau_star, _ = _entmax_threshold_and_support(X, dim=dim, k=k)
273+
274+
Y = torch.clamp(X - tau_star, min=0) ** 2
275+
ctx.save_for_backward(Y)
276+
return Y
277+
278+
@classmethod
279+
def backward(cls, ctx, dY):
280+
(Y,) = ctx.saved_tensors
281+
gppr = Y.sqrt() # = 1 / g'' (Y)
282+
dX = dY * gppr
283+
q = dX.sum(ctx.dim) / gppr.sum(ctx.dim)
284+
q = q.unsqueeze(ctx.dim)
285+
dX -= q * gppr
286+
return dX, None, None
287+
288+
289+
def sparsemax(X, dim=-1, k=None):
290+
"""sparsemax: normalizing sparse transform (a la softmax).
291+
292+
Solves the projection:
293+
294+
min_p ||x - p||_2 s.t. p >= 0, sum(p) == 1.
295+
296+
Parameters
297+
----------
298+
X : torch.Tensor
299+
The input tensor.
300+
301+
dim : int
302+
The dimension along which to apply sparsemax.
303+
304+
k : int or None
305+
number of largest elements to partial-sort over. For optimal
306+
performance, should be slightly bigger than the expected number of
307+
nonzeros in the solution. If the solution is more than k-sparse,
308+
this function is recursively called with a 2*k schedule.
309+
If `None`, full sorting is performed from the beginning.
310+
311+
Returns
312+
-------
313+
P : torch tensor, same shape as X
314+
The projection result, such that P.sum(dim=dim) == 1 elementwise.
315+
"""
316+
317+
return SparsemaxFunction.apply(X, dim, k)
318+
319+
320+
def entmax15(X, dim=-1, k=None):
321+
"""1.5-entmax: normalizing sparse transform (a la softmax).
322+
323+
Solves the optimization problem:
324+
325+
max_p <x, p> - H_1.5(p) s.t. p >= 0, sum(p) == 1.
326+
327+
where H_1.5(p) is the Tsallis alpha-entropy with alpha=1.5.
328+
329+
Parameters
330+
----------
331+
X : torch.Tensor
332+
The input tensor.
333+
334+
dim : int
335+
The dimension along which to apply 1.5-entmax.
336+
337+
k : int or None
338+
number of largest elements to partial-sort over. For optimal
339+
performance, should be slightly bigger than the expected number of
340+
nonzeros in the solution. If the solution is more than k-sparse,
341+
this function is recursively called with a 2*k schedule.
342+
If `None`, full sorting is performed from the beginning.
343+
344+
Returns
345+
-------
346+
P : torch tensor, same shape as X
347+
The projection result, such that P.sum(dim=dim) == 1 elementwise.
348+
"""
349+
350+
return Entmax15Function.apply(X, dim, k)
351+
352+
353+
class Sparsemax(nn.Module):
354+
def __init__(self, dim=-1, k=None):
355+
"""sparsemax: normalizing sparse transform (a la softmax).
356+
357+
Solves the projection:
358+
359+
min_p ||x - p||_2 s.t. p >= 0, sum(p) == 1.
360+
361+
Parameters
362+
----------
363+
dim : int
364+
The dimension along which to apply sparsemax.
365+
366+
k : int or None
367+
number of largest elements to partial-sort over. For optimal
368+
performance, should be slightly bigger than the expected number of
369+
nonzeros in the solution. If the solution is more than k-sparse,
370+
this function is recursively called with a 2*k schedule.
371+
If `None`, full sorting is performed from the beginning.
372+
"""
373+
self.dim = dim
374+
self.k = k
375+
super().__init__()
376+
377+
def forward(self, X):
378+
return sparsemax(X, dim=self.dim, k=self.k)
379+
380+
381+
class Entmax15(nn.Module):
382+
def __init__(self, dim=-1, k=None):
383+
"""1.5-entmax: normalizing sparse transform (a la softmax).
384+
385+
Solves the optimization problem:
386+
387+
max_p <x, p> - H_1.5(p) s.t. p >= 0, sum(p) == 1.
388+
389+
where H_1.5(p) is the Tsallis alpha-entropy with alpha=1.5.
390+
391+
Parameters
392+
----------
393+
dim : int
394+
The dimension along which to apply 1.5-entmax.
395+
396+
k : int or None
397+
number of largest elements to partial-sort over. For optimal
398+
performance, should be slightly bigger than the expected number of
399+
nonzeros in the solution. If the solution is more than k-sparse,
400+
this function is recursively called with a 2*k schedule.
401+
If `None`, full sorting is performed from the beginning.
402+
"""
403+
self.dim = dim
404+
self.k = k
405+
super().__init__()
406+
407+
def forward(self, X):
408+
return entmax15(X, dim=self.dim, k=self.k)
409+
410+
411+
entmoid15 = Entmoid15.apply
412+
entmax15 = entmax15
413+
sparsemax = sparsemax

0 commit comments

Comments
 (0)