77
88# pyre-strict
99
10+ import logging
11+ from typing import Dict , List , Optional
12+
1013import torch
14+ from torchrec .modules .embedding_configs import EmbeddingBagConfig
1115from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
1216
17+ logger : logging .Logger = logging .getLogger (__name__ )
18+
1319
1420def validate_keyed_jagged_tensor (
15- kjt : KeyedJaggedTensor ,
16- ) -> None :
21+ kjt : KeyedJaggedTensor , configs : Optional [ List [ EmbeddingBagConfig ]] = None
22+ ) -> bool :
1723 """
1824 Validates the inputs that construct a KeyedJaggedTensor.
1925
2026 Any invalid input will result in a ValueError being thrown.
27+
28+ Returns:
29+ bool: True if all validations pass (including feature range),
30+ False if feature range validation fails (soft warning).
2131 """
2232 _validate_lengths_and_offsets (kjt )
2333 _validate_keys (kjt )
2434 _validate_weights (kjt )
35+ return _validate_feature_range (kjt , configs )
2536
2637
27- def _validate_lengths_and_offsets (kjt : KeyedJaggedTensor ) -> None :
38+ def _validate_lengths_and_offsets (kjt : KeyedJaggedTensor ) -> bool :
2839 """
2940 Validates the lengths and offsets of a KJT.
3041
3142 - At least one of lengths or offsets is provided
3243 - If both are provided, they are consistent with each other
3344 - The dimensions of these tensors align with the values tensor
45+
46+ Returns:
47+ bool: True if validation passes.
48+
49+ Raises:
50+ ValueError: If validation fails.
3451 """
3552 lengths = kjt .lengths_or_none ()
3653 offsets = kjt .offsets_or_none ()
@@ -44,6 +61,7 @@ def _validate_lengths_and_offsets(kjt: KeyedJaggedTensor) -> None:
4461 _validate_lengths (lengths , kjt .values ())
4562 elif offsets is not None :
4663 _validate_offsets (offsets , kjt .values ())
64+ return True
4765
4866
4967def _validate_lengths_and_offsets_consistency (
@@ -59,16 +77,36 @@ def _validate_lengths_and_offsets_consistency(
5977
6078 if not lengths .equal (torch .diff (offsets )):
6179 raise ValueError ("offsets is not equal to the cumulative sum of lengths" )
80+ return True
6281
6382
64- def _validate_lengths (lengths : torch .Tensor , values : torch .Tensor ) -> None :
83+ def _validate_lengths (lengths : torch .Tensor , values : torch .Tensor ) -> bool :
84+ """
85+ Validates lengths tensor.
86+
87+ Returns:
88+ bool: True if validation passes.
89+
90+ Raises:
91+ ValueError: If validation fails.
92+ """
6593 if lengths .sum ().item () != values .numel ():
6694 raise ValueError (
6795 f"Sum of lengths must equal the number of values, but got { lengths .sum ().item ()} and { values .numel ()} "
6896 )
97+ return True
98+
99+
100+ def _validate_offsets (offsets : torch .Tensor , values : torch .Tensor ) -> bool :
101+ """
102+ Validates offsets tensor.
69103
104+ Returns:
105+ bool: True if validation passes.
70106
71- def _validate_offsets (offsets : torch .Tensor , values : torch .Tensor ) -> None :
107+ Raises:
108+ ValueError: If validation fails.
109+ """
72110 if offsets .numel () == 0 :
73111 raise ValueError ("offsets cannot be empty" )
74112
@@ -79,14 +117,21 @@ def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> None:
79117 raise ValueError (
80118 f"The last element of offsets must equal to the number of values, but got { offsets [- 1 ]} and { values .numel ()} "
81119 )
120+ return True
82121
83122
84- def _validate_keys (kjt : KeyedJaggedTensor ) -> None :
123+ def _validate_keys (kjt : KeyedJaggedTensor ) -> bool :
85124 """
86125 Validates KJT keys, assuming the lengths/offsets input are valid.
87126
88127 - keys must be unique
89128 - For non-VBE cases, the size of lengths is divisible by the number of keys
129+
130+ Returns:
131+ bool: True if validation passes.
132+
133+ Raises:
134+ ValueError: If validation fails.
90135 """
91136 keys = kjt .keys ()
92137
@@ -110,14 +155,60 @@ def _validate_keys(kjt: KeyedJaggedTensor) -> None:
110155 raise ValueError (
111156 f"lengths size must be divisible by keys size, but got { lengths_size } and { len (keys )} "
112157 )
158+ return True
113159
114160
115- def _validate_weights (kjt : KeyedJaggedTensor ) -> None :
161+ def _validate_weights (kjt : KeyedJaggedTensor ) -> bool :
116162 """
117163 Validates if the KJT weights has the same size as values.
164+
165+ Returns:
166+ bool: True if validation passes.
167+
168+ Raises:
169+ ValueError: If validation fails.
118170 """
119171 weights = kjt .weights_or_none ()
120172 if weights is not None and weights .numel () != kjt .values ().numel ():
121173 raise ValueError (
122174 f"weights size must equal to values size, but got { weights .numel ()} and { kjt .values ().numel ()} "
123175 )
176+ return True
177+
178+
179+ def _validate_feature_range (
180+ kjt : KeyedJaggedTensor , configs : List [EmbeddingBagConfig ]
181+ ) -> bool :
182+ """
183+ Validates if the KJT feature range is valid.
184+
185+ Returns:
186+ bool: True if all features are within valid range, False otherwise.
187+ """
188+ feature_to_range_map : Dict [str , int ] = {}
189+ for config in configs :
190+ for feature in config .feature_names :
191+ feature_to_range_map [feature ] = config .num_embeddings
192+
193+ if len (kjt .keys () & feature_to_range_map .keys ()) == 0 :
194+ logger .info (
195+ f"None of KJT._keys { kjt .keys ()} in the config { feature_to_range_map .keys ()} "
196+ )
197+ return True
198+
199+ valid = True
200+ jtd = kjt .to_dict ()
201+ for feature , jt in jtd .items ():
202+ if feature not in feature_to_range_map :
203+ logger .info (f"Feature { feature } is not in the config" )
204+ continue
205+ if jt .values ().numel () == 0 :
206+ continue
207+ min_value , max_value = jt .values ().min (), jt .values ().max ()
208+ if min_value < 0 or max_value >= feature_to_range_map [feature ]:
209+ logger .warning (
210+ f"Feature { feature } has range { min_value , max_value } "
211+ f"which is out of range { 0 , feature_to_range_map [feature ]} "
212+ )
213+ valid = False
214+ return valid
0 commit comments