Models Collection Implementations of some models in JAX. Models Model Architecture Type Reference GPT-2 Decoder-only Transformer (causal LM) ref More to come.