77
88# pyre-strict
99
10+ import logging
11+ from typing import Dict , List , Optional
12+
1013import torch
14+ from torchrec .modules .embedding_configs import EmbeddingBagConfig
1115from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
1216
17+ logger : logging .Logger = logging .getLogger (__name__ )
18+
1319
1420def validate_keyed_jagged_tensor (
15- kjt : KeyedJaggedTensor ,
16- ) -> None :
21+ kjt : KeyedJaggedTensor , configs : Optional [ List [ EmbeddingBagConfig ]] = None
22+ ) -> bool :
1723 """
1824 Validates the inputs that construct a KeyedJaggedTensor.
1925
2026 Any invalid input will result in a ValueError being thrown.
27+
28+ Returns:
29+ bool: True if all validations pass (including feature range),
30+ False if feature range validation fails (soft warning).
2131 """
2232 _validate_lengths_and_offsets (kjt )
2333 _validate_keys (kjt )
2434 _validate_weights (kjt )
35+ if configs is not None :
36+ return _validate_feature_range (kjt , configs )
37+ else :
38+ return True
2539
2640
27- def _validate_lengths_and_offsets (kjt : KeyedJaggedTensor ) -> None :
41+ def _validate_lengths_and_offsets (kjt : KeyedJaggedTensor ) -> bool :
2842 """
2943 Validates the lengths and offsets of a KJT.
3044
3145 - At least one of lengths or offsets is provided
3246 - If both are provided, they are consistent with each other
3347 - The dimensions of these tensors align with the values tensor
48+
49+ Returns:
50+ bool: True if validation passes.
51+
52+ Raises:
53+ ValueError: If validation fails.
3454 """
3555 lengths = kjt .lengths_or_none ()
3656 offsets = kjt .offsets_or_none ()
@@ -44,6 +64,7 @@ def _validate_lengths_and_offsets(kjt: KeyedJaggedTensor) -> None:
4464 _validate_lengths (lengths , kjt .values ())
4565 elif offsets is not None :
4666 _validate_offsets (offsets , kjt .values ())
67+ return True
4768
4869
4970def _validate_lengths_and_offsets_consistency (
@@ -59,16 +80,36 @@ def _validate_lengths_and_offsets_consistency(
5980
6081 if not lengths .equal (torch .diff (offsets )):
6182 raise ValueError ("offsets is not equal to the cumulative sum of lengths" )
83+ return True
6284
6385
64- def _validate_lengths (lengths : torch .Tensor , values : torch .Tensor ) -> None :
86+ def _validate_lengths (lengths : torch .Tensor , values : torch .Tensor ) -> bool :
87+ """
88+ Validates lengths tensor.
89+
90+ Returns:
91+ bool: True if validation passes.
92+
93+ Raises:
94+ ValueError: If validation fails.
95+ """
6596 if lengths .sum ().item () != values .numel ():
6697 raise ValueError (
6798 f"Sum of lengths must equal the number of values, but got { lengths .sum ().item ()} and { values .numel ()} "
6899 )
100+ return True
101+
69102
103+ def _validate_offsets (offsets : torch .Tensor , values : torch .Tensor ) -> bool :
104+ """
105+ Validates offsets tensor.
70106
71- def _validate_offsets (offsets : torch .Tensor , values : torch .Tensor ) -> None :
107+ Returns:
108+ bool: True if validation passes.
109+
110+ Raises:
111+ ValueError: If validation fails.
112+ """
72113 if offsets .numel () == 0 :
73114 raise ValueError ("offsets cannot be empty" )
74115
@@ -79,14 +120,21 @@ def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> None:
79120 raise ValueError (
80121 f"The last element of offsets must equal to the number of values, but got { offsets [- 1 ]} and { values .numel ()} "
81122 )
123+ return True
82124
83125
84- def _validate_keys (kjt : KeyedJaggedTensor ) -> None :
126+ def _validate_keys (kjt : KeyedJaggedTensor ) -> bool :
85127 """
86128 Validates KJT keys, assuming the lengths/offsets input are valid.
87129
88130 - keys must be unique
89131 - For non-VBE cases, the size of lengths is divisible by the number of keys
132+
133+ Returns:
134+ bool: True if validation passes.
135+
136+ Raises:
137+ ValueError: If validation fails.
90138 """
91139 keys = kjt .keys ()
92140
@@ -110,14 +158,60 @@ def _validate_keys(kjt: KeyedJaggedTensor) -> None:
110158 raise ValueError (
111159 f"lengths size must be divisible by keys size, but got { lengths_size } and { len (keys )} "
112160 )
161+ return True
113162
114163
115- def _validate_weights (kjt : KeyedJaggedTensor ) -> None :
164+ def _validate_weights (kjt : KeyedJaggedTensor ) -> bool :
116165 """
117166 Validates if the KJT weights has the same size as values.
167+
168+ Returns:
169+ bool: True if validation passes.
170+
171+ Raises:
172+ ValueError: If validation fails.
118173 """
119174 weights = kjt .weights_or_none ()
120175 if weights is not None and weights .numel () != kjt .values ().numel ():
121176 raise ValueError (
122177 f"weights size must equal to values size, but got { weights .numel ()} and { kjt .values ().numel ()} "
123178 )
179+ return True
180+
181+
182+ def _validate_feature_range (
183+ kjt : KeyedJaggedTensor , configs : List [EmbeddingBagConfig ]
184+ ) -> bool :
185+ """
186+ Validates if the KJT feature range is valid.
187+
188+ Returns:
189+ bool: True if all features are within valid range, False otherwise.
190+ """
191+ feature_to_range_map : Dict [str , int ] = {}
192+ for config in configs :
193+ for feature in config .feature_names :
194+ feature_to_range_map [feature ] = config .num_embeddings
195+
196+ if len (kjt .keys () & feature_to_range_map .keys ()) == 0 :
197+ logger .info (
198+ f"None of KJT._keys { kjt .keys ()} in the config { feature_to_range_map .keys ()} "
199+ )
200+ return True
201+
202+ valid = True
203+ jtd = kjt .to_dict ()
204+ for feature , jt in jtd .items ():
205+ if feature not in feature_to_range_map :
206+ logger .info (f"Feature { feature } is not in the config" )
207+ continue
208+ if jt .values ().numel () == 0 :
209+ continue
210+ min_value , max_value = jt .values ().min (), jt .values ().max ()
211+ if min_value < 0 or max_value >= feature_to_range_map [feature ]:
212+ logger .warning (
213+ f"Feature { feature } has range { min_value , max_value } "
214+ f"which is out of range { 0 , feature_to_range_map [feature ]} "
215+ )
216+ valid = False
217+ return valid
0 commit comments