Skip to content

Commit 78c2c00

Browse files
committed
fix: support tensor labels in DataCollatorWithFlattening
- Add tensor to list conversion in DataCollatorWithFlattening - Convert input_ids and labels to list if they are tensors - Add tests for both tensor and list labels - Fixes #42599
1 parent 9b74e4c commit 78c2c00

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

src/transformers/data/data_collator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1413,9 +1413,17 @@ def __call__(self, features, return_tensors=None, separator_id=None):
14131413
max_length = 0
14141414
for seq_idx, sample in enumerate(features):
14151415
input_ids = sample["input_ids"]
1416+
# Convert to list if tensor
1417+
if hasattr(input_ids, "tolist"):
1418+
input_ids = input_ids.tolist()
14161419
batch["input_ids"] += input_ids
1420+
14171421
if is_labels_provided:
1418-
batch["labels"] += [separator_id] + sample["labels"][1:]
1422+
labels = sample["labels"]
1423+
# Convert to list if tensor
1424+
if hasattr(labels, "tolist"):
1425+
labels = labels.tolist()
1426+
batch["labels"] += [separator_id] + labels[1:]
14191427
else:
14201428
batch["labels"] += [separator_id] + input_ids[1:]
14211429
if self.return_position_ids:

tests/trainer/test_data_collator.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1965,3 +1965,56 @@ def test__whole_word_mask(self):
19651965
).astype(bool)
19661966

19671967
np.testing.assert_array_equal(output_mask, expected_mask)
1968+
1969+
class DataCollatorWithFlatteningTest(unittest.TestCase):
1970+
"""Tests for DataCollatorWithFlattening"""
1971+
1972+
def test_flattening_with_tensor_labels(self):
1973+
"""Test that DataCollatorWithFlattening supports tensor labels (fixes issue #42599)."""
1974+
features = [
1975+
{
1976+
"input_ids": torch.tensor([1, 2, 3, 4]),
1977+
"labels": torch.tensor([10, 11, 12, 13]),
1978+
},
1979+
{
1980+
"input_ids": torch.tensor([5, 6, 7]),
1981+
"labels": torch.tensor([14, 15, 16]),
1982+
},
1983+
]
1984+
collator = DataCollatorWithFlattening(return_tensors="pt")
1985+
1986+
# This should not raise TypeError anymore
1987+
batch = collator(features)
1988+
1989+
# Verify the output
1990+
self.assertIsInstance(batch, dict)
1991+
self.assertIn("input_ids", batch)
1992+
self.assertIn("labels", batch)
1993+
self.assertIn("position_ids", batch)
1994+
1995+
# Check shapes
1996+
self.assertEqual(batch["input_ids"].shape, (1, 7)) # 4 + 3 tokens
1997+
self.assertEqual(batch["labels"].shape, (1, 7))
1998+
self.assertEqual(batch["position_ids"].shape, (1, 7))
1999+
2000+
def test_flattening_with_list_labels(self):
2001+
"""Test that DataCollatorWithFlattening still works with list labels."""
2002+
features = [
2003+
{
2004+
"input_ids": torch.tensor([1, 2, 3, 4]),
2005+
"labels": [10, 11, 12, 13],
2006+
},
2007+
{
2008+
"input_ids": torch.tensor([5, 6, 7]),
2009+
"labels": [14, 15, 16],
2010+
},
2011+
]
2012+
collator = DataCollatorWithFlattening(return_tensors="pt")
2013+
batch = collator(features)
2014+
2015+
# Verify it still works with lists
2016+
self.assertIsInstance(batch, dict)
2017+
self.assertEqual(batch["input_ids"].shape, (1, 7))
2018+
self.assertEqual(batch["labels"].shape, (1, 7))
2019+
2020+

0 commit comments

Comments
 (0)