Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 29 additions & 29 deletions glow/information_bottleneck/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
import torch
import glow.utils.hsic_utils as kernel_module
import numpy as np


class Estimator:
Expand Down Expand Up @@ -112,12 +113,13 @@ class EDGE(Estimator):

"""

def __init__(self, hash_function, gpu=True, **kwargs):
def __init__(self, hash_function, U=10, gpu=True, **kwargs):
super().__init__(gpu, **kwargs)
self.hash_function = hash_function
self.U = U

def g(self, x):
return x * torch.log(x) * (1 / math.log(10))
return x * math.log(x) * (1 / math.log(2))

def criterion(self, x, y):
"""
Expand All @@ -127,38 +129,36 @@ def criterion(self, x, y):
"""
h = hash_module.get(self.hash_function, self.params_dict)
num_samples = x.shape[0]
if "F" in self.params_dict.keys():
F = self.params_dict["F"] * num_samples
else:
raise Exception(
"Cannot find argument for number of nodes of dependency graph in EDGE estimator"
)

N = torch.zeros(F, 1)
M = torch.zeros(F, 1)
L = torch.zeros(F, F)
"""
edge_list = []
N = {}
M = {}
L = {}

for k, x_k in enumerate(x):
y_k = y[k]
i = h(x_k)
j = h(y_k)
N[i] = N[i] + 1
M[j] = M[j] + 1
L[i][j] = L[i][j] + 1
"""
N = torch.nn.functional.one_hot(N.long().view(-1, 1), F)
N = torch.sum(N, dim=0)

M = torch.nn.functional.one_hot(M.long().view(-1, 1), F)
M = torch.sum(M, dim=0)

n = (1 / num_samples) * N
m = (1 / num_samples) * M
temp_matrix = torch.mm(N, torch.transpose(M, 0, 1))
zero_matrix = torch.zeros(F, F)
w = torch.addcdiv(zero_matrix, num_samples, L, temp_matrix)
temp_matrix = torch.mm(n, torch.transpose(m, 0, 1))
mut_info = torch.sum(temp_matrix * g_hat(w))
if list(x_k.size()) == []:
i = i.item()
j = j.item()
else:
i = tuple(i.tolist())
j = tuple(j.tolist())

N[i] = (N[i] + 1.0) if i in N else 1.0
M[j] = (M[j] + 1.0) if j in M else 1.0
L[i,j] = (L[i,j] + 1.0) if (i,j) in L else 1.0
edge_list.append((i, j))

mut_info = 0.0

for i, j in edge_list:
wi = 1.0 * N[i] / num_samples
wj = 1.0 * M[j] / num_samples
wij = min(self.U, 1.0 * L[i,j] * num_samples / (N[i]*M[j]))
mut_info += wi * wj * self.g(wij)

return mut_info


Expand Down