@@ -46,6 +46,7 @@ def exposed_init_model(self, kvargs):
4646 quant_type = self .args .vit_quant_type
4747 quant_cfg = self .args .vit_quant_cfg
4848 max_batch_size = min (self .args .visual_infer_batch_size // self .args .visual_dp , 1 )
49+ remote_vit = True if self .args .run_mode == "visual" else False
4950
5051 self .dp_rank_id = kvargs ["dp_rank_id" ]
5152 self .tp_rank_id = kvargs ["tp_rank_id" ]
@@ -62,10 +63,11 @@ def exposed_init_model(self, kvargs):
6263 "quant_type" : quant_type ,
6364 "quant_cfg" : quant_cfg ,
6465 "max_batch_size" : max_batch_size ,
66+ "remote_vit" : remote_vit ,
6567 }
6668 self .model_type = model_cfg ["model_type" ]
6769 if self .model_type == "qwen" :
68- self .model = QWenVisionTransformer (** model_cfg ["visual" ]).eval ().bfloat16 ()
70+ self .model = QWenVisionTransformer (kvargs , ** model_cfg ["visual" ]).eval ().bfloat16 ()
6971 elif self .model_type == "qwen2_vl" :
7072 self .model = (
7173 Qwen2VisionTransformerPretrainedModel (kvargs , ** model_cfg ["vision_config" ]).eval ().bfloat16 ()
@@ -75,14 +77,14 @@ def exposed_init_model(self, kvargs):
7577 Qwen2_5_VisionTransformerPretrainedModel (kvargs , ** model_cfg ["vision_config" ]).eval ().bfloat16 ()
7678 )
7779 elif model_cfg ["architectures" ][0 ] == "TarsierForConditionalGeneration" :
78- self .model = TarsierVisionTransformerPretrainedModel (** model_cfg ).eval ().bfloat16 ()
80+ self .model = TarsierVisionTransformerPretrainedModel (kvargs , ** model_cfg ).eval ().bfloat16 ()
7981 elif self .model_type == "llava" :
80- self .model = LlavaVisionModel ()
82+ self .model = LlavaVisionModel (kvargs )
8183 elif self .model_type == "internvl_chat" :
8284 self .model = VisionTransformer (kvargs )
8385 # self.model = InternVLVisionModel()
8486 elif self .model_type == "gemma3" :
85- self .model = Gemma3VisionModel ()
87+ self .model = Gemma3VisionModel (kvargs )
8688 else :
8789 raise Exception (f"can not support { self .model_type } now" )
8890 self .model .load_model (weight_dir )
0 commit comments