Skip to content

Commit 72b7703

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add feature range validation for ebc (#3583)
Summary: Enable validation of KeyedJaggedTensor feature values against EmbeddingBagConfig ranges to catch out-of-range embedding lookups early. Modified all validation functions to return boolean values, allowing callers to programmatically distinguish between hard failures (structural errors that raise ValueError) and soft failures (out-of-range values that return False with warnings). This supports two use cases: 1. Production monitoring - detect invalid embedding IDs without crashing 2. Data quality checks - identify features with values outside [0, num_embeddings) All validation functions now return bool for consistency, maintaining full backward compatibility since existing code can continue to ignore return values. Reviewed By: malaybag Differential Revision: D88013492
1 parent 223db0d commit 72b7703

File tree

4 files changed

+290
-10
lines changed

4 files changed

+290
-10
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1543,7 +1543,7 @@ def input_dist(
15431543
"pytorch/torchrec:enable_kjt_validation"
15441544
):
15451545
logger.info("Validating input features...")
1546-
validate_keyed_jagged_tensor(features)
1546+
validate_keyed_jagged_tensor(features, self._embedding_bag_configs)
15471547

15481548
self._create_input_dist(features.keys())
15491549
self._has_uninitialized_input_dist = False

torchrec/sparse/jagged_tensor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2825,11 +2825,15 @@ def to_dict(self, compute_offsets: bool = True) -> Dict[str, JaggedTensor]:
28252825
logger.warning(
28262826
"Trying to non-strict torch.export KJT to_dict, which is extremely slow and not recommended!"
28272827
)
2828+
length_per_key = self.length_per_key()
2829+
if isinstance(length_per_key, torch.Tensor):
2830+
# length_per_key should be a list of ints, but in some (incorrect) cases it is a tensor
2831+
length_per_key = length_per_key.tolist()
28282832
_jt_dict = _maybe_compute_kjt_to_jt_dict(
28292833
stride=self.stride(),
28302834
stride_per_key=self.stride_per_key(),
28312835
keys=self.keys(),
2832-
length_per_key=self.length_per_key(),
2836+
length_per_key=length_per_key,
28332837
lengths=self.lengths(),
28342838
values=self.values(),
28352839
variable_stride_per_key=self.variable_stride_per_key(),

torchrec/sparse/jagged_tensor_validator.py

Lines changed: 98 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,47 @@
77

88
# pyre-strict
99

10+
import logging
11+
from typing import Dict, List, Optional
12+
1013
import torch
14+
from torchrec.modules.embedding_configs import EmbeddingBagConfig
1115
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
1216

17+
logger: logging.Logger = logging.getLogger(__name__)
18+
1319

1420
def validate_keyed_jagged_tensor(
15-
kjt: KeyedJaggedTensor,
16-
) -> None:
21+
kjt: KeyedJaggedTensor, configs: Optional[List[EmbeddingBagConfig]] = None
22+
) -> bool:
1723
"""
1824
Validates the inputs that construct a KeyedJaggedTensor.
1925
2026
Any invalid input will result in a ValueError being thrown.
27+
28+
Returns:
29+
bool: True if all validations pass (including feature range),
30+
False if feature range validation fails (soft warning).
2131
"""
2232
_validate_lengths_and_offsets(kjt)
2333
_validate_keys(kjt)
2434
_validate_weights(kjt)
35+
return _validate_feature_range(kjt, configs)
2536

2637

27-
def _validate_lengths_and_offsets(kjt: KeyedJaggedTensor) -> None:
38+
def _validate_lengths_and_offsets(kjt: KeyedJaggedTensor) -> bool:
2839
"""
2940
Validates the lengths and offsets of a KJT.
3041
3142
- At least one of lengths or offsets is provided
3243
- If both are provided, they are consistent with each other
3344
- The dimensions of these tensors align with the values tensor
45+
46+
Returns:
47+
bool: True if validation passes.
48+
49+
Raises:
50+
ValueError: If validation fails.
3451
"""
3552
lengths = kjt.lengths_or_none()
3653
offsets = kjt.offsets_or_none()
@@ -44,6 +61,7 @@ def _validate_lengths_and_offsets(kjt: KeyedJaggedTensor) -> None:
4461
_validate_lengths(lengths, kjt.values())
4562
elif offsets is not None:
4663
_validate_offsets(offsets, kjt.values())
64+
return True
4765

4866

4967
def _validate_lengths_and_offsets_consistency(
@@ -59,16 +77,36 @@ def _validate_lengths_and_offsets_consistency(
5977

6078
if not lengths.equal(torch.diff(offsets)):
6179
raise ValueError("offsets is not equal to the cumulative sum of lengths")
80+
return True
6281

6382

64-
def _validate_lengths(lengths: torch.Tensor, values: torch.Tensor) -> None:
83+
def _validate_lengths(lengths: torch.Tensor, values: torch.Tensor) -> bool:
84+
"""
85+
Validates lengths tensor.
86+
87+
Returns:
88+
bool: True if validation passes.
89+
90+
Raises:
91+
ValueError: If validation fails.
92+
"""
6593
if lengths.sum().item() != values.numel():
6694
raise ValueError(
6795
f"Sum of lengths must equal the number of values, but got {lengths.sum().item()} and {values.numel()}"
6896
)
97+
return True
98+
99+
100+
def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> bool:
101+
"""
102+
Validates offsets tensor.
69103
104+
Returns:
105+
bool: True if validation passes.
70106
71-
def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> None:
107+
Raises:
108+
ValueError: If validation fails.
109+
"""
72110
if offsets.numel() == 0:
73111
raise ValueError("offsets cannot be empty")
74112

@@ -79,14 +117,21 @@ def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> None:
79117
raise ValueError(
80118
f"The last element of offsets must equal to the number of values, but got {offsets[-1]} and {values.numel()}"
81119
)
120+
return True
82121

83122

84-
def _validate_keys(kjt: KeyedJaggedTensor) -> None:
123+
def _validate_keys(kjt: KeyedJaggedTensor) -> bool:
85124
"""
86125
Validates KJT keys, assuming the lengths/offsets input are valid.
87126
88127
- keys must be unique
89128
- For non-VBE cases, the size of lengths is divisible by the number of keys
129+
130+
Returns:
131+
bool: True if validation passes.
132+
133+
Raises:
134+
ValueError: If validation fails.
90135
"""
91136
keys = kjt.keys()
92137

@@ -110,14 +155,60 @@ def _validate_keys(kjt: KeyedJaggedTensor) -> None:
110155
raise ValueError(
111156
f"lengths size must be divisible by keys size, but got {lengths_size} and {len(keys)}"
112157
)
158+
return True
113159

114160

115-
def _validate_weights(kjt: KeyedJaggedTensor) -> None:
161+
def _validate_weights(kjt: KeyedJaggedTensor) -> bool:
116162
"""
117163
Validates if the KJT weights has the same size as values.
164+
165+
Returns:
166+
bool: True if validation passes.
167+
168+
Raises:
169+
ValueError: If validation fails.
118170
"""
119171
weights = kjt.weights_or_none()
120172
if weights is not None and weights.numel() != kjt.values().numel():
121173
raise ValueError(
122174
f"weights size must equal to values size, but got {weights.numel()} and {kjt.values().numel()}"
123175
)
176+
return True
177+
178+
179+
def _validate_feature_range(
180+
kjt: KeyedJaggedTensor, configs: List[EmbeddingBagConfig]
181+
) -> bool:
182+
"""
183+
Validates if the KJT feature range is valid.
184+
185+
Returns:
186+
bool: True if all features are within valid range, False otherwise.
187+
"""
188+
feature_to_range_map: Dict[str, int] = {}
189+
for config in configs:
190+
for feature in config.feature_names:
191+
feature_to_range_map[feature] = config.num_embeddings
192+
193+
if len(kjt.keys() & feature_to_range_map.keys()) == 0:
194+
logger.info(
195+
f"None of KJT._keys {kjt.keys()} in the config {feature_to_range_map.keys()}"
196+
)
197+
return True
198+
199+
valid = True
200+
jtd = kjt.to_dict()
201+
for feature, jt in jtd.items():
202+
if feature not in feature_to_range_map:
203+
logger.info(f"Feature {feature} is not in the config")
204+
continue
205+
if jt.values().numel() == 0:
206+
continue
207+
min_value, max_value = jt.values().min(), jt.values().max()
208+
if min_value < 0 or max_value >= feature_to_range_map[feature]:
209+
logger.warning(
210+
f"Feature {feature} has range {min_value, max_value} "
211+
f"which is out of range {0, feature_to_range_map[feature]}"
212+
)
213+
valid = False
214+
return valid

0 commit comments

Comments
 (0)