|
26 | 26 | - name: indexing |
27 | 27 | lang: | |
28 | 28 | def indexing(float(H, W) input, int32(L) index) -> (output) {{ |
29 | | - output(l, w) = input(index(l), w) where l in 0:{L} |
| 29 | + output(l, w) = input(index(l), w) where l in 0:{L} |
30 | 30 | }} |
31 | 31 |
|
32 | 32 | - name: lookup_table |
33 | 33 | lang: | |
34 | 34 | def lookup_table(float(B, R) LUT, int32(B, N) I) -> (O) { |
35 | | - O(b, n) +=! LUT(I(b, n), r) |
| 35 | + O(b, n) +=! LUT(I(b, n), r_r) |
36 | 36 | } |
37 | 37 |
|
38 | 38 | - name: matmul |
39 | 39 | lang: | |
40 | | - def matmul(float(M,N) A, float(N,K) B) -> (output) { |
41 | | - output(i, j) +=! A(i, kk) * B(kk, j) |
| 40 | + def matmul(float(M, K) A, float(K, N) B) -> (C) { |
| 41 | + C(m, n) +=! A(m, r_k) * B(r_k, n) |
42 | 42 | } |
43 | 43 | grad: | |
44 | | - def matmul_grad(float(M,N) A, float(N,K), float(M,K) O_grad) -> (A_grad, B_grad){ |
45 | | - A_grad(i, j) +=! O_grad(i, kk) * B(j, kk) |
46 | | - B_grad(i, j) +=! O_grad(kk, j) * A(kk, i) |
| 44 | + def matmul_grad(float(M,K) A, float(K,N) B, float(M,N) g_C) -> (g_A, g_B){ |
| 45 | + g_A(m, k) +=! g_C( m, r_n) * B( k, r_n) |
| 46 | + g_B(k, n) +=! g_C(r_m, n) * A(r_m, k) |
47 | 47 | } |
48 | 48 |
|
49 | 49 | - name: batch_matmul |
50 | 50 | lang: | |
51 | 51 | def batch_matmul(float(B, N, M) X, float(B, M, K) Y) -> (Z) { |
52 | | - Z(b, n, k) +=! X(b, n, mm) * Y(b, mm, k) |
| 52 | + Z(b, n, k) +=! X(b, n, r_m) * Y(b, r_m, k) |
53 | 53 | } |
54 | 54 |
|
55 | 55 | - name: transpose |
56 | 56 | lang: | |
57 | 57 | def transpose(float(N, C, H, W) I) -> (O) { |
58 | | - O(c, n, w, h) = I(n, c, h, w) |
| 58 | + O(c, n, w, h) = I(n, c, h, w) |
59 | 59 | } |
60 | 60 |
|
61 | 61 | - name: avgpool |
62 | 62 | lang: | |
63 | 63 | def avgpool(float(B, C, H, W) input) -> (output) {{ |
64 | | - output(b, c, h, w) +=! input(b, c, h * {sH} + kh, w * {sW} + kw) / ({kH} * {kW}) where kh in 0:{kH}, kw in 0:{kW} |
| 64 | + output(b, c, h, w) +=! input(b, c, h * {sH} + r_kh, w * {sW} + r_kw) / ({kH} * {kW}) |
| 65 | + where r_kh in 0:{kH}, r_kw in 0:{kW} |
65 | 66 | }} |
66 | 67 |
|
67 | 68 | - name: maxpool |
68 | 69 | lang: | |
69 | 70 | def maxpool(float(B, C, H, W) input) -> (output) {{ |
70 | | - output(b, c, h, w) max=! input(b, c, h * {sH} + kh, w * {sW} + kw) where kh in 0:{kH}, kw in 0:{kW} |
| 71 | + output(b, c, h, w) max=! input(b, c, h * {sH} + r_kh, w * {sW} + r_kw) |
| 72 | + where r_kh in 0:{kH}, r_kw in 0:{kW} |
71 | 73 | }} |
72 | 74 |
|
73 | 75 | - name: scale |
74 | 76 | lang: | |
75 | 77 | def scale(float(M, N) I) -> (O) {{ |
76 | | - O(m, n) = I(m, n) * {s} |
| 78 | + O(m, n) = I(m, n) * {s} |
77 | 79 | }} |
78 | 80 |
|
79 | 81 | - name: sigmoid |
80 | 82 | lang: | |
81 | 83 | def sigmoid(float(N, C, H, W) I) -> (O) { |
82 | | - O(n, c, h, w) = 1 / (1 + exp(-I(n, c, h, w))) |
| 84 | + O(n, c, h, w) = 1 / (1 + exp(-I(n, c, h, w))) |
83 | 85 | } |
84 | 86 |
|
85 | 87 | - name: softmax |
86 | 88 | lang: | |
87 | 89 | def softmax(float(N, D) I) -> (O, expsum, maxVal) { |
88 | | - maxVal(n) max= I(n, d) |
89 | | - expsum(n) +=! exp(I(n, d) - maxVal(n)) |
90 | | - O(n, d) = exp(I(n, d)) / expsum(n) |
| 90 | + maxVal(n) max= I(n, d) |
| 91 | + expsum(n) +=! exp(I(n, d) - maxVal(n)) |
| 92 | + O(n, d) = exp(I(n, d)) / expsum(n) |
91 | 93 | } |
92 | 94 |
|
93 | 95 | - name: Tanh |
94 | 96 | lang: | |
95 | 97 | def Tanh(float(M) I) -> (O) { |
96 | | - O(m) = tanh(I(m)) |
| 98 | + O(m) = tanh(I(m)) |
97 | 99 | } |
98 | 100 |
|
99 | 101 | - name: tensordot |
100 | 102 | lang: | |
101 | 103 | def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) { |
102 | | - O(n, c1, c3, h, w) +=! I0(n, c1, c2, h, w) * I1(n, c2, c3, h, w) |
| 104 | + O(n, c1, c3, h, w) +=! I0(n, c1, r_c2, h, w) * I1(n, r_c2, c3, h, w) |
103 | 105 | } |
104 | 106 |
|
105 | 107 | - name: fully_connected |
106 | 108 | lang: | |
107 | 109 | def fully_connected(float(B, M) I, float(N, M) W1, float(N) B1) -> (O1) { |
108 | | - O1(b, n) +=! I(b, m) * W1(n, m) |
109 | | - O1(b, n) = O1(b, n) + B1(n) |
| 110 | + O1(b, n) +=! I(b, r_m) * W1(n, r_m) |
| 111 | + O1(b, n) = O1(b, n) + B1(n) |
110 | 112 | } |
111 | 113 |
|
112 | 114 | - name: relu |
113 | 115 | lang: | |
114 | 116 | def relu(float(B, M) I) -> (O1){ |
115 | | - O1(b, m) = fmax(I(b, m), 0) |
| 117 | + O1(b, m) = fmax(I(b, m), 0) |
116 | 118 | } |
117 | 119 |
|
118 | 120 | - name: fcrelu |
119 | 121 | lang: | |
120 | | - def fcrelu(float(B, M) I, float(N, M) W1, float(N) B1) -> (O1){ |
121 | | - O1(b, n) +=! I(b, m) * W1(n, m) |
122 | | - O1(b, n) = O1(b, n) + B1(n) |
123 | | - O1(b, n) = fmax(O1(b, n), 0) |
| 122 | + def fcrelu(float(B,M) I, float(N,M) W1, float(N) B1) -> (O1){ |
| 123 | + O1(b, n) +=! I(b, r_m) * W1(n, r_m) |
| 124 | + O1(b, n) = O1(b, n) + B1(n) |
| 125 | + O1(b, n) = fmax(O1(b, n), 0) |
124 | 126 | } |
125 | 127 |
|
126 | 128 | - name: cast |
127 | 129 | lang: | |
128 | | - def cast(float(M, N) A) -> (int32(M, N) O1) {{ |
129 | | - O1(m, n) = int32(A(m, n) + {constant}) |
| 130 | + def cast(float(M,N) A) -> (int32(M,N) O1) {{ |
| 131 | + O1(m, n) = int32(A(m, n) + {constant}) |
130 | 132 | }} |
131 | 133 |
|
132 | 134 | - name: concat |
133 | 135 | lang: | |
134 | 136 | def concat(float(M, N) A, float(M, N) B) -> (O1) { |
135 | | - O1(n, i, m) = i == 0 ? A(m, n) : B(m, n) where i in 0:2 |
| 137 | + O1(n, i, m) = i == 0 ? A(m, n) : B(m, n) where i in 0:2 |
136 | 138 | } |
137 | 139 |
|
138 | 140 | - name: convolution |
139 | 141 | lang: | |
140 | 142 | def convolution(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(M) B) -> (O) { |
141 | | - O(n, m, h, w) +=! I(n, c, h + kh, w + kw) * W1(m, c, kh, kw) |
142 | | - O(n, m, h, w) = O(n, m, h, w) + B(m) |
| 143 | + O(n, m, h, w) +=! I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw) |
| 144 | + O(n, m, h, w) = O(n, m, h, w) + B(m) |
143 | 145 | } |
144 | 146 |
|
145 | 147 | - name: convolution_strided |
146 | 148 | lang: | |
147 | 149 | def convolution_strided(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(M) B) -> (O) {{ |
148 | | - O(n, m, h, w) +=! I(n, c, {sh} * h + kh, {sw} * w + kw) * W1(m, c, kh, kw) |
149 | | - O(n, m, h, w) = O(n, m, h, w) + B(m) |
| 150 | + O(n, m, h, w) +=! I(n, r_c, {sh} * h + r_kh, {sw} * w + r_kw) * W1(m, r_c, r_kh, r_kw) |
| 151 | + O(n, m, h, w) = O(n, m, h, w) + B(m) |
150 | 152 | }} |
151 | 153 | grad: | |
152 | | - def convolution_strided_grad(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(N, M, H, W) O_grad) |
153 | | - -> (I_grad, W1_grad) {{ |
154 | | - I_grad(n, c, h, w) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * W1(m, c, kh, kw) |
155 | | - W1_grad(m, c, kh, kw) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * I(n, c, h, w) |
| 154 | + def convolution_strided_grad(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(N, M, H, W) g_O) -> (g_I, g_W1) {{ |
| 155 | + g_I(n, c, h, w) +=! g_O(n, r_m, {sh} * h - r_kh, {sw} * w - r_kw) * W1(r_m, c, r_kh, r_kw) |
| 156 | + g_W1(m, c, kh, kw) +=! g_O(n, m, {sh} * r_h - kh, {sw} * r_w - kw) * I(r_n, c, r_h, r_w) |
156 | 157 | }} |
157 | 158 |
|
158 | 159 | - name: group_convolution |
159 | 160 | lang: | |
160 | 161 | def group_convolution(float(N, G, C, H, W) I, float(G, F, C, KH, KW) W1, float(G, F) B) -> (O) { |
161 | | - O(n, g, f, h, w) +=! I(n, g, c, h + kh, w + kw) * W1(g, f, c, kh, kw) |
162 | | - O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) |
| 162 | + O(n, g, f, h, w) +=! I(n, g, r_c, h + r_kh, w + r_kw) * W1(g, f, r_c, r_kh, r_kw) |
| 163 | + O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) |
163 | 164 | } |
164 | 165 |
|
165 | 166 | - name: group_convolution_strided |
166 | 167 | lang: | |
167 | | - def group_convolution_strided(float(N, G, C, H, W) I, float(G, F, C, KH, KW) W1, float(G, F) B) -> (O) |
168 | | - {{ |
169 | | - O(n, g, f, h, w) +=! I(n, g, c, {sh} * h + kh, {sw} * w + kw) * W1(g, f, c, kh, kw) |
170 | | - O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) |
| 168 | + def group_convolution_strided(float(N, G, C, H, W) I, float(G, F, C, KH, KW) W1, float(G, F) B) -> (O) {{ |
| 169 | + O(n, g, f, h, w) +=! I(n, g, r_c, {sh} * h + r_kh, {sw} * w + r_kw) * W1(g, f, r_c, r_kh, r_kw) |
| 170 | + O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) |
171 | 171 | }} |
172 | 172 |
|
173 | 173 | - name: copy2D |
174 | 174 | lang: | |
175 | 175 | def copy2D(float(M, N) I) -> (O) { |
176 | | - O(i, j) = I(i, j) |
| 176 | + O(m, n) = I(m, n) |
177 | 177 | } |
178 | 178 |
|
179 | 179 | - name: copy |
180 | 180 | lang: | |
181 | 181 | def copy(float({dimParams}) I) -> (O) {{ |
182 | | - O({dimIndices}) = I({dimIndices}) |
| 182 | + O({dimIndices}) = I({dimIndices}) |
183 | 183 | }} |
184 | 184 |
|
185 | 185 | - name: cosine |
186 | 186 | lang: | |
187 | 187 | def cosine(float(M) I) -> (O) { |
188 | | - O(i) = cos(I(i)) |
| 188 | + O(i) = cos(I(i)) |
189 | 189 | } |
190 | 190 |
|
191 | 191 | - name: cosine_similarity |
192 | 192 | lang: | |
193 | 193 | def cosine_similarity(float(M, N) I1, float(M, N) I2) -> (O, sumI1, sumI2) {{ |
194 | | - sumI1(m) +=! I1(m, n) * I1(m, n) |
195 | | - sumI2(m) +=! I2(m, n) * I2(m, n) |
196 | | - O(m) +=! (I1(m, n) * I2(m, n)) / fmax(rsqrt(sumI1(m)) * sqrt(sumI2(m)), {eps}) |
| 194 | + sumI1(m) +=! I1(m, r_n) * I1(m, r_n) |
| 195 | + sumI2(m) +=! I2(m, r_n) * I2(m, r_n) |
| 196 | + O(m) +=! (I1(m, r_n) * I2(m, r_n)) / fmax(rsqrt(sumI1(m)) * sqrt(sumI2(m)), {eps}) |
197 | 197 | }} |
198 | 198 |
|
199 | 199 | - name: add |
200 | 200 | lang: | |
201 | 201 | def add(float(N) A, float(N) B) -> (output) { |
202 | | - output(i) = A(i) + B(i) |
| 202 | + output(n) = A(n) + B(n) |
203 | 203 | } |
204 | 204 |
|
205 | 205 | - name: abs |
206 | 206 | lang: | |
207 | 207 | def abs(float(M, N) A) -> (O1) { |
208 | | - O1(m, n) = fabs(A(m, n)) |
| 208 | + O1(m, n) = fabs(A(m, n)) |
209 | 209 | } |
210 | 210 |
|
211 | 211 | - name: layernorm |
212 | 212 | lang: | |
213 | 213 | def layernorm(float(T, B, C) I) -> (O, mean, centered, var) {{ |
214 | | - mean(t, b) +=! I(t, b, c) / C |
215 | | - centered(t, b, c) = I(t, b, c) - mean(t, b) |
216 | | - var(t, b) +=! centered(t, b, c) * centered(t, b, c) |
217 | | - var(t, b) = (var(t, b) + {eps}) / C |
218 | | - O(t, b, c) = centered(t, b, c) / rsqrt(var(t, b)) |
| 214 | + mean(t, b) +=! I(t, b, c) / C |
| 215 | + centered(t, b, c) = I(t, b, c) - mean(t, b) |
| 216 | + var(t, b) +=! centered(t, b, c) * centered(t, b, c) |
| 217 | + var(t, b) = (var(t, b) + {eps}) / C |
| 218 | + O(t, b, c) = centered(t, b, c) / rsqrt(var(t, b)) |
219 | 219 | }} |
220 | 220 |
|
221 | 221 | - name: batchnorm |
222 | 222 | lang: | |
223 | | - def batchnorm(float(N, C, H, W) I, float(C) rMeanIn, float(C) rVarIn) |
| 223 | + def batchnorm(float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn) |
224 | 224 | -> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut) |
225 | 225 | {{ |
226 | | - mean(c) +=! I(nn, c, hh, ww) |
227 | | - mean(c) = mean(c) / (N * H * W) |
228 | | - rMeanOut(c) = (1 - {momentum}) * rMeanIn(c) + {momentum} * mean(c) |
229 | | - centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c) |
230 | | - variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w) |
231 | | - expectedVariance(c) +=! (variance(n, c, h, w) + {eps}) / (N * H * W) |
232 | | - rVarOut(c) = rsqrt( |
233 | | - (1 - {momentum}) * rVarIn(c) + {momentum} * expectedVariance(c)) |
234 | | - O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c) |
235 | | - normalizedOut(n, c, h, w) = O(n, c, h, w) |
| 226 | + mean(c) +=! I(nn, c, hh, ww) |
| 227 | + mean(c) = mean(c) / (N * H * W) |
| 228 | + rMeanOut(c) = (1 - {momentum}) * rMeanIn(c) + {momentum} * mean(c) |
| 229 | + centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c) |
| 230 | + variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w) |
| 231 | + expectedVariance(c) +=! (variance(n, c, h, w) + {eps}) / (N * H * W) |
| 232 | + rVarOut(c) = rsqrt((1 - {momentum}) * rVarIn(c) + {momentum} * expectedVariance(c)) |
| 233 | + O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c) |
| 234 | + normalizedOut(n, c, h, w) = O(n, c, h, w) |
236 | 235 | }} |
237 | 236 |
|
238 | 237 | - name: small_mobilenet |
239 | 238 | lang: | |
240 | 239 | def small_mobilenet(float(C1, H, W) I, float(C1, KH1, KW1) W1, float(C1) B1, float(C2, C1) W2, float(C2) B2) |
241 | | - -> (O1, O2) |
242 | | - { |
243 | | - O1(c1, h, w) +=! I(c1, h + kh, w + kw) * W1(c1, kh, kw) |
244 | | - O1(c1, h, w) = O1(c1, h, w) + B1(c1) |
245 | | - O1(c1, h, w) = fmax(O1(c1, h, w), 0) |
246 | | -
|
247 | | - O2(c2, h, w) +=! O1(c1, h, w) * W2(c2, c1) |
248 | | - O2(c2, h, w) = O2(c2, h, w) + B2(c2) |
249 | | - O2(c2, h, w) = fmax(O2(c2, h, w), 0) |
| 240 | + -> (O1, O2) { |
| 241 | + O1(c1, h, w) +=! I(c1, h + r_kh, w + r_kw) * W1(c1, r_kh, r_kw) |
| 242 | + O1(c1, h, w) = O1(c1, h, w) + B1(c1) |
| 243 | + O1(c1, h, w) = fmax(O1(c1, h, w), 0) |
| 244 | +
|
| 245 | + O2(c2, h, w) +=! O1(r_c1, h, w) * W2(c2, r_c1) |
| 246 | + O2(c2, h, w) = O2( c2, h, w) + B2(c2) |
| 247 | + O2(c2, h, w) = fmax(O2(c2, h, w), 0) |
250 | 248 | } |
0 commit comments