1+ import functools
12import inspect
3+ import itertools
4+ import string
25
36import numpy as np
47
1215from keras .src .utils import jax_utils
1316from keras .src .utils import tracking
1417from keras .src .utils .module_utils import jax
18+ from keras .src .utils .module_utils import tensorflow as tf
19+
20+ if backend .backend () == "tensorflow" :
21+ tf_no_automatic_dependency_tracking = (
22+ tf .__internal__ .tracking .no_automatic_dependency_tracking
23+ )
24+ else :
25+
26+ def tf_no_automatic_dependency_tracking (fn ):
27+ return fn
28+
29+
30+ def _convert_to_jax_key (tensor ):
31+ if backend .backend () == "tensorflow" :
32+ return tf .bitcast (tensor , tf .uint32 )[0 ]
33+ return tensor
1534
1635
1736@keras_export ("keras.layers.JaxLayer" )
@@ -219,10 +238,10 @@ def __init__(
219238 seed = None ,
220239 ** kwargs ,
221240 ):
222- if backend .backend () != "jax" :
241+ if backend .backend () not in [ "jax" , "tensorflow" ] :
223242 raise ValueError (
224- "JaxLayer is only supported with the JAX backend. Current "
225- f"backend: { backend .backend ()} "
243+ f" { self . __class__ . __name__ } is only supported with the JAX or "
244+ f" Tensorflow backend. Current backend: { backend .backend ()} "
226245 )
227246
228247 if init_fn is None and params is None and state is None :
@@ -252,6 +271,10 @@ def __init__(
252271 init_fn , "init_fn" , {"rng" , "inputs" , "training" }, {"inputs" }
253272 )
254273
274+ # Attributes for jax2tf functions
275+ self .jax2tf_training_false_fn = None
276+ self .jax2tf_training_true_fn = None
277+
255278 def _validate_signature (self , fn , fn_name , allowed , required ):
256279 fn_parameters = inspect .signature (fn ).parameters
257280 for parameter_name in required :
@@ -272,7 +295,81 @@ def _validate_signature(self, fn, fn_name, allowed, required):
272295
273296 return parameter_names
274297
298+ def _get_jax2tf_input_shape (self , input_shape ):
299+ """Convert input shape in a format suitable for `jax2tf`.
300+
301+ `jax2tf` expects a letter for each unknown dimension, which allows
302+ correlated dimensions. Since correlated dimensions are not supported by
303+ Keras, we simply use 'a', 'b', 'c'..., for each unknown dimension. We
304+ however use 'batch' for dimension 0 if not defined to correlate the
305+ batch size across inputs.
306+
307+ Example (spaces added for readability):
308+ ```
309+ input_shape: (None , 4 , None, None, 5 )
310+ result: "(batch, 4 , a , b , 5 )"
311+ ```
312+
313+ Args:
314+ input_shape: a single shape or a structure of shapes for the inputs.
315+ Returns:
316+ the shape or shapes structure in the `jax2tf` format as strings.
317+ """
318+ dim_names = itertools .chain (
319+ string .ascii_lowercase , # a, b, ... z
320+ itertools .starmap ( # aa, ab, ... az, ba, bb, ... zz
321+ lambda a , b : a + b ,
322+ itertools .product (string .ascii_lowercase , repeat = 2 ),
323+ ),
324+ )
325+
326+ def get_single_jax2tf_shape (shape ):
327+ jax2tf_shape = []
328+
329+ for index , dim in enumerate (shape ):
330+ if dim is not None :
331+ jax2tf_shape .append (str (dim ))
332+ elif index == 0 :
333+ jax2tf_shape .append ("batch" )
334+ else :
335+ jax2tf_shape .append (next (dim_names ))
336+
337+ return "(" + ", " .join (jax2tf_shape ) + ")"
338+
339+ res = tree .map_shape_structure (get_single_jax2tf_shape , input_shape )
340+ return res
341+
342+ def _jax2tf_convert (self , fn , polymorphic_shapes ):
343+ from jax .experimental import jax2tf
344+
345+ converted_fn = jax2tf .convert (fn , polymorphic_shapes = polymorphic_shapes )
346+ # Autograph won't work with the output of jax2tf.
347+ converted_fn = tf .autograph .experimental .do_not_convert (converted_fn )
348+ return converted_fn
349+
350+ def _partial_with_positional (self , fn , index , value ):
351+ """Return a new partial with one positional argument set to a value.
352+
353+ This is needed because `jax2tf` only supports positional arguments and
354+ `functools.partial` only supports setting positional arguments starting
355+ from the left. Our use case is the `training` argument which is
356+ typically the righmost argument.
357+
358+ Args:
359+ fn: the function to wrap.
360+ index: the index of the positional argument to set to `value`.
361+ value: the value for the positional argument at `index`.
362+ """
363+
364+ @functools .wraps (fn )
365+ def wrapper (* args ):
366+ args = args [0 :index ] + (value ,) + args [index :]
367+ return fn (* args )
368+
369+ return wrapper
370+
275371 @tracking .no_automatic_dependency_tracking
372+ @tf_no_automatic_dependency_tracking
276373 def _create_variables (self , values , trainable ):
277374 """Create a structure of variables from a structure of JAX arrays.
278375
@@ -296,14 +393,14 @@ def _create_variables(self, values, trainable):
296393
297394 def create_variable (value ):
298395 if backend .is_tensor (value ) or isinstance (
299- value , (np .ndarray , np .generic )
396+ value , (np .ndarray , np .generic , jax . Array )
300397 ):
301398 dtype = value .dtype
302399 if is_float_dtype (dtype ):
303400 dtype = None # Use the layer dtype policy
304401 return self .add_weight (
305402 value .shape ,
306- initializer = value ,
403+ initializer = backend . convert_to_tensor ( value ) ,
307404 dtype = dtype ,
308405 trainable = trainable ,
309406 )
@@ -333,44 +430,46 @@ def create_variable(value):
333430
334431 def _get_init_rng (self ):
335432 """
336- Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `init_fn`.
433+ Returns a key in form of the backend array of size 2 dtype uint32
434+ to pass to `init_fn`.
337435
338- By default, this returns a single `PRNGKey` retrieved by calling
436+ By default, this returns a Jax or TF array of size 2 by calling
339437 `self.seed_generator.next()`. Override this to return a different
340438 structure.
341439
342440 Returns:
343- a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
344- the `rng` argument of `init_fn`.
441+ a key as an Jax or TF array of size 2 dtype uint32 will be passed
442+ as the `rng` argument of `init_fn`.
345443 """
346444 return self .seed_generator .next ()
347445
348446 def _get_call_rng (self , training ):
349447 """
350- Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `call_fn`.
448+ Returns a key in form of the backend array of size 2 dtype uint32
449+ to pass to `call_fn`.
351450
352- By default, this returns a single `PRNGKey` retrieved by calling
451+ By default, this returns a Jax or TF array of size 2 by calling
353452 `self.seed_generator.next()` when `training` is `True`, and `None` when
354453 `training` is `False`. Override this to return a different structure or
355454 to pass RNGs in inference mode too.
356455
357456 Returns:
358- a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
359- the `rng` argument of `call_fn`.
457+ a key as an Jax or TF array of size 2 dtype uint32 will be passed
458+ as the `rng` argument of `call_fn`.
360459 """
361460 if training :
362461 return self .seed_generator .next ()
363462 else :
364463 return None
365464
366- def build (self , input_shape ):
367- if self .params is not None or self .state is not None :
368- return
369-
370- if jax_utils .is_in_jax_tracing_scope ():
465+ def _initialize_weights (self , input_shape ):
466+ if jax_utils .is_in_jax_tracing_scope () or tf .inside_function ():
371467 # This exception is not actually shown, it is caught and a detailed
372468 # warning about calling 'build' is printed.
373- raise ValueError ("'JaxLayer' cannot be built in tracing scope" )
469+ raise ValueError (
470+ "'JaxLayer' cannot be built in tracing scope"
471+ "or inside tf function"
472+ )
374473
375474 # Initialize `params` and `state` if needed by calling `init_fn`.
376475 def create_input (shape ):
@@ -381,7 +480,12 @@ def create_input(shape):
381480 init_args = []
382481 for argument_name in self .init_fn_arguments :
383482 if argument_name == "rng" :
384- init_args .append (self ._get_init_rng ())
483+ init_args .append (
484+ jax .tree_util .tree_map (
485+ lambda x : jax .numpy .array (_convert_to_jax_key (x )),
486+ self ._get_init_rng (),
487+ )
488+ )
385489 elif argument_name == "inputs" :
386490 init_args .append (init_inputs )
387491 elif argument_name == "training" :
@@ -398,6 +502,45 @@ def create_input(shape):
398502 )
399503 self .tracked_state = self ._create_variables (init_state , trainable = False )
400504
505+ def build (self , input_shape ):
506+ if self .params is None and self .state is None :
507+ self ._initialize_weights (input_shape )
508+
509+ if backend .backend () == "tensorflow" :
510+ polymorphic_shapes = []
511+ for argument in self .call_fn_arguments :
512+ if argument == "inputs" :
513+ polymorphic_shapes .append (
514+ self ._get_jax2tf_input_shape (input_shape )
515+ )
516+ elif argument != "training" :
517+ # params, state, rng
518+ polymorphic_shapes .append ("..." )
519+
520+ if "training" in self .call_fn_arguments :
521+ training_argument_index = self .call_fn_arguments .index (
522+ "training"
523+ )
524+ self .jax2tf_training_false_fn = self ._jax2tf_convert (
525+ self ._partial_with_positional (
526+ self .call_fn , training_argument_index , False
527+ ),
528+ polymorphic_shapes ,
529+ )
530+ self .jax2tf_training_true_fn = self ._jax2tf_convert (
531+ self ._partial_with_positional (
532+ self .call_fn , training_argument_index , True
533+ ),
534+ polymorphic_shapes ,
535+ )
536+ else :
537+ self .jax2tf_training_false_fn = self ._jax2tf_convert (
538+ self .call_fn ,
539+ polymorphic_shapes ,
540+ )
541+ self .jax2tf_training_true_fn = None
542+ super ().build (input_shape )
543+
401544 def call (self , inputs , training = False ):
402545 def unwrap_variable (variable ):
403546 return None if variable is None else variable .value
@@ -413,11 +556,16 @@ def unwrap_variable(variable):
413556 jax .tree_util .tree_map (unwrap_variable , self .state )
414557 )
415558 elif argument_name == "rng" :
416- call_args .append (self ._get_call_rng (training ))
559+ call_args .append (
560+ jax .tree_util .tree_map (
561+ _convert_to_jax_key , self ._get_call_rng (training )
562+ )
563+ )
417564 elif argument_name == "inputs" :
418565 call_args .append (inputs )
419566 elif argument_name == "training" :
420- call_args .append (training )
567+ if backend .backend () == "jax" :
568+ call_args .append (training )
421569
422570 def assign_state_to_variable (value , variable ):
423571 # This exists only to make debugging this error case easier.
@@ -429,14 +577,23 @@ def assign_state_to_variable(value, variable):
429577 )
430578 variable .assign (value )
431579
432- if self .has_state :
433- predictions , new_state = self .call_fn (* call_args )
434- jax .tree_util .tree_map (
435- assign_state_to_variable , new_state , self .state
436- )
437- return predictions
438- else :
439- return self .call_fn (* call_args )
580+ def call_with_fn (fn ):
581+ if self .has_state :
582+ predictions , new_state = fn (* call_args )
583+ jax .tree_util .tree_map (
584+ assign_state_to_variable , new_state , self .state
585+ )
586+ return predictions
587+ else :
588+ return fn (* call_args )
589+
590+ if backend .backend () == "jax" :
591+ return call_with_fn (self .call_fn )
592+ elif backend .backend () == "tensorflow" :
593+ if training and self .jax2tf_training_true_fn is not None :
594+ return call_with_fn (self .jax2tf_training_true_fn )
595+ else :
596+ return call_with_fn (self .jax2tf_training_false_fn )
440597
441598 def get_config (self ):
442599 config = {
@@ -556,12 +713,6 @@ def __init__(
556713 # Late import to only require Flax when this is used.
557714 from flax .core import scope as flax_scope
558715
559- if backend .backend () != "jax" :
560- raise ValueError (
561- "FlaxLayer is only supported with the JAX backend. Current "
562- f"backend: { backend .backend ()} "
563- )
564-
565716 self .module = module
566717 self .method = method
567718
0 commit comments