|
516 | 516 | "source": [ |
517 | 517 | "### Important note about BatchNormalization layers\n", |
518 | 518 | "\n", |
519 | | - "Many models contain `tf.keras.layers.BatchNormalization` layers. This layer is a special case and precautions should be taken in the context of fine-tuning, as shown later in this tutorial. \n", |
| 519 | + "Many models contain `tf.keras.layers.BatchNormalization` layers. This layer is a special case and precautions should be taken in the context of fine-tuning, as shown later in this tutorial.\n", |
520 | 520 | "\n", |
521 | | - "When you set `layer.trainable = False`, the `BatchNormalization` layer will run in inference mode, and will not update its mean and variance statistics. \n", |
| 521 | + "When you set `layer.trainable = False`, the `BatchNormalization` layer will run in inference mode, and will not update its mean and variance statistics.\n", |
522 | 522 | "\n", |
523 | 523 | "When you unfreeze a model that contains BatchNormalization layers in order to do fine-tuning, you should keep the BatchNormalization layers in inference mode by passing `training = False` when calling the base model. Otherwise, the updates applied to the non-trainable weights will destroy what the model has learned.\n", |
524 | 524 | "\n", |
|
617 | 617 | "model = tf.keras.Model(inputs, outputs)" |
618 | 618 | ] |
619 | 619 | }, |
| 620 | + { |
| 621 | + "cell_type": "code", |
| 622 | + "execution_count": null, |
| 623 | + "metadata": { |
| 624 | + "id": "I8ARiyMFsgbH" |
| 625 | + }, |
| 626 | + "outputs": [], |
| 627 | + "source": [ |
| 628 | + "model.summary()" |
| 629 | + ] |
| 630 | + }, |
620 | 631 | { |
621 | 632 | "cell_type": "markdown", |
622 | 633 | "metadata": { |
623 | | - "id": "g0ylJXE_kRLi" |
| 634 | + "id": "lxOcmVr0ydFZ" |
624 | 635 | }, |
625 | 636 | "source": [ |
626 | | - "### Compile the model\n", |
627 | | - "\n", |
628 | | - "Compile the model before training it. Since there are two classes, use the `tf.keras.losses.BinaryCrossentropy` loss with `from_logits=True` since the model provides a linear output." |
| 637 | + "The 8+ million parameters in MobileNet are frozen, but there are 1.2 thousand _trainable_ parameters in the Dense layer. These are divided between two `tf.Variable` objects, the weights and biases." |
629 | 638 | ] |
630 | 639 | }, |
631 | 640 | { |
632 | 641 | "cell_type": "code", |
633 | 642 | "execution_count": null, |
634 | 643 | "metadata": { |
635 | | - "id": "RpR8HdyMhukJ" |
| 644 | + "id": "krvBumovycVA" |
636 | 645 | }, |
637 | 646 | "outputs": [], |
638 | 647 | "source": [ |
639 | | - "base_learning_rate = 0.0001\n", |
640 | | - "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),\n", |
641 | | - " loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n", |
642 | | - " metrics=['accuracy'])" |
| 648 | + "len(model.trainable_variables)" |
643 | 649 | ] |
644 | 650 | }, |
645 | 651 | { |
646 | 652 | "cell_type": "code", |
647 | 653 | "execution_count": null, |
648 | 654 | "metadata": { |
649 | | - "id": "I8ARiyMFsgbH" |
| 655 | + "id": "jeGk93R2ahav" |
650 | 656 | }, |
651 | 657 | "outputs": [], |
652 | 658 | "source": [ |
653 | | - "model.summary()" |
| 659 | + "tf.keras.utils.plot_model(model, show_shapes=True)" |
654 | 660 | ] |
655 | 661 | }, |
656 | 662 | { |
657 | 663 | "cell_type": "markdown", |
658 | 664 | "metadata": { |
659 | | - "id": "lxOcmVr0ydFZ" |
| 665 | + "id": "g0ylJXE_kRLi" |
660 | 666 | }, |
661 | 667 | "source": [ |
662 | | - "The 2.5 million parameters in MobileNet are frozen, but there are 1.2 thousand _trainable_ parameters in the Dense layer. These are divided between two `tf.Variable` objects, the weights and biases." |
| 668 | + "### Compile the model\n", |
| 669 | + "\n", |
| 670 | + "Compile the model before training it. Since there are two classes, use the `tf.keras.losses.BinaryCrossentropy` loss with `from_logits=True` since the model provides a linear output." |
663 | 671 | ] |
664 | 672 | }, |
665 | 673 | { |
666 | 674 | "cell_type": "code", |
667 | 675 | "execution_count": null, |
668 | 676 | "metadata": { |
669 | | - "id": "krvBumovycVA" |
| 677 | + "id": "RpR8HdyMhukJ" |
670 | 678 | }, |
671 | 679 | "outputs": [], |
672 | 680 | "source": [ |
673 | | - "len(model.trainable_variables)" |
| 681 | + "base_learning_rate = 0.0001\n", |
| 682 | + "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),\n", |
| 683 | + " loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n", |
| 684 | + " metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0, name='accuracy')])" |
674 | 685 | ] |
675 | 686 | }, |
676 | 687 | { |
|
681 | 692 | "source": [ |
682 | 693 | "### Train the model\n", |
683 | 694 | "\n", |
684 | | - "After training for 10 epochs, you should see ~94% accuracy on the validation set.\n" |
| 695 | + "After training for 10 epochs, you should see ~96% accuracy on the validation set.\n" |
685 | 696 | ] |
686 | 697 | }, |
687 | 698 | { |
|
863 | 874 | "source": [ |
864 | 875 | "model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n", |
865 | 876 | " optimizer = tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate/10),\n", |
866 | | - " metrics=['accuracy'])" |
| 877 | + " metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0, name='accuracy')])" |
867 | 878 | ] |
868 | 879 | }, |
869 | 880 | { |
|
1070 | 1081 | "\n", |
1071 | 1082 | "To learn more, visit the [Transfer learning guide](https://www.tensorflow.org/guide/keras/transfer_learning).\n" |
1072 | 1083 | ] |
| 1084 | + }, |
| 1085 | + { |
| 1086 | + "cell_type": "code", |
| 1087 | + "execution_count": null, |
| 1088 | + "metadata": { |
| 1089 | + "id": "uKIByL01da8c" |
| 1090 | + }, |
| 1091 | + "outputs": [], |
| 1092 | + "source": [] |
1073 | 1093 | } |
1074 | 1094 | ], |
1075 | 1095 | "metadata": { |
1076 | 1096 | "accelerator": "GPU", |
1077 | 1097 | "colab": { |
1078 | | - "collapsed_sections": [], |
1079 | 1098 | "name": "transfer_learning.ipynb", |
| 1099 | + "private_outputs": true, |
| 1100 | + "provenance": [], |
1080 | 1101 | "toc_visible": true |
1081 | 1102 | }, |
1082 | 1103 | "kernelspec": { |
|
0 commit comments