diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index d880b7b751fe..8412ab5ae25a 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -1413,9 +1413,17 @@ def __call__(self, features, return_tensors=None, separator_id=None): max_length = 0 for seq_idx, sample in enumerate(features): input_ids = sample["input_ids"] + # Convert to list if tensor + if hasattr(input_ids, "tolist"): + input_ids = input_ids.tolist() batch["input_ids"] += input_ids + if is_labels_provided: - batch["labels"] += [separator_id] + sample["labels"][1:] + labels = sample["labels"] + # Convert to list if tensor + if hasattr(labels, "tolist"): + labels = labels.tolist() + batch["labels"] += [separator_id] + labels[1:] else: batch["labels"] += [separator_id] + input_ids[1:] if self.return_position_ids: diff --git a/tests/trainer/test_data_collator.py b/tests/trainer/test_data_collator.py index b5cbb5ecea28..4c57d284686e 100644 --- a/tests/trainer/test_data_collator.py +++ b/tests/trainer/test_data_collator.py @@ -1965,3 +1965,55 @@ def test__whole_word_mask(self): ).astype(bool) np.testing.assert_array_equal(output_mask, expected_mask) + + +class DataCollatorWithFlatteningTest(unittest.TestCase): + """Tests for DataCollatorWithFlattening""" + + def test_flattening_with_tensor_labels(self): + """Test that DataCollatorWithFlattening supports tensor labels (fixes issue #42599).""" + features = [ + { + "input_ids": torch.tensor([1, 2, 3, 4]), + "labels": torch.tensor([10, 11, 12, 13]), + }, + { + "input_ids": torch.tensor([5, 6, 7]), + "labels": torch.tensor([14, 15, 16]), + }, + ] + collator = DataCollatorWithFlattening(return_tensors="pt") + + # This should not raise TypeError anymore + batch = collator(features) + + # Verify the output + self.assertIsInstance(batch, dict) + self.assertIn("input_ids", batch) + self.assertIn("labels", batch) + self.assertIn("position_ids", batch) + + # Check shapes + self.assertEqual(batch["input_ids"].shape, (1, 7)) # 4 + 3 tokens + self.assertEqual(batch["labels"].shape, (1, 7)) + self.assertEqual(batch["position_ids"].shape, (1, 7)) + + def test_flattening_with_list_labels(self): + """Test that DataCollatorWithFlattening still works with list labels.""" + features = [ + { + "input_ids": torch.tensor([1, 2, 3, 4]), + "labels": [10, 11, 12, 13], + }, + { + "input_ids": torch.tensor([5, 6, 7]), + "labels": [14, 15, 16], + }, + ] + collator = DataCollatorWithFlattening(return_tensors="pt") + batch = collator(features) + + # Verify it still works with lists + self.assertIsInstance(batch, dict) + self.assertEqual(batch["input_ids"].shape, (1, 7)) + self.assertEqual(batch["labels"].shape, (1, 7))