2121from torchao .prototype .parq .quant import (
2222 Int4UnifTorchaoQuantizer ,
2323 LSBQuantizer ,
24- Quantizer ,
2524 StretchedIntxWeightConfig ,
2625 StretchedUnifTorchaoQuantizer ,
2726 TernaryUnifQuantizer ,
@@ -162,29 +161,43 @@ def build_param_groups(
162161 model ,
163162 b : int = 2 ,
164163 group_size : Optional [int ] = None ,
165- quantizer : Optional [ Quantizer ] = None ,
164+ embed_b : int = 4 ,
166165):
167166 params_quant , params_embed , params_no_quant = split_param_groups (model )
168167 quant_kwargs = {}
169168 if group_size :
170169 quant_kwargs ["quant_block_size" ] = group_size
171- if quantizer is not None :
172- quant_kwargs ["quantizer" ] = quantizer
173170 param_groups = [
174171 {"params" : params_quant , "quant_bits" : b , ** quant_kwargs },
175172 {"params" : params_no_quant },
176173 ]
177174 if params_embed :
178- param_groups .append (
179- {
180- "params" : params_embed ,
181- "quant_bits" : 4 ,
182- "quantizer" : UnifTorchaoQuantizer (),
183- }
184- )
175+ param_groups .append ({"params" : params_embed , "quant_bits" : embed_b })
185176 return param_groups
186177
187178
179+ def get_optim_kwargs (
180+ model , base_optimizer , embedding = True , quant_cls = UnifTorchaoQuantizer
181+ ):
182+ optim_kwargs = {}
183+ if embedding :
184+ embed_data_ptrs = set (
185+ (
186+ m .weight .data_ptr ()
187+ for m in model .modules ()
188+ if isinstance (m , nn .Embedding )
189+ )
190+ )
191+ group_idx = - 1
192+ for i , group in enumerate (base_optimizer .param_groups ):
193+ if all (p .data_ptr () in embed_data_ptrs for p in group ["params" ]):
194+ group_idx = i
195+ break
196+ assert group_idx > - 1
197+ optim_kwargs ["group_quantizer_map" ] = {group_idx : quant_cls ()}
198+ return optim_kwargs
199+
200+
188201def compare_quantized_models (
189202 model : nn .Module ,
190203 m_ref : nn .Module ,
@@ -222,7 +235,7 @@ def compare_parq_convert(
222235 orig_model = copy .deepcopy (model ) # save copy of PARQ quantized model
223236
224237 # equivalent to torchao's convert step
225- optimizer .torchao_convert (model , weight_only = weight_only )
238+ optimizer .torchao_convert (model , weight_only = weight_only , embed_weight_only = True )
226239
227240 inputs = model .example_inputs (device = _DEVICE )
228241 torch .testing .assert_close (model (inputs ), orig_model (inputs ))
@@ -290,15 +303,16 @@ def test_parq_train_loop(
290303 quantizer = TernaryUnifQuantizer () if b == 0 else UnifQuantizer ()
291304 else :
292305 quantizer = LSBQuantizer ()
293- param_groups = build_param_groups (
294- model , b , quantizer = quantizer if per_group_quantizer else None
295- )
306+ param_groups = build_param_groups (model , b , embed_b = b )
296307 base_optimizer = torch .optim .AdamW (param_groups )
297308
298309 prox_map = (
299310 ProxHardQuant () if hard_prox else ProxPARQ (anneal_start = 0 , anneal_end = 2 )
300311 )
301- optimizer = QuantOptimizer (base_optimizer , quantizer , prox_map )
312+ optim_kwargs = get_optim_kwargs (
313+ model , base_optimizer , quant_cls = type (quantizer ), embedding = False
314+ )
315+ optimizer = QuantOptimizer (base_optimizer , quantizer , prox_map , ** optim_kwargs )
302316 for _ in range (3 ):
303317 x = model .example_inputs (device = _DEVICE )
304318 out = model (x )
@@ -367,11 +381,13 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
367381
368382 b = 4
369383 base_optimizer = torch .optim .AdamW (build_param_groups (model , b , group_size ))
384+ optim_kwargs = get_optim_kwargs (model , base_optimizer , embedding = False )
370385 optimizer = QuantOptimizer (
371386 base_optimizer ,
372387 Int4UnifTorchaoQuantizer (),
373388 ProxHardQuant (),
374389 quant_per_channel = True ,
390+ ** optim_kwargs ,
375391 )
376392 compare_parq_convert (model , m_ref , optimizer , weight_only = True )
377393
@@ -387,11 +403,13 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
387403 quantize_ (m_ref , config )
388404
389405 base_optimizer = torch .optim .AdamW (build_param_groups (model , b , group_size ))
406+ optim_kwargs = get_optim_kwargs (model , base_optimizer , embedding = False )
390407 optimizer = QuantOptimizer (
391408 base_optimizer ,
392409 UnifTorchaoQuantizer (),
393410 ProxHardQuant (),
394411 quant_per_channel = True ,
412+ ** optim_kwargs ,
395413 )
396414 compare_parq_convert (model , m_ref , optimizer , weight_only = True )
397415 check_torchao_tensor_subclass (self , model , weight_only = True )
@@ -462,11 +480,13 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
462480 quantize_ (m_ref , config , filter_fn = _is_linear )
463481
464482 base_optimizer = torch .optim .AdamW (build_param_groups (model , b , group_size ))
483+ optim_kwargs = get_optim_kwargs (model , base_optimizer , embedding = False )
465484 optimizer = QuantOptimizer (
466485 base_optimizer ,
467486 quantizer ,
468487 ProxHardQuant (),
469488 quant_per_channel = True ,
489+ ** optim_kwargs ,
470490 )
471491 compare_parq_convert (model , m_ref , optimizer , weight_only = True )
472492 check_torchao_tensor_subclass (self , model , weight_only = True )
@@ -482,14 +502,19 @@ def test_intx_weight_only_tied_embed_linear(
482502
483503 quantizer = StretchedUnifTorchaoQuantizer (b )
484504 base_optimizer = torch .optim .SGD (build_param_groups (model , b ))
505+ optim_kwargs = get_optim_kwargs (model , base_optimizer )
485506 optimizer = QuantOptimizer (
486- base_optimizer , quantizer , ProxHardQuant (), quant_per_channel = True
507+ base_optimizer ,
508+ quantizer ,
509+ ProxHardQuant (),
510+ quant_per_channel = True ,
511+ ** optim_kwargs ,
487512 )
488513 optimizer .zero_grad ()
489514 optimizer .step ()
490515
491516 apply_activation_quantization (model , optimizer , model_dtype )
492- optimizer .torchao_convert (model )
517+ optimizer .torchao_convert (model , embed_weight_only = True )
493518 check_torchao_tensor_subclass (self , model )
494519 self .assertTrue (
495520 torch .equal (model .embed_tokens .weight .qdata , model .linear2 .weight .qdata )
@@ -531,8 +556,13 @@ def test_int8_dynamic_activation_intx_e2e(
531556
532557 # quantize weights with PARQ
533558 base_optimizer = torch .optim .SGD (build_param_groups (model , b , group_size ))
559+ optim_kwargs = get_optim_kwargs (model , base_optimizer , embedding = False )
534560 optimizer = QuantOptimizer (
535- base_optimizer , quantizer , ProxHardQuant (), quant_per_channel = True
561+ base_optimizer ,
562+ quantizer ,
563+ ProxHardQuant (),
564+ quant_per_channel = True ,
565+ ** optim_kwargs ,
536566 )
537567
538568 optimizer .zero_grad ()
0 commit comments