@@ -75,6 +75,18 @@ struct GenericHalideCoreTest : public ::testing::Test {
7575 curPos = newPos;
7676 }
7777 }
78+ void CheckC (const std::string& tc, const std::string& expected) {
79+ std::istringstream stream (expected);
80+ std::string line;
81+ std::vector<std::string> split;
82+ while (std::getline (stream, line)) {
83+ // Skip lines containing (only) closing brace.
84+ if (line.find (' }' ) == std::string::npos) {
85+ split.emplace_back (line);
86+ }
87+ }
88+ CheckC (tc, split);
89+ }
7890};
7991
8092TEST_F (GenericHalideCoreTest, TwoMatmul) {
@@ -86,18 +98,32 @@ def fun(float(M, K) I, float(K, N) W1, float(N, P) W2) -> (O1, O2) {
8698)TC" ;
8799 CheckC (
88100 tc,
89- {
90- " for (int O1_s0_m = 0; O1_s0_m < M; O1_s0_m++) {" ,
91- " for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {" ,
92- " O1[O1_s0_m][O1_s0_n] = 0.000000f" ,
93- " for (int O1_s1_r_k = 0; O1_s1_r_k < K; O1_s1_r_k++) {" ,
94- " O1[O1_s0_m][O1_s0_n] = (O1[O1_s0_m][O1_s0_n] + (I[O1_s0_m][O1_s1_r_k]*W1[O1_s1_r_k][O1_s0_n]))" ,
95- " for (int O2_s0_m = 0; O2_s0_m < M; O2_s0_m++) {" ,
96- " for (int O2_s0_p = 0; O2_s0_p < P; O2_s0_p++) {" ,
97- " O2[O2_s0_m][O2_s0_p] = 0.000000f" ,
98- " for (int O2_s1_r_n = 0; O2_s1_r_n < N; O2_s1_r_n++) {" ,
99- " O2[O2_s0_m][O2_s0_p] = (O2[O2_s0_m][O2_s0_p] + (O1[O2_s0_m][O2_s1_r_n]*W2[O2_s1_r_n][O2_s0_p]))" ,
100- });
101+ R"C(
102+ for (int O1_s0_m = 0; O1_s0_m < M; O1_s0_m++) {
103+ for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {
104+ O1[O1_s0_m][O1_s0_n] = 0.000000f;
105+ }
106+ }
107+ for (int O1_s1_m = 0; O1_s1_m < M; O1_s1_m++) {
108+ for (int O1_s1_n = 0; O1_s1_n < N; O1_s1_n++) {
109+ for (int O1_s1_r_k = 0; O1_s1_r_k < K; O1_s1_r_k++) {
110+ O1[O1_s1_m][O1_s1_n] = (O1[O1_s1_m][O1_s1_n] + (I[O1_s1_m][O1_s1_r_k]*W1[O1_s1_r_k][O1_s1_n]));
111+ }
112+ }
113+ }
114+ for (int O2_s0_m = 0; O2_s0_m < M; O2_s0_m++) {
115+ for (int O2_s0_p = 0; O2_s0_p < P; O2_s0_p++) {
116+ O2[O2_s0_m][O2_s0_p] = 0.000000f;
117+ }
118+ }
119+ for (int O2_s1_m = 0; O2_s1_m < M; O2_s1_m++) {
120+ for (int O2_s1_p = 0; O2_s1_p < P; O2_s1_p++) {
121+ for (int O2_s1_r_n = 0; O2_s1_r_n < N; O2_s1_r_n++) {
122+ O2[O2_s1_m][O2_s1_p] = (O2[O2_s1_m][O2_s1_p] + (O1[O2_s1_m][O2_s1_r_n]*W2[O2_s1_r_n][O2_s1_p]));
123+ }
124+ }
125+ }
126+ )C" );
101127}
102128
103129TEST_F (GenericHalideCoreTest, Convolution) {
@@ -108,15 +134,32 @@ def fun(float(N, C, H, W) I1, float(C, F, KH, KW) W1) -> (O1) {
108134)TC" ;
109135 CheckC (
110136 tc,
111- {" for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {" ,
112- " for (int O1_s0_f = 0; O1_s0_f < F; O1_s0_f++) {" ,
113- " for (int O1_s0_h = 0; O1_s0_h < ((H - KH) + 1); O1_s0_h++) {" ,
114- " for (int O1_s0_w = 0; O1_s0_w < ((W - KW) + 1); O1_s0_w++) {" ,
115- " O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] = 0.000000f" ,
116- " for (int O1_s1_r_c = 0; O1_s1_r_c < C; O1_s1_r_c++) {" ,
117- " for (int O1_s1_r_kh = 0; O1_s1_r_kh < KH; O1_s1_r_kh++) {" ,
118- " for (int O1_s1_r_kw = 0; O1_s1_r_kw < KW; O1_s1_r_kw++) {" ,
119- " O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] = (O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] + (I1[O1_s0_n][O1_s1_r_c][(O1_s0_h + O1_s1_r_kh)][(O1_s0_w + O1_s1_r_kw)]*W1[O1_s1_r_c][O1_s0_f][O1_s1_r_kh][O1_s1_r_kw]))" });
137+ R"C(
138+ for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {
139+ for (int O1_s0_f = 0; O1_s0_f < F; O1_s0_f++) {
140+ for (int O1_s0_h = 0; O1_s0_h < ((H - KH) + 1); O1_s0_h++) {
141+ for (int O1_s0_w = 0; O1_s0_w < ((W - KW) + 1); O1_s0_w++) {
142+ O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] = 0.000000f;
143+ }
144+ }
145+ }
146+ }
147+ for (int O1_s1_n = 0; O1_s1_n < N; O1_s1_n++) {
148+ for (int O1_s1_f = 0; O1_s1_f < F; O1_s1_f++) {
149+ for (int O1_s1_h = 0; O1_s1_h < ((H - KH) + 1); O1_s1_h++) {
150+ for (int O1_s1_w = 0; O1_s1_w < ((W - KW) + 1); O1_s1_w++) {
151+ for (int O1_s1_r_c = 0; O1_s1_r_c < C; O1_s1_r_c++) {
152+ for (int O1_s1_r_kh = 0; O1_s1_r_kh < KH; O1_s1_r_kh++) {
153+ for (int O1_s1_r_kw = 0; O1_s1_r_kw < KW; O1_s1_r_kw++) {
154+ O1[O1_s1_n][O1_s1_f][O1_s1_h][O1_s1_w] = (O1[O1_s1_n][O1_s1_f][O1_s1_h][O1_s1_w] + (I1[O1_s1_n][O1_s1_r_c][(O1_s1_h + O1_s1_r_kh)][(O1_s1_w + O1_s1_r_kw)]*W1[O1_s1_r_c][O1_s1_f][O1_s1_r_kh][O1_s1_r_kw]));
155+ }
156+ }
157+ }
158+ }
159+ }
160+ }
161+ }
162+ )C" );
120163}
121164
122165TEST_F (GenericHalideCoreTest, Copy) {
@@ -136,27 +179,55 @@ def fun(float(N, G, C, H, W) I1, float(G, C, F, KH, KW) W1) -> (O1) {
136179)TC" ;
137180 CheckC (
138181 tc,
139- {" for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {" ,
140- " for (int O1_s0_g = 0; O1_s0_g < G; O1_s0_g++) {" ,
141- " for (int O1_s0_f = 0; O1_s0_f < F; O1_s0_f++) {" ,
142- " for (int O1_s0_h = 0; O1_s0_h < ((H - KH) + 1); O1_s0_h++) {" ,
143- " for (int O1_s0_w = 0; O1_s0_w < ((W - KW) + 1); O1_s0_w++) {" ,
144- " O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] = 0.000000f" ,
145- " for (int O1_s1_r_c = 0; O1_s1_r_c < C; O1_s1_r_c++) {" ,
146- " for (int O1_s1_r_kh = 0; O1_s1_r_kh < KH; O1_s1_r_kh++) {" ,
147- " for (int O1_s1_r_kw = 0; O1_s1_r_kw < KW; O1_s1_r_kw++) {" ,
148- " O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] = (O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] + (I1[O1_s0_n][O1_s0_g][O1_s1_r_c][(O1_s0_h + O1_s1_r_kh)][(O1_s0_w + O1_s1_r_kw)]*W1[O1_s0_g][O1_s1_r_c][O1_s0_f][O1_s1_r_kh][O1_s1_r_kw]))" });
182+ R"C(
183+ for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {
184+ for (int O1_s0_g = 0; O1_s0_g < G; O1_s0_g++) {
185+ for (int O1_s0_f = 0; O1_s0_f < F; O1_s0_f++) {
186+ for (int O1_s0_h = 0; O1_s0_h < ((H - KH) + 1); O1_s0_h++) {
187+ for (int O1_s0_w = 0; O1_s0_w < ((W - KW) + 1); O1_s0_w++) {
188+ O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] = 0.000000f;
189+ }
190+ }
191+ }
192+ }
193+ }
194+ for (int O1_s1_n = 0; O1_s1_n < N; O1_s1_n++) {
195+ for (int O1_s1_g = 0; O1_s1_g < G; O1_s1_g++) {
196+ for (int O1_s1_f = 0; O1_s1_f < F; O1_s1_f++) {
197+ for (int O1_s1_h = 0; O1_s1_h < ((H - KH) + 1); O1_s1_h++) {
198+ for (int O1_s1_w = 0; O1_s1_w < ((W - KW) + 1); O1_s1_w++) {
199+ for (int O1_s1_r_c = 0; O1_s1_r_c < C; O1_s1_r_c++) {
200+ for (int O1_s1_r_kh = 0; O1_s1_r_kh < KH; O1_s1_r_kh++) {
201+ for (int O1_s1_r_kw = 0; O1_s1_r_kw < KW; O1_s1_r_kw++) {
202+ O1[O1_s1_n][O1_s1_g][O1_s1_f][O1_s1_h][O1_s1_w] = (O1[O1_s1_n][O1_s1_g][O1_s1_f][O1_s1_h][O1_s1_w] + (I1[O1_s1_n][O1_s1_g][O1_s1_r_c][(O1_s1_h + O1_s1_r_kh)][(O1_s1_w + O1_s1_r_kw)]*W1[O1_s1_g][O1_s1_r_c][O1_s1_f][O1_s1_r_kh][O1_s1_r_kw]));
203+ }
204+ }
205+ }
206+ }
207+ }
208+ }
209+ }
210+ }
211+ )C" );
149212}
150213
151214TEST_F (GenericHalideCoreTest, Matmul) {
152215 CheckC (
153216 makeMatmulTc (false , false ),
154- std::vector<std::string>{
155- " for (int O_s0_i = 0; O_s0_i < N; O_s0_i++) {" ,
156- " for (int O_s0_j = 0; O_s0_j < M; O_s0_j++) {" ,
157- " O[O_s0_i][O_s0_j] = 0.000000f;" ,
158- " for (int O_s1_k = 0; O_s1_k < K; O_s1_k++) {" ,
159- " O[O_s0_i][O_s0_j] = (O[O_s0_i][O_s0_j] + (A[O_s0_i][O_s1_k]*B[O_s1_k][O_s0_j]));" });
217+ R"C(
218+ for (int O_s0_i = 0; O_s0_i < N; O_s0_i++) {
219+ for (int O_s0_j = 0; O_s0_j < M; O_s0_j++) {
220+ O[O_s0_i][O_s0_j] = 0.000000f;
221+ }
222+ }
223+ for (int O_s1_i = 0; O_s1_i < N; O_s1_i++) {
224+ for (int O_s1_j = 0; O_s1_j < M; O_s1_j++) {
225+ for (int O_s1_k = 0; O_s1_k < K; O_s1_k++) {
226+ O[O_s1_i][O_s1_j] = (O[O_s1_i][O_s1_j] + (A[O_s1_i][O_s1_k]*B[O_s1_k][O_s1_j]));
227+ }
228+ }
229+ }
230+ )C" );
160231}
161232
162233using namespace isl ::with_exceptions;
0 commit comments