Skip to content

Commit 440e55a

Browse files
jeffkbkimmeta-codesync[bot]
authored andcommitted
Pop _trained_batches key from state_dict on load_state_dict (#3573)
Summary: Pull Request resolved: #3573 Remove _trained_batches key from metric module state_dict upon loading. Confirmed that the new unit test, `test_load_state_dict_with_trained_batches_key`, fails without the new load_state_dict hook. Reviewed By: iamzainhuda Differential Revision: D87669499 fbshipit-source-id: 1b3f3f0fca4bec9a8d2b339a4e2ca67fcb0983f0
1 parent 217889e commit 440e55a

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

torchrec/metrics/metric_module.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import concurrent
1414
import logging
1515
import time
16-
from collections import defaultdict
16+
from collections import defaultdict, OrderedDict
1717
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
1818

1919
import torch
@@ -228,6 +228,26 @@ def __init__(
228228
)
229229
self.last_compute_time = -1.0
230230

231+
self._register_load_state_dict_pre_hook(self.load_state_dict_hook)
232+
233+
def load_state_dict_hook(
234+
self,
235+
state_dict: OrderedDict[str, torch.Tensor],
236+
prefix: str,
237+
local_metadata: Dict[str, Any],
238+
strict: bool,
239+
missing_keys: List[str],
240+
unexpected_keys: List[str],
241+
error_msgs: List[str],
242+
) -> None:
243+
"""Remove _trained_batches key for backward compatibility."""
244+
key = f"{prefix}_trained_batches"
245+
if key in state_dict:
246+
state_dict.pop(key)
247+
logger.warning(
248+
f"Removed key '{key}' from state_dict for backward compatibility"
249+
)
250+
231251
def _update_rec_metrics(
232252
self, model_out: Dict[str, torch.Tensor], **kwargs: Any
233253
) -> None:

torchrec/metrics/tests/test_metric_module.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,52 @@ def test_async_compute_raises_exception(self) -> None:
664664
):
665665
metric_module.async_compute(concurrent.futures.Future())
666666

667+
def test_load_state_dict_with_trained_batches_key(self) -> None:
668+
metric_module = generate_metric_module(
669+
TestMetricModule,
670+
metrics_config=DefaultMetricsConfig,
671+
batch_size=128,
672+
world_size=1,
673+
my_rank=0,
674+
state_metrics_mapping={},
675+
device=torch.device("cpu"),
676+
)
677+
state_dict = metric_module.state_dict()
678+
679+
# Add the _trained_batches key to simulate old checkpoint
680+
state_dict["_trained_batches"] = torch.tensor(42, dtype=torch.long)
681+
682+
# Load the state_dict with _trained_batches
683+
# This should not raise an error
684+
metric_module.load_state_dict(state_dict)
685+
metric_module.update(gen_test_batch(128))
686+
result = metric_module.compute()
687+
self.assertIsInstance(result, dict)
688+
self.assertTrue(len(result) > 0)
689+
690+
def test_load_state_dict_without_trained_batches_key(self) -> None:
691+
metric_module = generate_metric_module(
692+
TestMetricModule,
693+
metrics_config=DefaultMetricsConfig,
694+
batch_size=128,
695+
world_size=1,
696+
my_rank=0,
697+
state_metrics_mapping={},
698+
device=torch.device("cpu"),
699+
)
700+
state_dict = metric_module.state_dict()
701+
702+
# Verify the key is not in the state_dict
703+
self.assertNotIn("_trained_batches", state_dict)
704+
705+
# Load the clean state_dict
706+
# This should not raise an error
707+
metric_module.load_state_dict(state_dict)
708+
metric_module.update(gen_test_batch(128))
709+
result = metric_module.compute()
710+
self.assertIsInstance(result, dict)
711+
self.assertTrue(len(result) > 0)
712+
667713

668714
def metric_module_gather_state(
669715
rank: int,

0 commit comments

Comments
 (0)