|
1 | 1 | # W605 |
2 | 2 | import torch |
| 3 | +import torch.nn as nn |
3 | 4 | import torch.nn.functional as F |
4 | | -from entmax import entmax15, sparsemax |
5 | 5 | from torch import Tensor |
6 | 6 | from torch.autograd import Function |
7 | 7 | from torch.jit import script |
@@ -43,11 +43,6 @@ def sparsemoid(input): |
43 | 43 | return (0.5 * input + 0.5).clamp_(0, 1) |
44 | 44 |
|
45 | 45 |
|
46 | | -entmoid15 = Entmoid15.apply |
47 | | -entmax15 = entmax15 |
48 | | -sparsemax = sparsemax |
49 | | - |
50 | | - |
51 | 46 | def t_softmax(input: Tensor, t: Tensor = None, dim: int = -1) -> Tensor: |
52 | 47 | if t is None: |
53 | 48 | t = torch.tensor(0.5, device=input.device) |
@@ -98,3 +93,321 @@ def calculate_t(cls, input: Tensor, r: Tensor, dim: int = -1, eps: float = 1e-8) |
98 | 93 | def forward(self, input: Tensor, r: Tensor): |
99 | 94 | t = RSoftmax.calculate_t(input, r, self.dim, self.eps) |
100 | 95 | return self.tsoftmax(input, t) |
| 96 | + |
| 97 | + |
| 98 | +""" |
| 99 | +An implementation of entmax (Peters et al., 2019). See |
| 100 | +https://arxiv.org/pdf/1905.05702 for detailed description. |
| 101 | +
|
| 102 | +This builds on previous work with sparsemax (Martins & Astudillo, 2016). |
| 103 | +See https://arxiv.org/pdf/1602.02068. |
| 104 | +""" |
| 105 | + |
| 106 | +# Author: Ben Peters |
| 107 | +# Author: Vlad Niculae <vlad@vene.ro> |
| 108 | +# License: MIT |
| 109 | +# Including here for conda compatibility |
| 110 | + |
| 111 | + |
| 112 | +def _make_ix_like(X, dim): |
| 113 | + d = X.size(dim) |
| 114 | + rho = torch.arange(1, d + 1, device=X.device, dtype=X.dtype) |
| 115 | + view = [1] * X.dim() |
| 116 | + view[0] = -1 |
| 117 | + return rho.view(view).transpose(0, dim) |
| 118 | + |
| 119 | + |
| 120 | +def _roll_last(X, dim): |
| 121 | + if dim == -1: |
| 122 | + return X |
| 123 | + elif dim < 0: |
| 124 | + dim = X.dim() - dim |
| 125 | + |
| 126 | + perm = [i for i in range(X.dim()) if i != dim] + [dim] |
| 127 | + return X.permute(perm) |
| 128 | + |
| 129 | + |
| 130 | +def _sparsemax_threshold_and_support(X, dim=-1, k=None): |
| 131 | + """Core computation for sparsemax: optimal threshold and support size. |
| 132 | +
|
| 133 | + Parameters |
| 134 | + ---------- |
| 135 | + X : torch.Tensor |
| 136 | + The input tensor to compute thresholds over. |
| 137 | +
|
| 138 | + dim : int |
| 139 | + The dimension along which to apply sparsemax. |
| 140 | +
|
| 141 | + k : int or None |
| 142 | + number of largest elements to partial-sort over. For optimal |
| 143 | + performance, should be slightly bigger than the expected number of |
| 144 | + nonzeros in the solution. If the solution is more than k-sparse, |
| 145 | + this function is recursively called with a 2*k schedule. |
| 146 | + If `None`, full sorting is performed from the beginning. |
| 147 | +
|
| 148 | + Returns |
| 149 | + ------- |
| 150 | + tau : torch.Tensor like `X`, with all but the `dim` dimension intact |
| 151 | + the threshold value for each vector |
| 152 | + support_size : torch LongTensor, shape like `tau` |
| 153 | + the number of nonzeros in each vector. |
| 154 | + """ |
| 155 | + |
| 156 | + if k is None or k >= X.shape[dim]: # do full sort |
| 157 | + topk, _ = torch.sort(X, dim=dim, descending=True) |
| 158 | + else: |
| 159 | + topk, _ = torch.topk(X, k=k, dim=dim) |
| 160 | + |
| 161 | + topk_cumsum = topk.cumsum(dim) - 1 |
| 162 | + rhos = _make_ix_like(topk, dim) |
| 163 | + support = rhos * topk > topk_cumsum |
| 164 | + |
| 165 | + support_size = support.sum(dim=dim).unsqueeze(dim) |
| 166 | + tau = topk_cumsum.gather(dim, support_size - 1) |
| 167 | + tau /= support_size.to(X.dtype) |
| 168 | + |
| 169 | + if k is not None and k < X.shape[dim]: |
| 170 | + unsolved = (support_size == k).squeeze(dim) |
| 171 | + |
| 172 | + if torch.any(unsolved): |
| 173 | + in_ = _roll_last(X, dim)[unsolved] |
| 174 | + tau_, ss_ = _sparsemax_threshold_and_support(in_, dim=-1, k=2 * k) |
| 175 | + _roll_last(tau, dim)[unsolved] = tau_ |
| 176 | + _roll_last(support_size, dim)[unsolved] = ss_ |
| 177 | + |
| 178 | + return tau, support_size |
| 179 | + |
| 180 | + |
| 181 | +def _entmax_threshold_and_support(X, dim=-1, k=None): |
| 182 | + """Core computation for 1.5-entmax: optimal threshold and support size. |
| 183 | +
|
| 184 | + Parameters |
| 185 | + ---------- |
| 186 | + X : torch.Tensor |
| 187 | + The input tensor to compute thresholds over. |
| 188 | +
|
| 189 | + dim : int |
| 190 | + The dimension along which to apply 1.5-entmax. |
| 191 | +
|
| 192 | + k : int or None |
| 193 | + number of largest elements to partial-sort over. For optimal |
| 194 | + performance, should be slightly bigger than the expected number of |
| 195 | + nonzeros in the solution. If the solution is more than k-sparse, |
| 196 | + this function is recursively called with a 2*k schedule. |
| 197 | + If `None`, full sorting is performed from the beginning. |
| 198 | +
|
| 199 | + Returns |
| 200 | + ------- |
| 201 | + tau : torch.Tensor like `X`, with all but the `dim` dimension intact |
| 202 | + the threshold value for each vector |
| 203 | + support_size : torch LongTensor, shape like `tau` |
| 204 | + the number of nonzeros in each vector. |
| 205 | + """ |
| 206 | + |
| 207 | + if k is None or k >= X.shape[dim]: # do full sort |
| 208 | + Xsrt, _ = torch.sort(X, dim=dim, descending=True) |
| 209 | + else: |
| 210 | + Xsrt, _ = torch.topk(X, k=k, dim=dim) |
| 211 | + |
| 212 | + rho = _make_ix_like(Xsrt, dim) |
| 213 | + mean = Xsrt.cumsum(dim) / rho |
| 214 | + mean_sq = (Xsrt**2).cumsum(dim) / rho |
| 215 | + ss = rho * (mean_sq - mean**2) |
| 216 | + delta = (1 - ss) / rho |
| 217 | + |
| 218 | + # NOTE this is not exactly the same as in reference algo |
| 219 | + # Fortunately it seems the clamped values never wrongly |
| 220 | + # get selected by tau <= sorted_z. Prove this! |
| 221 | + delta_nz = torch.clamp(delta, 0) |
| 222 | + tau = mean - torch.sqrt(delta_nz) |
| 223 | + |
| 224 | + support_size = (tau <= Xsrt).sum(dim).unsqueeze(dim) |
| 225 | + tau_star = tau.gather(dim, support_size - 1) |
| 226 | + |
| 227 | + if k is not None and k < X.shape[dim]: |
| 228 | + unsolved = (support_size == k).squeeze(dim) |
| 229 | + |
| 230 | + if torch.any(unsolved): |
| 231 | + X_ = _roll_last(X, dim)[unsolved] |
| 232 | + tau_, ss_ = _entmax_threshold_and_support(X_, dim=-1, k=2 * k) |
| 233 | + _roll_last(tau_star, dim)[unsolved] = tau_ |
| 234 | + _roll_last(support_size, dim)[unsolved] = ss_ |
| 235 | + |
| 236 | + return tau_star, support_size |
| 237 | + |
| 238 | + |
| 239 | +class SparsemaxFunction(Function): |
| 240 | + @classmethod |
| 241 | + def forward(cls, ctx, X, dim=-1, k=None): |
| 242 | + ctx.dim = dim |
| 243 | + max_val, _ = X.max(dim=dim, keepdim=True) |
| 244 | + X = X - max_val # same numerical stability trick as softmax |
| 245 | + tau, supp_size = _sparsemax_threshold_and_support(X, dim=dim, k=k) |
| 246 | + output = torch.clamp(X - tau, min=0) |
| 247 | + ctx.save_for_backward(supp_size, output) |
| 248 | + return output |
| 249 | + |
| 250 | + @classmethod |
| 251 | + def backward(cls, ctx, grad_output): |
| 252 | + supp_size, output = ctx.saved_tensors |
| 253 | + dim = ctx.dim |
| 254 | + grad_input = grad_output.clone() |
| 255 | + grad_input[output == 0] = 0 |
| 256 | + |
| 257 | + v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze(dim) |
| 258 | + v_hat = v_hat.unsqueeze(dim) |
| 259 | + grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) |
| 260 | + return grad_input, None, None |
| 261 | + |
| 262 | + |
| 263 | +class Entmax15Function(Function): |
| 264 | + @classmethod |
| 265 | + def forward(cls, ctx, X, dim=0, k=None): |
| 266 | + ctx.dim = dim |
| 267 | + |
| 268 | + max_val, _ = X.max(dim=dim, keepdim=True) |
| 269 | + X = X - max_val # same numerical stability trick as for softmax |
| 270 | + X = X / 2 # divide by 2 to solve actual Entmax |
| 271 | + |
| 272 | + tau_star, _ = _entmax_threshold_and_support(X, dim=dim, k=k) |
| 273 | + |
| 274 | + Y = torch.clamp(X - tau_star, min=0) ** 2 |
| 275 | + ctx.save_for_backward(Y) |
| 276 | + return Y |
| 277 | + |
| 278 | + @classmethod |
| 279 | + def backward(cls, ctx, dY): |
| 280 | + (Y,) = ctx.saved_tensors |
| 281 | + gppr = Y.sqrt() # = 1 / g'' (Y) |
| 282 | + dX = dY * gppr |
| 283 | + q = dX.sum(ctx.dim) / gppr.sum(ctx.dim) |
| 284 | + q = q.unsqueeze(ctx.dim) |
| 285 | + dX -= q * gppr |
| 286 | + return dX, None, None |
| 287 | + |
| 288 | + |
| 289 | +def sparsemax(X, dim=-1, k=None): |
| 290 | + """sparsemax: normalizing sparse transform (a la softmax). |
| 291 | +
|
| 292 | + Solves the projection: |
| 293 | +
|
| 294 | + min_p ||x - p||_2 s.t. p >= 0, sum(p) == 1. |
| 295 | +
|
| 296 | + Parameters |
| 297 | + ---------- |
| 298 | + X : torch.Tensor |
| 299 | + The input tensor. |
| 300 | +
|
| 301 | + dim : int |
| 302 | + The dimension along which to apply sparsemax. |
| 303 | +
|
| 304 | + k : int or None |
| 305 | + number of largest elements to partial-sort over. For optimal |
| 306 | + performance, should be slightly bigger than the expected number of |
| 307 | + nonzeros in the solution. If the solution is more than k-sparse, |
| 308 | + this function is recursively called with a 2*k schedule. |
| 309 | + If `None`, full sorting is performed from the beginning. |
| 310 | +
|
| 311 | + Returns |
| 312 | + ------- |
| 313 | + P : torch tensor, same shape as X |
| 314 | + The projection result, such that P.sum(dim=dim) == 1 elementwise. |
| 315 | + """ |
| 316 | + |
| 317 | + return SparsemaxFunction.apply(X, dim, k) |
| 318 | + |
| 319 | + |
| 320 | +def entmax15(X, dim=-1, k=None): |
| 321 | + """1.5-entmax: normalizing sparse transform (a la softmax). |
| 322 | +
|
| 323 | + Solves the optimization problem: |
| 324 | +
|
| 325 | + max_p <x, p> - H_1.5(p) s.t. p >= 0, sum(p) == 1. |
| 326 | +
|
| 327 | + where H_1.5(p) is the Tsallis alpha-entropy with alpha=1.5. |
| 328 | +
|
| 329 | + Parameters |
| 330 | + ---------- |
| 331 | + X : torch.Tensor |
| 332 | + The input tensor. |
| 333 | +
|
| 334 | + dim : int |
| 335 | + The dimension along which to apply 1.5-entmax. |
| 336 | +
|
| 337 | + k : int or None |
| 338 | + number of largest elements to partial-sort over. For optimal |
| 339 | + performance, should be slightly bigger than the expected number of |
| 340 | + nonzeros in the solution. If the solution is more than k-sparse, |
| 341 | + this function is recursively called with a 2*k schedule. |
| 342 | + If `None`, full sorting is performed from the beginning. |
| 343 | +
|
| 344 | + Returns |
| 345 | + ------- |
| 346 | + P : torch tensor, same shape as X |
| 347 | + The projection result, such that P.sum(dim=dim) == 1 elementwise. |
| 348 | + """ |
| 349 | + |
| 350 | + return Entmax15Function.apply(X, dim, k) |
| 351 | + |
| 352 | + |
| 353 | +class Sparsemax(nn.Module): |
| 354 | + def __init__(self, dim=-1, k=None): |
| 355 | + """sparsemax: normalizing sparse transform (a la softmax). |
| 356 | +
|
| 357 | + Solves the projection: |
| 358 | +
|
| 359 | + min_p ||x - p||_2 s.t. p >= 0, sum(p) == 1. |
| 360 | +
|
| 361 | + Parameters |
| 362 | + ---------- |
| 363 | + dim : int |
| 364 | + The dimension along which to apply sparsemax. |
| 365 | +
|
| 366 | + k : int or None |
| 367 | + number of largest elements to partial-sort over. For optimal |
| 368 | + performance, should be slightly bigger than the expected number of |
| 369 | + nonzeros in the solution. If the solution is more than k-sparse, |
| 370 | + this function is recursively called with a 2*k schedule. |
| 371 | + If `None`, full sorting is performed from the beginning. |
| 372 | + """ |
| 373 | + self.dim = dim |
| 374 | + self.k = k |
| 375 | + super().__init__() |
| 376 | + |
| 377 | + def forward(self, X): |
| 378 | + return sparsemax(X, dim=self.dim, k=self.k) |
| 379 | + |
| 380 | + |
| 381 | +class Entmax15(nn.Module): |
| 382 | + def __init__(self, dim=-1, k=None): |
| 383 | + """1.5-entmax: normalizing sparse transform (a la softmax). |
| 384 | +
|
| 385 | + Solves the optimization problem: |
| 386 | +
|
| 387 | + max_p <x, p> - H_1.5(p) s.t. p >= 0, sum(p) == 1. |
| 388 | +
|
| 389 | + where H_1.5(p) is the Tsallis alpha-entropy with alpha=1.5. |
| 390 | +
|
| 391 | + Parameters |
| 392 | + ---------- |
| 393 | + dim : int |
| 394 | + The dimension along which to apply 1.5-entmax. |
| 395 | +
|
| 396 | + k : int or None |
| 397 | + number of largest elements to partial-sort over. For optimal |
| 398 | + performance, should be slightly bigger than the expected number of |
| 399 | + nonzeros in the solution. If the solution is more than k-sparse, |
| 400 | + this function is recursively called with a 2*k schedule. |
| 401 | + If `None`, full sorting is performed from the beginning. |
| 402 | + """ |
| 403 | + self.dim = dim |
| 404 | + self.k = k |
| 405 | + super().__init__() |
| 406 | + |
| 407 | + def forward(self, X): |
| 408 | + return entmax15(X, dim=self.dim, k=self.k) |
| 409 | + |
| 410 | + |
| 411 | +entmoid15 = Entmoid15.apply |
| 412 | +entmax15 = entmax15 |
| 413 | +sparsemax = sparsemax |
0 commit comments