Skip to content

Commit 3066589

Browse files
committed
Add test case and update image processing
1 parent 3b72301 commit 3066589

File tree

3 files changed

+129
-2
lines changed

3 files changed

+129
-2
lines changed

src/transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ def get_optimal_tiled_canvas(
113113
# Pick the resolution that required the least upscaling so that it most closely fits the image
114114
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
115115
best_grid = possible_resolutions[np.argmin(required_scale)]
116-
return best_grid
116+
best_grid_row, best_grid_col = best_grid
117+
return best_grid_col, best_grid_row # revere the order to align with boilerplate code
117118

118119

119120
@auto_docstring

src/transformers/models/cohere2_vision/modular_cohere2_vision.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,8 @@ def get_optimal_tiled_canvas(
302302
# Pick the resolution that required the least upscaling so that it most closely fits the image
303303
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
304304
best_grid = possible_resolutions[np.argmin(required_scale)]
305-
return best_grid
305+
best_grid_row, best_grid_col = best_grid
306+
return best_grid_col, best_grid_row # revere the order to align with boilerplate code
306307

307308

308309
@auto_docstring

tests/models/cohere2_vision/test_image_processing_cohere2_vision.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,128 @@ def test_call_numpy_4_channels(self):
190190
image_std=1,
191191
).pixel_values
192192
self.assertEqual(tuple(encoded_images.shape), (70, 4, 30, 30))
193+
194+
def test_crop_to_patches_aspect_ratio(self):
195+
"""Test that row/column ordering is correct when cropping non-square images to patches.
196+
197+
This test verifies that patches can be stitched back to reconstruct the original image,
198+
which validates that the row/column ordering in get_optimal_tiled_canvas is correct.
199+
If row/column are swapped, the image would be resized to wrong dimensions and patches
200+
would not match the original content.
201+
"""
202+
for image_processing_class in self.image_processor_list:
203+
patch_size = 64
204+
image_processor = image_processing_class(
205+
do_resize=True,
206+
size={"height": patch_size, "width": patch_size},
207+
do_normalize=False, # Disable normalization to preserve pixel values
208+
do_rescale=False, # Disable rescaling to preserve pixel values
209+
crop_to_patches=True,
210+
min_patches=1,
211+
max_patches=6, # Allow up to 6 patches to test asymmetric grids like 2x3
212+
)
213+
214+
# Create a 2:3 aspect ratio image (2 rows x 3 columns of patches)
215+
# This asymmetric grid will fail if rows/columns are swapped
216+
num_rows, num_cols = 2, 3
217+
image_height = patch_size * num_rows # 128
218+
image_width = patch_size * num_cols # 192
219+
220+
# Create image with unique color for each patch position
221+
test_image = Image.new("RGB", (image_width, image_height))
222+
for row in range(num_rows):
223+
for col in range(num_cols):
224+
patch_idx = row * num_cols + col # 0-5
225+
color = (patch_idx * 40 + 20, 0, 0) # Unique red values: 20, 60, 100, 140, 180, 220
226+
for y in range(patch_size):
227+
for x in range(patch_size):
228+
test_image.putpixel(
229+
(col * patch_size + x, row * patch_size + y),
230+
color,
231+
)
232+
233+
# Process image
234+
result = image_processor(test_image, return_tensors="pt")
235+
patches = result.pixel_values
236+
num_patches_result = result.num_patches
237+
238+
# Should produce 7 patches (6 grid patches + 1 thumbnail)
239+
self.assertEqual(num_patches_result.tolist(), [7])
240+
self.assertEqual(tuple(patches.shape), (7, 3, patch_size, patch_size))
241+
242+
# Verify each patch has the correct color (excluding thumbnail which is last)
243+
# Patches should be ordered row by row: (0,0), (0,1), (0,2), (1,0), (1,1), (1,2)
244+
for patch_idx in range(6):
245+
expected_red = patch_idx * 40 + 20
246+
actual_red = patches[patch_idx, 0, 0, 0].item() # Red channel, top-left pixel
247+
self.assertEqual(
248+
actual_red,
249+
expected_red,
250+
f"Patch {patch_idx} has wrong color. Expected red={expected_red}, got {actual_red}. "
251+
f"This indicates row/column ordering is incorrect.",
252+
)
253+
254+
# Stitch patches back and verify against original
255+
stitched = torch.zeros(3, image_height, image_width)
256+
for patch_idx in range(6):
257+
row = patch_idx // num_cols
258+
col = patch_idx % num_cols
259+
stitched[
260+
:,
261+
row * patch_size : (row + 1) * patch_size,
262+
col * patch_size : (col + 1) * patch_size,
263+
] = patches[patch_idx]
264+
265+
original_tensor = torch.tensor(np.array(test_image)).permute(2, 0, 1).float()
266+
self.assertTrue(
267+
torch.allclose(stitched, original_tensor),
268+
"Patches do not stitch back to original image - row/column ordering may be wrong",
269+
)
270+
271+
def test_get_number_of_image_patches_aspect_ratio(self):
272+
"""Test that get_number_of_image_patches returns correct count for non-square images.
273+
274+
This directly tests the row/column unpacking fix by verifying patch counts match
275+
the expected grid layout. If rows/columns are swapped, the wrong grid would be
276+
chosen for asymmetric images.
277+
"""
278+
for image_processing_class in self.image_processor_list:
279+
patch_size = 64
280+
image_processor = image_processing_class(
281+
size={"height": patch_size, "width": patch_size},
282+
crop_to_patches=True,
283+
min_patches=1,
284+
max_patches=12,
285+
)
286+
287+
# Test 1: Tall image (4 rows x 1 column) should give 5 patches (4 + thumbnail)
288+
tall_patches = image_processor.get_number_of_image_patches(
289+
height=patch_size * 4, # 256
290+
width=patch_size, # 64
291+
images_kwargs={},
292+
)
293+
self.assertEqual(tall_patches, 5, "Tall image (4:1) should produce 5 patches")
294+
295+
# Test 2: Wide image (1 row x 4 columns) should give 5 patches (4 + thumbnail)
296+
wide_patches = image_processor.get_number_of_image_patches(
297+
height=patch_size, # 64
298+
width=patch_size * 4, # 256
299+
images_kwargs={},
300+
)
301+
self.assertEqual(wide_patches, 5, "Wide image (1:4) should produce 5 patches")
302+
303+
# Test 3: Asymmetric image (2 rows x 3 columns) should give 7 patches
304+
asym_patches = image_processor.get_number_of_image_patches(
305+
height=patch_size * 2, # 128
306+
width=patch_size * 3, # 192
307+
images_kwargs={"max_patches": 6},
308+
)
309+
self.assertEqual(asym_patches, 7, "Asymmetric image (2:3) should produce 7 patches")
310+
311+
# Test 4: Opposite asymmetric (3 rows x 2 columns) should also give 7 patches
312+
asym_patches2 = image_processor.get_number_of_image_patches(
313+
height=patch_size * 3, # 192
314+
width=patch_size * 2, # 128
315+
images_kwargs={"max_patches": 6},
316+
)
317+
self.assertEqual(asym_patches2, 7, "Asymmetric image (3:2) should produce 7 patches")

0 commit comments

Comments
 (0)