Skip to content

Commit e048ae4

Browse files
Support jax2tf in JaxLayer for tf backend (#21842)
* Support jax2tf in JaxLayer for tf backend * Update keras/src/utils/jax_layer.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * format * address comments * lint * local import * format * Use backend respective seed generator, but passing dtype uint32 as that's the dtype for jax key. * format * resolve comments * update docstring * resolve comments * delete log * use jax random dtype * remove lines * local import * format * Change seed_gen to backend respective dtype and convert later for gpu test * format * format again * rever seed_generator * fix jax backend bug * address comment * skip gpu test --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent edbf8f5 commit e048ae4

File tree

3 files changed

+364
-45
lines changed

3 files changed

+364
-45
lines changed

keras/src/utils/jax_layer.py

Lines changed: 187 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import functools
12
import inspect
3+
import itertools
4+
import string
25

36
import numpy as np
47

@@ -12,6 +15,22 @@
1215
from keras.src.utils import jax_utils
1316
from keras.src.utils import tracking
1417
from keras.src.utils.module_utils import jax
18+
from keras.src.utils.module_utils import tensorflow as tf
19+
20+
if backend.backend() == "tensorflow":
21+
tf_no_automatic_dependency_tracking = (
22+
tf.__internal__.tracking.no_automatic_dependency_tracking
23+
)
24+
else:
25+
26+
def tf_no_automatic_dependency_tracking(fn):
27+
return fn
28+
29+
30+
def _convert_to_jax_key(tensor):
31+
if backend.backend() == "tensorflow":
32+
return tf.bitcast(tensor, tf.uint32)[0]
33+
return tensor
1534

1635

1736
@keras_export("keras.layers.JaxLayer")
@@ -219,10 +238,10 @@ def __init__(
219238
seed=None,
220239
**kwargs,
221240
):
222-
if backend.backend() != "jax":
241+
if backend.backend() not in ["jax", "tensorflow"]:
223242
raise ValueError(
224-
"JaxLayer is only supported with the JAX backend. Current "
225-
f"backend: {backend.backend()}"
243+
f"{self.__class__.__name__} is only supported with the JAX or"
244+
f" Tensorflow backend. Current backend: {backend.backend()}"
226245
)
227246

228247
if init_fn is None and params is None and state is None:
@@ -252,6 +271,10 @@ def __init__(
252271
init_fn, "init_fn", {"rng", "inputs", "training"}, {"inputs"}
253272
)
254273

274+
# Attributes for jax2tf functions
275+
self.jax2tf_training_false_fn = None
276+
self.jax2tf_training_true_fn = None
277+
255278
def _validate_signature(self, fn, fn_name, allowed, required):
256279
fn_parameters = inspect.signature(fn).parameters
257280
for parameter_name in required:
@@ -272,7 +295,81 @@ def _validate_signature(self, fn, fn_name, allowed, required):
272295

273296
return parameter_names
274297

298+
def _get_jax2tf_input_shape(self, input_shape):
299+
"""Convert input shape in a format suitable for `jax2tf`.
300+
301+
`jax2tf` expects a letter for each unknown dimension, which allows
302+
correlated dimensions. Since correlated dimensions are not supported by
303+
Keras, we simply use 'a', 'b', 'c'..., for each unknown dimension. We
304+
however use 'batch' for dimension 0 if not defined to correlate the
305+
batch size across inputs.
306+
307+
Example (spaces added for readability):
308+
```
309+
input_shape: (None , 4 , None, None, 5 )
310+
result: "(batch, 4 , a , b , 5 )"
311+
```
312+
313+
Args:
314+
input_shape: a single shape or a structure of shapes for the inputs.
315+
Returns:
316+
the shape or shapes structure in the `jax2tf` format as strings.
317+
"""
318+
dim_names = itertools.chain(
319+
string.ascii_lowercase, # a, b, ... z
320+
itertools.starmap( # aa, ab, ... az, ba, bb, ... zz
321+
lambda a, b: a + b,
322+
itertools.product(string.ascii_lowercase, repeat=2),
323+
),
324+
)
325+
326+
def get_single_jax2tf_shape(shape):
327+
jax2tf_shape = []
328+
329+
for index, dim in enumerate(shape):
330+
if dim is not None:
331+
jax2tf_shape.append(str(dim))
332+
elif index == 0:
333+
jax2tf_shape.append("batch")
334+
else:
335+
jax2tf_shape.append(next(dim_names))
336+
337+
return "(" + ", ".join(jax2tf_shape) + ")"
338+
339+
res = tree.map_shape_structure(get_single_jax2tf_shape, input_shape)
340+
return res
341+
342+
def _jax2tf_convert(self, fn, polymorphic_shapes):
343+
from jax.experimental import jax2tf
344+
345+
converted_fn = jax2tf.convert(fn, polymorphic_shapes=polymorphic_shapes)
346+
# Autograph won't work with the output of jax2tf.
347+
converted_fn = tf.autograph.experimental.do_not_convert(converted_fn)
348+
return converted_fn
349+
350+
def _partial_with_positional(self, fn, index, value):
351+
"""Return a new partial with one positional argument set to a value.
352+
353+
This is needed because `jax2tf` only supports positional arguments and
354+
`functools.partial` only supports setting positional arguments starting
355+
from the left. Our use case is the `training` argument which is
356+
typically the righmost argument.
357+
358+
Args:
359+
fn: the function to wrap.
360+
index: the index of the positional argument to set to `value`.
361+
value: the value for the positional argument at `index`.
362+
"""
363+
364+
@functools.wraps(fn)
365+
def wrapper(*args):
366+
args = args[0:index] + (value,) + args[index:]
367+
return fn(*args)
368+
369+
return wrapper
370+
275371
@tracking.no_automatic_dependency_tracking
372+
@tf_no_automatic_dependency_tracking
276373
def _create_variables(self, values, trainable):
277374
"""Create a structure of variables from a structure of JAX arrays.
278375
@@ -296,14 +393,14 @@ def _create_variables(self, values, trainable):
296393

297394
def create_variable(value):
298395
if backend.is_tensor(value) or isinstance(
299-
value, (np.ndarray, np.generic)
396+
value, (np.ndarray, np.generic, jax.Array)
300397
):
301398
dtype = value.dtype
302399
if is_float_dtype(dtype):
303400
dtype = None # Use the layer dtype policy
304401
return self.add_weight(
305402
value.shape,
306-
initializer=value,
403+
initializer=backend.convert_to_tensor(value),
307404
dtype=dtype,
308405
trainable=trainable,
309406
)
@@ -333,44 +430,46 @@ def create_variable(value):
333430

334431
def _get_init_rng(self):
335432
"""
336-
Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `init_fn`.
433+
Returns a key in form of the backend array of size 2 dtype uint32
434+
to pass to `init_fn`.
337435
338-
By default, this returns a single `PRNGKey` retrieved by calling
436+
By default, this returns a Jax or TF array of size 2 by calling
339437
`self.seed_generator.next()`. Override this to return a different
340438
structure.
341439
342440
Returns:
343-
a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
344-
the `rng` argument of `init_fn`.
441+
a key as an Jax or TF array of size 2 dtype uint32 will be passed
442+
as the `rng` argument of `init_fn`.
345443
"""
346444
return self.seed_generator.next()
347445

348446
def _get_call_rng(self, training):
349447
"""
350-
Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `call_fn`.
448+
Returns a key in form of the backend array of size 2 dtype uint32
449+
to pass to `call_fn`.
351450
352-
By default, this returns a single `PRNGKey` retrieved by calling
451+
By default, this returns a Jax or TF array of size 2 by calling
353452
`self.seed_generator.next()` when `training` is `True`, and `None` when
354453
`training` is `False`. Override this to return a different structure or
355454
to pass RNGs in inference mode too.
356455
357456
Returns:
358-
a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
359-
the `rng` argument of `call_fn`.
457+
a key as an Jax or TF array of size 2 dtype uint32 will be passed
458+
as the `rng` argument of `call_fn`.
360459
"""
361460
if training:
362461
return self.seed_generator.next()
363462
else:
364463
return None
365464

366-
def build(self, input_shape):
367-
if self.params is not None or self.state is not None:
368-
return
369-
370-
if jax_utils.is_in_jax_tracing_scope():
465+
def _initialize_weights(self, input_shape):
466+
if jax_utils.is_in_jax_tracing_scope() or tf.inside_function():
371467
# This exception is not actually shown, it is caught and a detailed
372468
# warning about calling 'build' is printed.
373-
raise ValueError("'JaxLayer' cannot be built in tracing scope")
469+
raise ValueError(
470+
"'JaxLayer' cannot be built in tracing scope"
471+
"or inside tf function"
472+
)
374473

375474
# Initialize `params` and `state` if needed by calling `init_fn`.
376475
def create_input(shape):
@@ -381,7 +480,12 @@ def create_input(shape):
381480
init_args = []
382481
for argument_name in self.init_fn_arguments:
383482
if argument_name == "rng":
384-
init_args.append(self._get_init_rng())
483+
init_args.append(
484+
jax.tree_util.tree_map(
485+
lambda x: jax.numpy.array(_convert_to_jax_key(x)),
486+
self._get_init_rng(),
487+
)
488+
)
385489
elif argument_name == "inputs":
386490
init_args.append(init_inputs)
387491
elif argument_name == "training":
@@ -398,6 +502,45 @@ def create_input(shape):
398502
)
399503
self.tracked_state = self._create_variables(init_state, trainable=False)
400504

505+
def build(self, input_shape):
506+
if self.params is None and self.state is None:
507+
self._initialize_weights(input_shape)
508+
509+
if backend.backend() == "tensorflow":
510+
polymorphic_shapes = []
511+
for argument in self.call_fn_arguments:
512+
if argument == "inputs":
513+
polymorphic_shapes.append(
514+
self._get_jax2tf_input_shape(input_shape)
515+
)
516+
elif argument != "training":
517+
# params, state, rng
518+
polymorphic_shapes.append("...")
519+
520+
if "training" in self.call_fn_arguments:
521+
training_argument_index = self.call_fn_arguments.index(
522+
"training"
523+
)
524+
self.jax2tf_training_false_fn = self._jax2tf_convert(
525+
self._partial_with_positional(
526+
self.call_fn, training_argument_index, False
527+
),
528+
polymorphic_shapes,
529+
)
530+
self.jax2tf_training_true_fn = self._jax2tf_convert(
531+
self._partial_with_positional(
532+
self.call_fn, training_argument_index, True
533+
),
534+
polymorphic_shapes,
535+
)
536+
else:
537+
self.jax2tf_training_false_fn = self._jax2tf_convert(
538+
self.call_fn,
539+
polymorphic_shapes,
540+
)
541+
self.jax2tf_training_true_fn = None
542+
super().build(input_shape)
543+
401544
def call(self, inputs, training=False):
402545
def unwrap_variable(variable):
403546
return None if variable is None else variable.value
@@ -413,11 +556,16 @@ def unwrap_variable(variable):
413556
jax.tree_util.tree_map(unwrap_variable, self.state)
414557
)
415558
elif argument_name == "rng":
416-
call_args.append(self._get_call_rng(training))
559+
call_args.append(
560+
jax.tree_util.tree_map(
561+
_convert_to_jax_key, self._get_call_rng(training)
562+
)
563+
)
417564
elif argument_name == "inputs":
418565
call_args.append(inputs)
419566
elif argument_name == "training":
420-
call_args.append(training)
567+
if backend.backend() == "jax":
568+
call_args.append(training)
421569

422570
def assign_state_to_variable(value, variable):
423571
# This exists only to make debugging this error case easier.
@@ -429,14 +577,23 @@ def assign_state_to_variable(value, variable):
429577
)
430578
variable.assign(value)
431579

432-
if self.has_state:
433-
predictions, new_state = self.call_fn(*call_args)
434-
jax.tree_util.tree_map(
435-
assign_state_to_variable, new_state, self.state
436-
)
437-
return predictions
438-
else:
439-
return self.call_fn(*call_args)
580+
def call_with_fn(fn):
581+
if self.has_state:
582+
predictions, new_state = fn(*call_args)
583+
jax.tree_util.tree_map(
584+
assign_state_to_variable, new_state, self.state
585+
)
586+
return predictions
587+
else:
588+
return fn(*call_args)
589+
590+
if backend.backend() == "jax":
591+
return call_with_fn(self.call_fn)
592+
elif backend.backend() == "tensorflow":
593+
if training and self.jax2tf_training_true_fn is not None:
594+
return call_with_fn(self.jax2tf_training_true_fn)
595+
else:
596+
return call_with_fn(self.jax2tf_training_false_fn)
440597

441598
def get_config(self):
442599
config = {
@@ -556,12 +713,6 @@ def __init__(
556713
# Late import to only require Flax when this is used.
557714
from flax.core import scope as flax_scope
558715

559-
if backend.backend() != "jax":
560-
raise ValueError(
561-
"FlaxLayer is only supported with the JAX backend. Current "
562-
f"backend: {backend.backend()}"
563-
)
564-
565716
self.module = module
566717
self.method = method
567718

0 commit comments

Comments
 (0)