@@ -55,7 +55,7 @@ def _mobius_add(self, x, y):
5555 x_y = tf .reduce_sum (x * y , axis = - 1 , keepdims = True )
5656 k = tf .cast (self .k , x .dtype )
5757 return ((1 + 2 * k * x_y + k * y_2 ) * x + (1 - k * x_2 ) * y ) / (
58- 1 + 2 * k * x_y + k ** 2 * x_2 * y_2
58+ 1 + 2 * k * x_y + k ** 2 * x_2 * y_2
5959 )
6060
6161 def _mobius_scal_mul (self , x , r ):
@@ -91,15 +91,15 @@ def _lambda(self, x, keepdims=False):
9191
9292 def inner (self , x , u , v , keepdims = False ):
9393 lambda_x = self ._lambda (x , keepdims = keepdims )
94- return tf .reduce_sum (u * v , axis = - 1 , keepdims = keepdims ) * lambda_x ** 2
94+ return tf .reduce_sum (u * v , axis = - 1 , keepdims = keepdims ) * lambda_x ** 2
9595
9696 def norm (self , x , u , keepdims = False ):
9797 lambda_x = self ._lambda (x , keepdims = keepdims )
9898 return tf .linalg .norm (u , axis = - 1 , keepdims = keepdims ) * lambda_x
9999
100100 def proju (self , x , u ):
101101 lambda_x = self ._lambda (x , keepdims = True )
102- return u / lambda_x ** 2
102+ return u / lambda_x ** 2
103103
104104 def projx (self , x ):
105105 sqrt_k = tf .math .sqrt (tf .cast (self .k , x .dtype ))
0 commit comments