Skip to content

Commit 3ea3f9c

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add feature range validation for ebc (#3583)
Summary: Enable validation of KeyedJaggedTensor feature values against EmbeddingBagConfig ranges to catch out-of-range embedding lookups early. Modified all validation functions to return boolean values, allowing callers to programmatically distinguish between hard failures (structural errors that raise ValueError) and soft failures (out-of-range values that return False with warnings). This supports two use cases: 1. Production monitoring - detect invalid embedding IDs without crashing 2. Data quality checks - identify features with values outside [0, num_embeddings) All validation functions now return bool for consistency, maintaining full backward compatibility since existing code can continue to ignore return values. Reviewed By: malaybag, spmex Differential Revision: D88013492
1 parent 223db0d commit 3ea3f9c

File tree

4 files changed

+293
-10
lines changed

4 files changed

+293
-10
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1543,7 +1543,7 @@ def input_dist(
15431543
"pytorch/torchrec:enable_kjt_validation"
15441544
):
15451545
logger.info("Validating input features...")
1546-
validate_keyed_jagged_tensor(features)
1546+
validate_keyed_jagged_tensor(features, self._embedding_bag_configs)
15471547

15481548
self._create_input_dist(features.keys())
15491549
self._has_uninitialized_input_dist = False

torchrec/sparse/jagged_tensor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2825,11 +2825,15 @@ def to_dict(self, compute_offsets: bool = True) -> Dict[str, JaggedTensor]:
28252825
logger.warning(
28262826
"Trying to non-strict torch.export KJT to_dict, which is extremely slow and not recommended!"
28272827
)
2828+
length_per_key = self.length_per_key()
2829+
if isinstance(length_per_key, torch.Tensor):
2830+
# length_per_key should be a list of ints, but in some (incorrect) cases it is a tensor
2831+
length_per_key = length_per_key.tolist()
28282832
_jt_dict = _maybe_compute_kjt_to_jt_dict(
28292833
stride=self.stride(),
28302834
stride_per_key=self.stride_per_key(),
28312835
keys=self.keys(),
2832-
length_per_key=self.length_per_key(),
2836+
length_per_key=length_per_key,
28332837
lengths=self.lengths(),
28342838
values=self.values(),
28352839
variable_stride_per_key=self.variable_stride_per_key(),

torchrec/sparse/jagged_tensor_validator.py

Lines changed: 101 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,50 @@
77

88
# pyre-strict
99

10+
import logging
11+
from typing import Dict, List, Optional
12+
1013
import torch
14+
from torchrec.modules.embedding_configs import EmbeddingBagConfig
1115
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
1216

17+
logger: logging.Logger = logging.getLogger(__name__)
18+
1319

1420
def validate_keyed_jagged_tensor(
15-
kjt: KeyedJaggedTensor,
16-
) -> None:
21+
kjt: KeyedJaggedTensor, configs: Optional[List[EmbeddingBagConfig]] = None
22+
) -> bool:
1723
"""
1824
Validates the inputs that construct a KeyedJaggedTensor.
1925
2026
Any invalid input will result in a ValueError being thrown.
27+
28+
Returns:
29+
bool: True if all validations pass (including feature range),
30+
False if feature range validation fails (soft warning).
2131
"""
2232
_validate_lengths_and_offsets(kjt)
2333
_validate_keys(kjt)
2434
_validate_weights(kjt)
35+
if configs is not None:
36+
return _validate_feature_range(kjt, configs)
37+
else:
38+
return True
2539

2640

27-
def _validate_lengths_and_offsets(kjt: KeyedJaggedTensor) -> None:
41+
def _validate_lengths_and_offsets(kjt: KeyedJaggedTensor) -> bool:
2842
"""
2943
Validates the lengths and offsets of a KJT.
3044
3145
- At least one of lengths or offsets is provided
3246
- If both are provided, they are consistent with each other
3347
- The dimensions of these tensors align with the values tensor
48+
49+
Returns:
50+
bool: True if validation passes.
51+
52+
Raises:
53+
ValueError: If validation fails.
3454
"""
3555
lengths = kjt.lengths_or_none()
3656
offsets = kjt.offsets_or_none()
@@ -44,6 +64,7 @@ def _validate_lengths_and_offsets(kjt: KeyedJaggedTensor) -> None:
4464
_validate_lengths(lengths, kjt.values())
4565
elif offsets is not None:
4666
_validate_offsets(offsets, kjt.values())
67+
return True
4768

4869

4970
def _validate_lengths_and_offsets_consistency(
@@ -59,16 +80,36 @@ def _validate_lengths_and_offsets_consistency(
5980

6081
if not lengths.equal(torch.diff(offsets)):
6182
raise ValueError("offsets is not equal to the cumulative sum of lengths")
83+
return True
6284

6385

64-
def _validate_lengths(lengths: torch.Tensor, values: torch.Tensor) -> None:
86+
def _validate_lengths(lengths: torch.Tensor, values: torch.Tensor) -> bool:
87+
"""
88+
Validates lengths tensor.
89+
90+
Returns:
91+
bool: True if validation passes.
92+
93+
Raises:
94+
ValueError: If validation fails.
95+
"""
6596
if lengths.sum().item() != values.numel():
6697
raise ValueError(
6798
f"Sum of lengths must equal the number of values, but got {lengths.sum().item()} and {values.numel()}"
6899
)
100+
return True
101+
69102

103+
def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> bool:
104+
"""
105+
Validates offsets tensor.
70106
71-
def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> None:
107+
Returns:
108+
bool: True if validation passes.
109+
110+
Raises:
111+
ValueError: If validation fails.
112+
"""
72113
if offsets.numel() == 0:
73114
raise ValueError("offsets cannot be empty")
74115

@@ -79,14 +120,21 @@ def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> None:
79120
raise ValueError(
80121
f"The last element of offsets must equal to the number of values, but got {offsets[-1]} and {values.numel()}"
81122
)
123+
return True
82124

83125

84-
def _validate_keys(kjt: KeyedJaggedTensor) -> None:
126+
def _validate_keys(kjt: KeyedJaggedTensor) -> bool:
85127
"""
86128
Validates KJT keys, assuming the lengths/offsets input are valid.
87129
88130
- keys must be unique
89131
- For non-VBE cases, the size of lengths is divisible by the number of keys
132+
133+
Returns:
134+
bool: True if validation passes.
135+
136+
Raises:
137+
ValueError: If validation fails.
90138
"""
91139
keys = kjt.keys()
92140

@@ -110,14 +158,60 @@ def _validate_keys(kjt: KeyedJaggedTensor) -> None:
110158
raise ValueError(
111159
f"lengths size must be divisible by keys size, but got {lengths_size} and {len(keys)}"
112160
)
161+
return True
113162

114163

115-
def _validate_weights(kjt: KeyedJaggedTensor) -> None:
164+
def _validate_weights(kjt: KeyedJaggedTensor) -> bool:
116165
"""
117166
Validates if the KJT weights has the same size as values.
167+
168+
Returns:
169+
bool: True if validation passes.
170+
171+
Raises:
172+
ValueError: If validation fails.
118173
"""
119174
weights = kjt.weights_or_none()
120175
if weights is not None and weights.numel() != kjt.values().numel():
121176
raise ValueError(
122177
f"weights size must equal to values size, but got {weights.numel()} and {kjt.values().numel()}"
123178
)
179+
return True
180+
181+
182+
def _validate_feature_range(
183+
kjt: KeyedJaggedTensor, configs: List[EmbeddingBagConfig]
184+
) -> bool:
185+
"""
186+
Validates if the KJT feature range is valid.
187+
188+
Returns:
189+
bool: True if all features are within valid range, False otherwise.
190+
"""
191+
feature_to_range_map: Dict[str, int] = {}
192+
for config in configs:
193+
for feature in config.feature_names:
194+
feature_to_range_map[feature] = config.num_embeddings
195+
196+
if len(kjt.keys() & feature_to_range_map.keys()) == 0:
197+
logger.info(
198+
f"None of KJT._keys {kjt.keys()} in the config {feature_to_range_map.keys()}"
199+
)
200+
return True
201+
202+
valid = True
203+
jtd = kjt.to_dict()
204+
for feature, jt in jtd.items():
205+
if feature not in feature_to_range_map:
206+
logger.info(f"Feature {feature} is not in the config")
207+
continue
208+
if jt.values().numel() == 0:
209+
continue
210+
min_value, max_value = jt.values().min(), jt.values().max()
211+
if min_value < 0 or max_value >= feature_to_range_map[feature]:
212+
logger.warning(
213+
f"Feature {feature} has range {min_value, max_value} "
214+
f"which is out of range {0, feature_to_range_map[feature]}"
215+
)
216+
valid = False
217+
return valid

0 commit comments

Comments
 (0)