Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 5b2a75e

Browse files
Update library/layers.yaml to follow TC coding style
1 parent 60fcf52 commit 5b2a75e

File tree

1 file changed

+73
-75
lines changed

1 file changed

+73
-75
lines changed

tensor_comprehensions/library/layers.yaml

Lines changed: 73 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -26,225 +26,223 @@
2626
- name: indexing
2727
lang: |
2828
def indexing(float(H, W) input, int32(L) index) -> (output) {{
29-
output(l, w) = input(index(l), w) where l in 0:{L}
29+
output(l, w) = input(index(l), w) where l in 0:{L}
3030
}}
3131
3232
- name: lookup_table
3333
lang: |
3434
def lookup_table(float(B, R) LUT, int32(B, N) I) -> (O) {
35-
O(b, n) +=! LUT(I(b, n), r)
35+
O(b, n) +=! LUT(I(b, n), r_r)
3636
}
3737
3838
- name: matmul
3939
lang: |
40-
def matmul(float(M,N) A, float(N,K) B) -> (output) {
41-
output(i, j) +=! A(i, kk) * B(kk, j)
40+
def matmul(float(M, K) A, float(K, N) B) -> (C) {
41+
C(m, n) +=! A(m, r_k) * B(r_k, n)
4242
}
4343
grad: |
44-
def matmul_grad(float(M,N) A, float(N,K), float(M,K) O_grad) -> (A_grad, B_grad){
45-
A_grad(i, j) +=! O_grad(i, kk) * B(j, kk)
46-
B_grad(i, j) +=! O_grad(kk, j) * A(kk, i)
44+
def matmul_grad(float(M,K) A, float(K,N) B, float(M,N) g_C) -> (g_A, g_B){
45+
g_A(m, k) +=! g_C( m, r_n) * B( k, r_n)
46+
g_B(k, n) +=! g_C(r_m, n) * A(r_m, k)
4747
}
4848
4949
- name: batch_matmul
5050
lang: |
5151
def batch_matmul(float(B, N, M) X, float(B, M, K) Y) -> (Z) {
52-
Z(b, n, k) +=! X(b, n, mm) * Y(b, mm, k)
52+
Z(b, n, k) +=! X(b, n, r_m) * Y(b, r_m, k)
5353
}
5454
5555
- name: transpose
5656
lang: |
5757
def transpose(float(N, C, H, W) I) -> (O) {
58-
O(c, n, w, h) = I(n, c, h, w)
58+
O(c, n, w, h) = I(n, c, h, w)
5959
}
6060
6161
- name: avgpool
6262
lang: |
6363
def avgpool(float(B, C, H, W) input) -> (output) {{
64-
output(b, c, h, w) +=! input(b, c, h * {sH} + kh, w * {sW} + kw) / ({kH} * {kW}) where kh in 0:{kH}, kw in 0:{kW}
64+
output(b, c, h, w) +=! input(b, c, h * {sH} + r_kh, w * {sW} + r_kw) / ({kH} * {kW})
65+
where r_kh in 0:{kH}, r_kw in 0:{kW}
6566
}}
6667
6768
- name: maxpool
6869
lang: |
6970
def maxpool(float(B, C, H, W) input) -> (output) {{
70-
output(b, c, h, w) max=! input(b, c, h * {sH} + kh, w * {sW} + kw) where kh in 0:{kH}, kw in 0:{kW}
71+
output(b, c, h, w) max=! input(b, c, h * {sH} + r_kh, w * {sW} + r_kw)
72+
where r_kh in 0:{kH}, r_kw in 0:{kW}
7173
}}
7274
7375
- name: scale
7476
lang: |
7577
def scale(float(M, N) I) -> (O) {{
76-
O(m, n) = I(m, n) * {s}
78+
O(m, n) = I(m, n) * {s}
7779
}}
7880
7981
- name: sigmoid
8082
lang: |
8183
def sigmoid(float(N, C, H, W) I) -> (O) {
82-
O(n, c, h, w) = 1 / (1 + exp(-I(n, c, h, w)))
84+
O(n, c, h, w) = 1 / (1 + exp(-I(n, c, h, w)))
8385
}
8486
8587
- name: softmax
8688
lang: |
8789
def softmax(float(N, D) I) -> (O, expsum, maxVal) {
88-
maxVal(n) max= I(n, d)
89-
expsum(n) +=! exp(I(n, d) - maxVal(n))
90-
O(n, d) = exp(I(n, d)) / expsum(n)
90+
maxVal(n) max= I(n, d)
91+
expsum(n) +=! exp(I(n, d) - maxVal(n))
92+
O(n, d) = exp(I(n, d)) / expsum(n)
9193
}
9294
9395
- name: Tanh
9496
lang: |
9597
def Tanh(float(M) I) -> (O) {
96-
O(m) = tanh(I(m))
98+
O(m) = tanh(I(m))
9799
}
98100
99101
- name: tensordot
100102
lang: |
101103
def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) {
102-
O(n, c1, c3, h, w) +=! I0(n, c1, c2, h, w) * I1(n, c2, c3, h, w)
104+
O(n, c1, c3, h, w) +=! I0(n, c1, r_c2, h, w) * I1(n, r_c2, c3, h, w)
103105
}
104106
105107
- name: fully_connected
106108
lang: |
107109
def fully_connected(float(B, M) I, float(N, M) W1, float(N) B1) -> (O1) {
108-
O1(b, n) +=! I(b, m) * W1(n, m)
109-
O1(b, n) = O1(b, n) + B1(n)
110+
O1(b, n) +=! I(b, r_m) * W1(n, r_m)
111+
O1(b, n) = O1(b, n) + B1(n)
110112
}
111113
112114
- name: relu
113115
lang: |
114116
def relu(float(B, M) I) -> (O1){
115-
O1(b, m) = fmax(I(b, m), 0)
117+
O1(b, m) = fmax(I(b, m), 0)
116118
}
117119
118120
- name: fcrelu
119121
lang: |
120-
def fcrelu(float(B, M) I, float(N, M) W1, float(N) B1) -> (O1){
121-
O1(b, n) +=! I(b, m) * W1(n, m)
122-
O1(b, n) = O1(b, n) + B1(n)
123-
O1(b, n) = fmax(O1(b, n), 0)
122+
def fcrelu(float(B,M) I, float(N,M) W1, float(N) B1) -> (O1){
123+
O1(b, n) +=! I(b, r_m) * W1(n, r_m)
124+
O1(b, n) = O1(b, n) + B1(n)
125+
O1(b, n) = fmax(O1(b, n), 0)
124126
}
125127
126128
- name: cast
127129
lang: |
128-
def cast(float(M, N) A) -> (int32(M, N) O1) {{
129-
O1(m, n) = int32(A(m, n) + {constant})
130+
def cast(float(M,N) A) -> (int32(M,N) O1) {{
131+
O1(m, n) = int32(A(m, n) + {constant})
130132
}}
131133
132134
- name: concat
133135
lang: |
134136
def concat(float(M, N) A, float(M, N) B) -> (O1) {
135-
O1(n, i, m) = i == 0 ? A(m, n) : B(m, n) where i in 0:2
137+
O1(n, i, m) = i == 0 ? A(m, n) : B(m, n) where i in 0:2
136138
}
137139
138140
- name: convolution
139141
lang: |
140142
def convolution(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(M) B) -> (O) {
141-
O(n, m, h, w) +=! I(n, c, h + kh, w + kw) * W1(m, c, kh, kw)
142-
O(n, m, h, w) = O(n, m, h, w) + B(m)
143+
O(n, m, h, w) +=! I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw)
144+
O(n, m, h, w) = O(n, m, h, w) + B(m)
143145
}
144146
145147
- name: convolution_strided
146148
lang: |
147149
def convolution_strided(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(M) B) -> (O) {{
148-
O(n, m, h, w) +=! I(n, c, {sh} * h + kh, {sw} * w + kw) * W1(m, c, kh, kw)
149-
O(n, m, h, w) = O(n, m, h, w) + B(m)
150+
O(n, m, h, w) +=! I(n, r_c, {sh} * h + r_kh, {sw} * w + r_kw) * W1(m, r_c, r_kh, r_kw)
151+
O(n, m, h, w) = O(n, m, h, w) + B(m)
150152
}}
151153
grad: |
152-
def convolution_strided_grad(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(N, M, H, W) O_grad)
153-
-> (I_grad, W1_grad) {{
154-
I_grad(n, c, h, w) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * W1(m, c, kh, kw)
155-
W1_grad(m, c, kh, kw) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * I(n, c, h, w)
154+
def convolution_strided_grad(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(N, M, H, W) g_O) -> (g_I, g_W1) {{
155+
g_I(n, c, h, w) +=! g_O(n, r_m, {sh} * h - r_kh, {sw} * w - r_kw) * W1(r_m, c, r_kh, r_kw)
156+
g_W1(m, c, kh, kw) +=! g_O(n, m, {sh} * r_h - kh, {sw} * r_w - kw) * I(r_n, c, r_h, r_w)
156157
}}
157158
158159
- name: group_convolution
159160
lang: |
160161
def group_convolution(float(N, G, C, H, W) I, float(G, F, C, KH, KW) W1, float(G, F) B) -> (O) {
161-
O(n, g, f, h, w) +=! I(n, g, c, h + kh, w + kw) * W1(g, f, c, kh, kw)
162-
O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f)
162+
O(n, g, f, h, w) +=! I(n, g, r_c, h + r_kh, w + r_kw) * W1(g, f, r_c, r_kh, r_kw)
163+
O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f)
163164
}
164165
165166
- name: group_convolution_strided
166167
lang: |
167-
def group_convolution_strided(float(N, G, C, H, W) I, float(G, F, C, KH, KW) W1, float(G, F) B) -> (O)
168-
{{
169-
O(n, g, f, h, w) +=! I(n, g, c, {sh} * h + kh, {sw} * w + kw) * W1(g, f, c, kh, kw)
170-
O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f)
168+
def group_convolution_strided(float(N, G, C, H, W) I, float(G, F, C, KH, KW) W1, float(G, F) B) -> (O) {{
169+
O(n, g, f, h, w) +=! I(n, g, r_c, {sh} * h + r_kh, {sw} * w + r_kw) * W1(g, f, r_c, r_kh, r_kw)
170+
O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f)
171171
}}
172172
173173
- name: copy2D
174174
lang: |
175175
def copy2D(float(M, N) I) -> (O) {
176-
O(i, j) = I(i, j)
176+
O(m, n) = I(m, n)
177177
}
178178
179179
- name: copy
180180
lang: |
181181
def copy(float({dimParams}) I) -> (O) {{
182-
O({dimIndices}) = I({dimIndices})
182+
O({dimIndices}) = I({dimIndices})
183183
}}
184184
185185
- name: cosine
186186
lang: |
187187
def cosine(float(M) I) -> (O) {
188-
O(i) = cos(I(i))
188+
O(i) = cos(I(i))
189189
}
190190
191191
- name: cosine_similarity
192192
lang: |
193193
def cosine_similarity(float(M, N) I1, float(M, N) I2) -> (O, sumI1, sumI2) {{
194-
sumI1(m) +=! I1(m, n) * I1(m, n)
195-
sumI2(m) +=! I2(m, n) * I2(m, n)
196-
O(m) +=! (I1(m, n) * I2(m, n)) / fmax(rsqrt(sumI1(m)) * sqrt(sumI2(m)), {eps})
194+
sumI1(m) +=! I1(m, r_n) * I1(m, r_n)
195+
sumI2(m) +=! I2(m, r_n) * I2(m, r_n)
196+
O(m) +=! (I1(m, r_n) * I2(m, r_n)) / fmax(rsqrt(sumI1(m)) * sqrt(sumI2(m)), {eps})
197197
}}
198198
199199
- name: add
200200
lang: |
201201
def add(float(N) A, float(N) B) -> (output) {
202-
output(i) = A(i) + B(i)
202+
output(n) = A(n) + B(n)
203203
}
204204
205205
- name: abs
206206
lang: |
207207
def abs(float(M, N) A) -> (O1) {
208-
O1(m, n) = fabs(A(m, n))
208+
O1(m, n) = fabs(A(m, n))
209209
}
210210
211211
- name: layernorm
212212
lang: |
213213
def layernorm(float(T, B, C) I) -> (O, mean, centered, var) {{
214-
mean(t, b) +=! I(t, b, c) / C
215-
centered(t, b, c) = I(t, b, c) - mean(t, b)
216-
var(t, b) +=! centered(t, b, c) * centered(t, b, c)
217-
var(t, b) = (var(t, b) + {eps}) / C
218-
O(t, b, c) = centered(t, b, c) / rsqrt(var(t, b))
214+
mean(t, b) +=! I(t, b, c) / C
215+
centered(t, b, c) = I(t, b, c) - mean(t, b)
216+
var(t, b) +=! centered(t, b, c) * centered(t, b, c)
217+
var(t, b) = (var(t, b) + {eps}) / C
218+
O(t, b, c) = centered(t, b, c) / rsqrt(var(t, b))
219219
}}
220220
221221
- name: batchnorm
222222
lang: |
223-
def batchnorm(float(N, C, H, W) I, float(C) rMeanIn, float(C) rVarIn)
223+
def batchnorm(float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn)
224224
-> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut)
225225
{{
226-
mean(c) +=! I(nn, c, hh, ww)
227-
mean(c) = mean(c) / (N * H * W)
228-
rMeanOut(c) = (1 - {momentum}) * rMeanIn(c) + {momentum} * mean(c)
229-
centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c)
230-
variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w)
231-
expectedVariance(c) +=! (variance(n, c, h, w) + {eps}) / (N * H * W)
232-
rVarOut(c) = rsqrt(
233-
(1 - {momentum}) * rVarIn(c) + {momentum} * expectedVariance(c))
234-
O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c)
235-
normalizedOut(n, c, h, w) = O(n, c, h, w)
226+
mean(c) +=! I(nn, c, hh, ww)
227+
mean(c) = mean(c) / (N * H * W)
228+
rMeanOut(c) = (1 - {momentum}) * rMeanIn(c) + {momentum} * mean(c)
229+
centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c)
230+
variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w)
231+
expectedVariance(c) +=! (variance(n, c, h, w) + {eps}) / (N * H * W)
232+
rVarOut(c) = rsqrt((1 - {momentum}) * rVarIn(c) + {momentum} * expectedVariance(c))
233+
O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c)
234+
normalizedOut(n, c, h, w) = O(n, c, h, w)
236235
}}
237236
238237
- name: small_mobilenet
239238
lang: |
240239
def small_mobilenet(float(C1, H, W) I, float(C1, KH1, KW1) W1, float(C1) B1, float(C2, C1) W2, float(C2) B2)
241-
-> (O1, O2)
242-
{
243-
O1(c1, h, w) +=! I(c1, h + kh, w + kw) * W1(c1, kh, kw)
244-
O1(c1, h, w) = O1(c1, h, w) + B1(c1)
245-
O1(c1, h, w) = fmax(O1(c1, h, w), 0)
246-
247-
O2(c2, h, w) +=! O1(c1, h, w) * W2(c2, c1)
248-
O2(c2, h, w) = O2(c2, h, w) + B2(c2)
249-
O2(c2, h, w) = fmax(O2(c2, h, w), 0)
240+
-> (O1, O2) {
241+
O1(c1, h, w) +=! I(c1, h + r_kh, w + r_kw) * W1(c1, r_kh, r_kw)
242+
O1(c1, h, w) = O1(c1, h, w) + B1(c1)
243+
O1(c1, h, w) = fmax(O1(c1, h, w), 0)
244+
245+
O2(c2, h, w) +=! O1(r_c1, h, w) * W2(c2, r_c1)
246+
O2(c2, h, w) = O2( c2, h, w) + B2(c2)
247+
O2(c2, h, w) = fmax(O2(c2, h, w), 0)
250248
}

0 commit comments

Comments
 (0)