@@ -164,8 +164,9 @@ def load_model_config_from_hf(model_id: str):
164164 if 'label_descriptions' in hf_config :
165165 pretrained_cfg ['label_descriptions' ] = hf_config .pop ('label_descriptions' )
166166
167+ model_args = hf_config .get ('model_args' , {})
167168 model_name = hf_config ['architecture' ]
168- return pretrained_cfg , model_name
169+ return pretrained_cfg , model_name , model_args
169170
170171
171172def load_state_dict_from_hf (model_id : str , filename : str = HF_WEIGHTS_NAME ):
@@ -193,19 +194,23 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
193194def save_config_for_hf (
194195 model ,
195196 config_path : str ,
196- model_config : Optional [dict ] = None
197+ model_config : Optional [dict ] = None ,
198+ model_args : Optional [dict ] = None
197199):
198200 model_config = model_config or {}
199201 hf_config = {}
200202 pretrained_cfg = filter_pretrained_cfg (model .pretrained_cfg , remove_source = True , remove_null = True )
201203 # set some values at root config level
202204 hf_config ['architecture' ] = pretrained_cfg .pop ('architecture' )
203- hf_config ['num_classes' ] = model_config .get ('num_classes' , model .num_classes )
204- hf_config ['num_features' ] = model_config .get ('num_features' , model .num_features )
205- global_pool_type = model_config .get ('global_pool' , getattr (model , 'global_pool' , None ))
205+ hf_config ['num_classes' ] = model_config .pop ('num_classes' , model .num_classes )
206+
207+ # NOTE these attr saved for informational purposes, do not impact model build
208+ hf_config ['num_features' ] = model_config .pop ('num_features' , model .num_features )
209+ global_pool_type = model_config .pop ('global_pool' , getattr (model , 'global_pool' , None ))
206210 if isinstance (global_pool_type , str ) and global_pool_type :
207211 hf_config ['global_pool' ] = global_pool_type
208212
213+ # Save class label info
209214 if 'labels' in model_config :
210215 _logger .warning (
211216 "'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
@@ -225,6 +230,9 @@ def save_config_for_hf(
225230 # maps label names -> descriptions
226231 hf_config ['label_descriptions' ] = label_descriptions
227232
233+ if model_args :
234+ hf_config ['model_args' ] = model_args
235+
228236 hf_config ['pretrained_cfg' ] = pretrained_cfg
229237 hf_config .update (model_config )
230238
@@ -236,6 +244,7 @@ def save_for_hf(
236244 model ,
237245 save_directory : str ,
238246 model_config : Optional [dict ] = None ,
247+ model_args : Optional [dict ] = None ,
239248 safe_serialization : Union [bool , Literal ["both" ]] = False ,
240249):
241250 assert has_hf_hub (True )
@@ -251,11 +260,16 @@ def save_for_hf(
251260 torch .save (tensors , save_directory / HF_WEIGHTS_NAME )
252261
253262 config_path = save_directory / 'config.json'
254- save_config_for_hf (model , config_path , model_config = model_config )
263+ save_config_for_hf (
264+ model ,
265+ config_path ,
266+ model_config = model_config ,
267+ model_args = model_args ,
268+ )
255269
256270
257271def push_to_hf_hub (
258- model ,
272+ model : torch . nn . Module ,
259273 repo_id : str ,
260274 commit_message : str = 'Add model' ,
261275 token : Optional [str ] = None ,
@@ -264,6 +278,7 @@ def push_to_hf_hub(
264278 create_pr : bool = False ,
265279 model_config : Optional [dict ] = None ,
266280 model_card : Optional [dict ] = None ,
281+ model_args : Optional [dict ] = None ,
267282 safe_serialization : Union [bool , Literal ["both" ]] = False ,
268283):
269284 """
@@ -291,7 +306,13 @@ def push_to_hf_hub(
291306 # Dump model and push to Hub
292307 with TemporaryDirectory () as tmpdir :
293308 # Save model weights and config.
294- save_for_hf (model , tmpdir , model_config = model_config , safe_serialization = safe_serialization )
309+ save_for_hf (
310+ model ,
311+ tmpdir ,
312+ model_config = model_config ,
313+ model_args = model_args ,
314+ safe_serialization = safe_serialization ,
315+ )
295316
296317 # Add readme if it does not exist
297318 if not has_readme :
0 commit comments