1+ # Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4+ # may not use this file except in compliance with the License. A copy of
5+ # the License is located at
6+ #
7+ # http://aws.amazon.com/apache2.0/
8+ #
9+ # or in the "license" file accompanying this file. This file is
10+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+ # ANY KIND, either express or implied. See the License for the specific
12+ # language governing permissions and limitations under the License.
13+ from __future__ import absolute_import
14+ from __future__ import division
15+ from __future__ import print_function
16+
17+ import os
18+
19+ import tensorflow as tf
20+ from tensorflow .python .keras .layers import InputLayer , Conv2D , Activation , MaxPooling2D , Dropout , Flatten , Dense
21+ from tensorflow .python .keras .models import Sequential
22+ from tensorflow .python .keras .optimizers import RMSprop
23+ from tensorflow .python .saved_model .signature_constants import PREDICT_INPUTS
24+
25+ HEIGHT = 32
26+ WIDTH = 32
27+ DEPTH = 3
28+ NUM_CLASSES = 10
29+ NUM_DATA_BATCHES = 5
30+ NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 10000 * NUM_DATA_BATCHES
31+ BATCH_SIZE = 128
32+
33+
34+ def keras_model_fn (hyperparameters ):
35+ """keras_model_fn receives hyperparameters from the training job and returns a compiled keras model.
36+ The model will be transformed into a TensorFlow Estimator before training and it will be saved in a
37+ TensorFlow Serving SavedModel at the end of training.
38+ Args:
39+ hyperparameters: The hyperparameters passed to the SageMaker TrainingJob that runs your TensorFlow
40+ training script.
41+ Returns: A compiled Keras model
42+ """
43+ model = Sequential ()
44+
45+ # TensorFlow Serving default prediction input tensor name is PREDICT_INPUTS.
46+ # We must conform to this naming scheme.
47+ model .add (InputLayer (input_shape = (HEIGHT , WIDTH , DEPTH ), name = PREDICT_INPUTS ))
48+ model .add (Conv2D (32 , (3 , 3 ), padding = 'same' ))
49+ model .add (Activation ('relu' ))
50+ model .add (Conv2D (32 , (3 , 3 )))
51+ model .add (Activation ('relu' ))
52+ model .add (MaxPooling2D (pool_size = (2 , 2 )))
53+ model .add (Dropout (0.25 ))
54+
55+ model .add (Conv2D (64 , (3 , 3 ), padding = 'same' ))
56+ model .add (Activation ('relu' ))
57+ model .add (Conv2D (64 , (3 , 3 )))
58+ model .add (Activation ('relu' ))
59+ model .add (MaxPooling2D (pool_size = (2 , 2 )))
60+ model .add (Dropout (0.25 ))
61+
62+ model .add (Flatten ())
63+ model .add (Dense (512 ))
64+ model .add (Activation ('relu' ))
65+ model .add (Dropout (0.5 ))
66+ model .add (Dense (NUM_CLASSES ))
67+ model .add (Activation ('softmax' ))
68+
69+ _model = tf .keras .Model (inputs = model .input , outputs = model .output )
70+
71+ opt = RMSprop (lr = hyperparameters ['learning_rate' ], decay = hyperparameters ['decay' ])
72+
73+ _model .compile (loss = 'categorical_crossentropy' ,
74+ optimizer = opt ,
75+ metrics = ['accuracy' ])
76+
77+ return _model
78+
79+
80+ def serving_input_fn (params ):
81+ # Notice that the input placeholder has the same input shape as the Keras model input
82+ tensor = tf .placeholder (tf .float32 , shape = [None , HEIGHT , WIDTH , DEPTH ])
83+
84+ # The inputs key PREDICT_INPUTS matches the Keras InputLayer name
85+ inputs = {PREDICT_INPUTS : tensor }
86+ return tf .estimator .export .ServingInputReceiver (inputs , inputs )
87+
88+
89+ def train_input_fn (training_dir , params ):
90+ return _input (tf .estimator .ModeKeys .TRAIN ,
91+ batch_size = BATCH_SIZE , data_dir = training_dir )
92+
93+
94+ def eval_input_fn (training_dir , params ):
95+ return _input (tf .estimator .ModeKeys .EVAL ,
96+ batch_size = BATCH_SIZE , data_dir = training_dir )
97+
98+
99+ def _input (mode , batch_size , data_dir ):
100+ """Uses the tf.data input pipeline for CIFAR-10 dataset.
101+ Args:
102+ mode: Standard names for model modes (tf.estimators.ModeKeys).
103+ batch_size: The number of samples per batch of input requested.
104+ """
105+ dataset = _record_dataset (_filenames (mode , data_dir ))
106+
107+ # For training repeat forever.
108+ if mode == tf .estimator .ModeKeys .TRAIN :
109+ dataset = dataset .repeat ()
110+
111+ dataset = dataset .map (_dataset_parser )
112+ dataset .prefetch (2 * batch_size )
113+
114+ # For training, preprocess the image and shuffle.
115+ if mode == tf .estimator .ModeKeys .TRAIN :
116+ dataset = dataset .map (_train_preprocess_fn )
117+ dataset .prefetch (2 * batch_size )
118+
119+ # Ensure that the capacity is sufficiently large to provide good random
120+ # shuffling.
121+ buffer_size = int (NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 0.4 ) + 3 * batch_size
122+ dataset = dataset .shuffle (buffer_size = buffer_size )
123+
124+ # Subtract off the mean and divide by the variance of the pixels.
125+ dataset = dataset .map (
126+ lambda image , label : (tf .image .per_image_standardization (image ), label ))
127+ dataset .prefetch (2 * batch_size )
128+
129+ # Batch results by up to batch_size, and then fetch the tuple from the
130+ # iterator.
131+ iterator = dataset .batch (batch_size ).make_one_shot_iterator ()
132+ images , labels = iterator .get_next ()
133+
134+ return {PREDICT_INPUTS : images }, labels
135+
136+
137+ def _train_preprocess_fn (image , label ):
138+ """Preprocess a single training image of layout [height, width, depth]."""
139+ # Resize the image to add four extra pixels on each side.
140+ image = tf .image .resize_image_with_crop_or_pad (image , HEIGHT + 8 , WIDTH + 8 )
141+
142+ # Randomly crop a [HEIGHT, WIDTH] section of the image.
143+ image = tf .random_crop (image , [HEIGHT , WIDTH , DEPTH ])
144+
145+ # Randomly flip the image horizontally.
146+ image = tf .image .random_flip_left_right (image )
147+
148+ return image , label
149+
150+
151+ def _dataset_parser (value ):
152+ """Parse a CIFAR-10 record from value."""
153+ # Every record consists of a label followed by the image, with a fixed number
154+ # of bytes for each.
155+ label_bytes = 1
156+ image_bytes = HEIGHT * WIDTH * DEPTH
157+ record_bytes = label_bytes + image_bytes
158+
159+ # Convert from a string to a vector of uint8 that is record_bytes long.
160+ raw_record = tf .decode_raw (value , tf .uint8 )
161+
162+ # The first byte represents the label, which we convert from uint8 to int32.
163+ label = tf .cast (raw_record [0 ], tf .int32 )
164+
165+ # The remaining bytes after the label represent the image, which we reshape
166+ # from [depth * height * width] to [depth, height, width].
167+ depth_major = tf .reshape (raw_record [label_bytes :record_bytes ],
168+ [DEPTH , HEIGHT , WIDTH ])
169+
170+ # Convert from [depth, height, width] to [height, width, depth], and cast as
171+ # float32.
172+ image = tf .cast (tf .transpose (depth_major , [1 , 2 , 0 ]), tf .float32 )
173+
174+ return image , tf .one_hot (label , NUM_CLASSES )
175+
176+
177+ def _record_dataset (filenames ):
178+ """Returns an input pipeline Dataset from `filenames`."""
179+ record_bytes = HEIGHT * WIDTH * DEPTH + 1
180+ return tf .data .FixedLengthRecordDataset (filenames , record_bytes )
181+
182+
183+ def _filenames (mode , data_dir ):
184+ """Returns a list of filenames based on 'mode'."""
185+ data_dir = os .path .join (data_dir , 'cifar-10-batches-bin' )
186+
187+ assert os .path .exists (data_dir ), ('Run cifar10_download_and_extract.py first '
188+ 'to download and extract the CIFAR-10 data.' )
189+
190+ if mode == tf .estimator .ModeKeys .TRAIN :
191+ return [
192+ os .path .join (data_dir , 'data_batch_%d.bin' % i )
193+ for i in range (1 , NUM_DATA_BATCHES + 1 )
194+ ]
195+ elif mode == tf .estimator .ModeKeys .EVAL :
196+ return [os .path .join (data_dir , 'test_batch.bin' )]
197+ else :
198+ raise ValueError ('Invalid mode: %s' % mode )
0 commit comments