@@ -32,8 +32,8 @@ Average pooling
3232
3333.. code ::
3434
35- def avgpool(float(B, C, H, W) input ) -> (output ) {{
36- output (b, c, h, w) +=! input (b, c, h * {sH} + r_kh, w * {sW} + r_kw) / ({kH} * {kW})
35+ def avgpool(float(B, C, H, W) Input ) -> (Output ) {{
36+ Output (b, c, h, w) +=! Input (b, c, h * {sH} + r_kh, w * {sW} + r_kw) / ({kH} * {kW})
3737 where r_kh in 0:{kH}, r_kw in 0:{kW}
3838 }}
3939
@@ -43,8 +43,8 @@ Max pooling
4343
4444.. code ::
4545
46- def maxpool(float(B, C, H, W) input ) -> (output ) {{
47- output (b, c, h, w) max=! input (b, c, h * {sH} + r_kh, w * {sW} + r_kw)
46+ def maxpool(float(B, C, H, W) Input ) -> (Output ) {{
47+ Output (b, c, h, w) max=! Input (b, c, h * {sH} + r_kh, w * {sW} + r_kw)
4848 where r_kh in 0:{kH}, r_kw in 0:{kW}
4949 }}
5050
@@ -76,9 +76,9 @@ Strided Convolution Gradient
7676
7777.. code ::
7878
79- def convolution_grad(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(N, M, H, W) g_O ) -> (g_I, g_W1 ) {{
80- g_I (n, c, h, w) +=! g_O (n, r_m, {sh} * h - r_kh, {sw} * w - r_kw) * W1(r_m, c, r_kh, r_kw)
81- g_W1 (m, c, kh, kw) +=! g_O (n, m, {sh} * r_h - kh, {sw} * r_w - kw) * I(r_n, c, r_h, r_w)
79+ def convolution_grad(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(N, M, H, W) d_O ) -> (d_I, d_W1 ) {{
80+ d_I (n, c, h, w) +=! d_O (n, r_m, {sh} * h - r_kh, {sw} * w - r_kw) * W1(r_m, c, r_kh, r_kw)
81+ d_W1 (m, c, kh, kw) +=! d_O (n, m, {sh} * r_h - kh, {sw} * r_w - kw) * I(r_n, c, r_h, r_w)
8282 }}
8383
8484 Simple Group Convolution
@@ -140,11 +140,11 @@ Softmax
140140
141141.. code ::
142142
143- def softmax(float(N, D) I) -> (O, maxVal, expDistance, expSum ) {
144- maxVal (n) max=! I(n, d)
145- expDistance (n, d) = exp(I(n, d) - maxVal (n))
146- expSum (n) +=! expDistance (n, d)
147- O(n, d) = expDistance (n, d) / expSum (n)
143+ def softmax(float(N, D) I) -> (O, MaxVal, ExpDistance, ExpSum ) {
144+ MaxVal (n) max=! I(n, d)
145+ ExpDistance (n, d) = exp(I(n, d) - MaxVal (n))
146+ ExpSum (n) +=! ExpDistance (n, d)
147+ O(n, d) = ExpDistance (n, d) / ExpSum (n)
148148 }
149149
150150 Tanh
@@ -191,9 +191,9 @@ Matmul Gradient
191191
192192.. code ::
193193
194- def matmul_bw(float(M,K) A, float(K,N) B, float(M,N) g_C ) -> (g_A, g_B ){
195- g_A (m, k) +=! g_C ( m, r_n) * B( k, r_n)
196- g_B (k, n) +=! g_C (r_m, n) * A(r_m, k)
194+ def matmul_bw(float(M,K) A, float(K,N) B, float(M,N) d_C ) -> (d_A, d_B ){
195+ d_A (m, k) +=! d_C ( m, r_n) * B( k, r_n)
196+ d_B (k, n) +=! d_C (r_m, n) * A(r_m, k)
197197 }
198198
199199 Batch Matmul
219219
220220.. code ::
221221
222- def add(float(N) A, float(N) B) -> (output ) {
223- output (n) = A(n) + B(n)
222+ def add(float(N) A, float(N) B) -> (Output ) {
223+ Output (n) = A(n) + B(n)
224224 }
225225
226226 Tensor Operations
@@ -231,8 +231,8 @@ Indexing
231231
232232.. code ::
233233
234- def indexing(float(H, W) input , int32(L) index ) -> (output ) {{
235- output (l, w) = input(index (l), w)
234+ def indexing(float(H, W) Input , int32(L) Index ) -> (Output ) {{
235+ Output (l, w) = Input(Index (l), w)
236236 }}
237237
238238 Lookup Table
@@ -327,17 +327,17 @@ Batch Normalization
327327
328328.. code ::
329329
330- def batchnorm(float(N,C,H,W) I, float(C) rMeanIn , float(C) rVarIn )
331- -> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance , normalizedOut)
330+ def batchnorm(float(N,C,H,W) I, float(C) RMeanIn , float(C) RVarIn )
331+ -> (O, RMeanOut, RVarOut, Mean, Centered, Variance, ExpectedVariance , normalizedOut)
332332 {{
333- mean (c) +=! I(nn, c, hh, ww)
334- mean (c) = mean (c) / (N * H * W)
335- rMeanOut (c) = (1 - {momentum}) * rMeanIn (c) + {momentum} * mean (c)
336- centered (n, c, h, w) = I(n, c, h, w) - rMeanOut (c)
337- variance (n, c, h, w) = centered (n, c, h, w) * centered (n, c, h, w)
338- expectedVariance (c) +=! (variance (n, c, h, w) + {eps}) / (N * H * W)
339- rVarOut (c) = rsqrt((1 - {momentum}) * rVarIn (c) + {momentum} * expectedVariance (c))
340- O(n, c, h, w) = centered (n, c, h, w) * rVarOut (c)
333+ Mean (c) +=! I(nn, c, hh, ww)
334+ Mean (c) = Mean (c) / (N * H * W)
335+ RMeanOut (c) = (1 - {momentum}) * RMeanIn (c) + {momentum} * Mean (c)
336+ Centered (n, c, h, w) = I(n, c, h, w) - RMeanOut (c)
337+ Variance (n, c, h, w) = Centered (n, c, h, w) * Centered (n, c, h, w)
338+ ExpectedVariance (c) +=! (Variance (n, c, h, w) + {eps}) / (N * H * W)
339+ RVarOut (c) = rsqrt((1 - {momentum}) * RVarIn (c) + {momentum} * ExpectedVariance (c))
340+ O(n, c, h, w) = Centered (n, c, h, w) * RVarOut (c)
341341 normalizedOut(n, c, h, w) = O(n, c, h, w)
342342 }}
343343
@@ -346,12 +346,12 @@ Layer Normalization
346346
347347.. code ::
348348
349- def layernorm(float(T, B, C) I) -> (O, mean, centered, var ) {{
350- mean (t, b) +=! I(t, b, c) / C
351- centered (t, b, c) = I(t, b, c) - mean (t, b)
352- var (t, b) +=! centered (t, b, c) * centered (t, b, c)
353- var (t, b) = (var (t, b) + {eps}) / C
354- O(t, b, c) = centered (t, b, c) / rsqrt(var (t, b))
349+ def layernorm(float(T, B, C) I) -> (O, Mean, Centered, Var ) {{
350+ Mean (t, b) +=! I(t, b, c) / C
351+ Centered (t, b, c) = I(t, b, c) - Mean (t, b)
352+ Var (t, b) +=! Centered (t, b, c) * Centered (t, b, c)
353+ Var (t, b) = (Var (t, b) + {eps}) / C
354+ O(t, b, c) = Centered (t, b, c) / rsqrt(Var (t, b))
355355 }}
356356
357357 Distance Functions
@@ -362,10 +362,10 @@ Cosine Similarity
362362
363363.. code ::
364364
365- def cosine_similarity(float(M, N) I1, float(M, N) I2) -> (O, sumI1, sumI2 ) {{
366- sumI1 (m) +=! I1(m, n) * I1(m, n)
367- sumI2 (m) +=! I2(m, n) * I2(m, n)
368- O(m) +=! (I1(m, n) * I2(m, n)) / fmax(rsqrt(sumI1 (m)) * sqrt(sumI2 (m)), {eps})
365+ def cosine_similarity(float(M, N) I1, float(M, N) I2) -> (O, SumI1, SumI2 ) {{
366+ SumI1 (m) +=! I1(m, n) * I1(m, n)
367+ SumI2 (m) +=! I2(m, n) * I2(m, n)
368+ O(m) +=! (I1(m, n) * I2(m, n)) / fmax(rsqrt(SumI1 (m)) * sqrt(SumI2 (m)), {eps})
369369 }}
370370
371371 What operations can not be expressed
0 commit comments