@@ -26,10 +26,11 @@ using namespace scl;
2626
2727using FF = math::Fp<61 >;
2828using Mat = math::Mat<FF>;
29+ using Vec = math::Vec<FF>;
2930
3031namespace {
3132
32- void Populate (Mat& m, const int * values) {
33+ void Populate (Mat& m, const std::vector< int >& values) {
3334 for (std::size_t i = 0 ; i < m.Rows (); i++) {
3435 for (std::size_t j = 0 ; j < m.Cols (); j++) {
3536 m (i, j) = FF (values[i * m.Cols () + j]);
@@ -41,11 +42,9 @@ void Populate(Mat& m, const int* values) {
4142
4243TEST_CASE (" Matrix construction" , " [math][matrix]" ) {
4344 Mat m0 (2 , 2 );
44- int v0[] = {1 , 2 , 5 , 6 };
45- Populate (m0, v0);
45+ Populate (m0, {1 , 2 , 5 , 6 });
4646 Mat m1 (2 , 2 );
47- int v1[] = {4 , 3 , 2 , 1 };
48- Populate (m1, v1);
47+ Populate (m1, {4 , 3 , 2 , 1 });
4948
5049 REQUIRE (!m0.Equals (m1));
5150 REQUIRE (m0.Rows () == 2 );
@@ -100,8 +99,7 @@ TEST_CASE("Matrix construction from Vec", "[math][matrix]") {
10099
101100TEST_CASE (" Matrix mutation" , " [math][matrix]" ) {
102101 Mat m0 (2 , 2 );
103- int v0[] = {1 , 2 , 5 , 6 };
104- Populate (m0, v0);
102+ Populate (m0, {1 , 2 , 5 , 6 });
105103
106104 auto m = m0;
107105 m (0 , 1 ) = FF (100 );
@@ -111,8 +109,7 @@ TEST_CASE("Matrix mutation", "[math][matrix]") {
111109
112110TEST_CASE (" Matrix ToString" , " [math][matrix]" ) {
113111 Mat m (3 , 2 );
114- int v[] = {1 , 2 , 44444 , 5 , 6 , 7 };
115- Populate (m, v);
112+ Populate (m, {1 , 2 , 44444 , 5 , 6 , 7 });
116113 std::string expected =
117114 " \n "
118115 " [ 1 2 ]\n "
@@ -130,11 +127,9 @@ TEST_CASE("Matrix ToString", "[math][matrix]") {
130127
131128TEST_CASE (" Matrix Addition" , " [math][matrix]" ) {
132129 Mat m0 (2 , 2 );
133- int v0[] = {1 , 2 , 5 , 6 };
134- Populate (m0, v0);
130+ Populate (m0, {1 , 2 , 5 , 6 });
135131 Mat m1 (2 , 2 );
136- int v1[] = {4 , 3 , 2 , 1 };
137- Populate (m1, v1);
132+ Populate (m1, {4 , 3 , 2 , 1 });
138133
139134 auto m2 = m0.Add (m1);
140135 REQUIRE (m2.Rows () == 2 );
@@ -149,11 +144,9 @@ TEST_CASE("Matrix Addition", "[math][matrix]") {
149144
150145TEST_CASE (" Matrix Subtraction" , " [math][matrix]" ) {
151146 Mat m0 (2 , 2 );
152- int v0[] = {1 , 2 , 5 , 6 };
153- Populate (m0, v0);
147+ Populate (m0, {1 , 2 , 5 , 6 });
154148 Mat m1 (2 , 2 );
155- int v1[] = {4 , 3 , 2 , 1 };
156- Populate (m1, v1);
149+ Populate (m1, {4 , 3 , 2 , 1 });
157150
158151 auto m2 = m0.Subtract (m1);
159152 REQUIRE (m2 (0 , 0 ) == FF (1 ) - FF (4 ));
@@ -166,11 +159,9 @@ TEST_CASE("Matrix Subtraction", "[math][matrix]") {
166159
167160TEST_CASE (" Matrix MultiplyEntryWise" , " [math][matrix]" ) {
168161 Mat m0 (2 , 2 );
169- int v0[] = {1 , 2 , 5 , 6 };
170- Populate (m0, v0);
162+ Populate (m0, {1 , 2 , 5 , 6 });
171163 Mat m1 (2 , 2 );
172- int v1[] = {4 , 3 , 2 , 1 };
173- Populate (m1, v1);
164+ Populate (m1, {4 , 3 , 2 , 1 });
174165
175166 auto m2 = m0.MultiplyEntryWise (m1);
176167 REQUIRE (m2 (0 , 0 ) == FF (4 ));
@@ -183,11 +174,9 @@ TEST_CASE("Matrix MultiplyEntryWise", "[math][matrix]") {
183174
184175TEST_CASE (" Matrix Multiply" , " [math][matrix]" ) {
185176 Mat m0 (2 , 2 );
186- int v0[] = {1 , 2 , 5 , 6 };
187- Populate (m0, v0);
177+ Populate (m0, {1 , 2 , 5 , 6 });
188178 Mat m1 (2 , 2 );
189- int v1[] = {4 , 3 , 2 , 1 };
190- Populate (m1, v1);
179+ Populate (m1, {4 , 3 , 2 , 1 });
191180
192181 auto m2 = m0.Multiply (m1);
193182 REQUIRE (m2.Rows () == 2 );
@@ -198,32 +187,45 @@ TEST_CASE("Matrix Multiply", "[math][matrix]") {
198187 REQUIRE (m2 (1 , 1 ) == FF (21 ));
199188
200189 Mat m3 (2 , 10 );
201- int v3[] = {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 0 ,
202- 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 };
203- Populate (m3, v3);
190+ Populate (m3, {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 0 ,
191+ 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 });
204192
205193 auto m5 = m0.Multiply (m3);
206194 REQUIRE (m5.Rows () == 2 );
207195 REQUIRE (m5.Cols () == 10 );
208196 Mat m4 (2 , 10 );
209- int v4[] = {23 , 26 , 29 , 32 , 35 , 38 , 41 , 44 , 47 , 40 ,
210- 71 , 82 , 93 , 104 , 115 , 126 , 137 , 148 , 159 , 120 };
211- Populate (m4, v4);
197+ Populate (m4, {23 , 26 , 29 , 32 , 35 , 38 , 41 , 44 , 47 , 40 ,
198+ 71 , 82 , 93 , 104 , 115 , 126 , 137 , 148 , 159 , 120 });
212199 REQUIRE (m5.Equals (m4));
213200
214201 REQUIRE_THROWS_MATCHES (
215202 m3.Multiply (m0),
216203 std::invalid_argument,
217- Catch::Matchers::Message (" invalid matrix dimensions for multiply" ));
204+ Catch::Matchers::Message (" matmul: this->Cols() != that->Rows()" ));
205+ }
206+
207+ TEST_CASE (" Matrix vector multiply" , " [math][matrix]" ) {
208+ Mat m0 (2 , 3 );
209+ Populate (m0, {1 , 2 , 3 , 4 , 5 , 6 });
210+ Vec v0 = {FF (1 ), FF (2 ), FF (3 )};
211+
212+ Vec v1 = m0.Multiply (v0);
213+ REQUIRE (v1.Size () == 2 );
214+ REQUIRE (v1[0 ] == FF (1 * 1 + 2 * 2 + 3 * 3 ));
215+ REQUIRE (v1[1 ] == FF (4 * 1 + 5 * 2 + 6 * 3 ));
216+
217+ Vec v2 = {FF (6 ), FF (7 )};
218+ REQUIRE_THROWS_MATCHES (
219+ m0.Multiply (v2),
220+ std::invalid_argument,
221+ Catch::Matchers::Message (" matmul: this->Cols() != vec.Size()" ));
218222}
219223
220224TEST_CASE (" Matrix ScalarMultiply" , " [math][matrix]" ) {
221225 Mat m0 (2 , 2 );
222- int v0[] = {1 , 2 , 5 , 6 };
223- Populate (m0, v0);
226+ Populate (m0, {1 , 2 , 5 , 6 });
224227 Mat m1 (2 , 2 );
225- int v1[] = {4 , 3 , 2 , 1 };
226- Populate (m1, v1);
228+ Populate (m1, {4 , 3 , 2 , 1 });
227229
228230 auto m2 = m0.ScalarMultiply (FF (2 ));
229231 REQUIRE (m2 (0 , 0 ) == FF (2 ));
@@ -240,8 +242,7 @@ TEST_CASE("Matrix ScalarMultiply", "[math][matrix]") {
240242
241243TEST_CASE (" Matrix Transpose" , " [math][matrix]" ) {
242244 Mat m3 (2 , 3 );
243- int v3[] = {1 , 2 , 3 , 11 , 12 , 13 };
244- Populate (m3, v3);
245+ Populate (m3, {1 , 2 , 3 , 11 , 12 , 13 });
245246 auto m4 = m3.Transpose ();
246247 REQUIRE (m4.Rows () == m3.Cols ());
247248 REQUIRE (m4.Cols () == m3.Rows ());
0 commit comments