2929NUM_DATA_BATCHES = 5
3030NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 10000 * NUM_DATA_BATCHES
3131BATCH_SIZE = 128
32+ INPUT_TENSOR_NAME = PREDICT_INPUTS
3233
3334
3435def keras_model_fn (hyperparameters ):
@@ -77,122 +78,34 @@ def keras_model_fn(hyperparameters):
7778 return _model
7879
7980
80- def serving_input_fn (params ):
81- # Notice that the input placeholder has the same input shape as the Keras model input
82- tensor = tf .placeholder (tf .float32 , shape = [None , HEIGHT , WIDTH , DEPTH ])
83-
84- # The inputs key PREDICT_INPUTS matches the Keras InputLayer name
85- inputs = {PREDICT_INPUTS : tensor }
81+ def serving_input_fn (hyperpameters ):
82+ inputs = {PREDICT_INPUTS : tf .placeholder (tf .float32 , [None , 32 , 32 , 3 ])}
8683 return tf .estimator .export .ServingInputReceiver (inputs , inputs )
8784
8885
89- def train_input_fn (training_dir , params ):
90- return _input (tf .estimator .ModeKeys .TRAIN ,
91- batch_size = BATCH_SIZE , data_dir = training_dir )
92-
93-
94- def eval_input_fn (training_dir , params ):
95- return _input (tf .estimator .ModeKeys .EVAL ,
96- batch_size = BATCH_SIZE , data_dir = training_dir )
97-
98-
99- def _input (mode , batch_size , data_dir ):
100- """Uses the tf.data input pipeline for CIFAR-10 dataset.
101- Args:
102- mode: Standard names for model modes (tf.estimators.ModeKeys).
103- batch_size: The number of samples per batch of input requested.
104- """
105- dataset = _record_dataset (_filenames (mode , data_dir ))
106-
107- # For training repeat forever.
108- if mode == tf .estimator .ModeKeys .TRAIN :
109- dataset = dataset .repeat ()
110-
111- dataset = dataset .map (_dataset_parser )
112- dataset .prefetch (2 * batch_size )
113-
114- # For training, preprocess the image and shuffle.
115- if mode == tf .estimator .ModeKeys .TRAIN :
116- dataset = dataset .map (_train_preprocess_fn )
117- dataset .prefetch (2 * batch_size )
118-
119- # Ensure that the capacity is sufficiently large to provide good random
120- # shuffling.
121- buffer_size = int (NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 0.4 ) + 3 * batch_size
122- dataset = dataset .shuffle (buffer_size = buffer_size )
123-
124- # Subtract off the mean and divide by the variance of the pixels.
125- dataset = dataset .map (
126- lambda image , label : (tf .image .per_image_standardization (image ), label ))
127- dataset .prefetch (2 * batch_size )
128-
129- # Batch results by up to batch_size, and then fetch the tuple from the
130- # iterator.
131- iterator = dataset .batch (batch_size ).make_one_shot_iterator ()
132- images , labels = iterator .get_next ()
133-
134- return {PREDICT_INPUTS : images }, labels
135-
136-
137- def _train_preprocess_fn (image , label ):
138- """Preprocess a single training image of layout [height, width, depth]."""
139- # Resize the image to add four extra pixels on each side.
140- image = tf .image .resize_image_with_crop_or_pad (image , HEIGHT + 8 , WIDTH + 8 )
141-
142- # Randomly crop a [HEIGHT, WIDTH] section of the image.
143- image = tf .random_crop (image , [HEIGHT , WIDTH , DEPTH ])
144-
145- # Randomly flip the image horizontally.
146- image = tf .image .random_flip_left_right (image )
147-
148- return image , label
149-
150-
151- def _dataset_parser (value ):
152- """Parse a CIFAR-10 record from value."""
153- # Every record consists of a label followed by the image, with a fixed number
154- # of bytes for each.
155- label_bytes = 1
156- image_bytes = HEIGHT * WIDTH * DEPTH
157- record_bytes = label_bytes + image_bytes
158-
159- # Convert from a string to a vector of uint8 that is record_bytes long.
160- raw_record = tf .decode_raw (value , tf .uint8 )
161-
162- # The first byte represents the label, which we convert from uint8 to int32.
163- label = tf .cast (raw_record [0 ], tf .int32 )
164-
165- # The remaining bytes after the label represent the image, which we reshape
166- # from [depth * height * width] to [depth, height, width].
167- depth_major = tf .reshape (raw_record [label_bytes :record_bytes ],
168- [DEPTH , HEIGHT , WIDTH ])
169-
170- # Convert from [depth, height, width] to [height, width, depth], and cast as
171- # float32.
172- image = tf .cast (tf .transpose (depth_major , [1 , 2 , 0 ]), tf .float32 )
173-
174- return image , tf .one_hot (label , NUM_CLASSES )
86+ def train_input_fn (training_dir , hyperparameters ):
87+ return _generate_synthetic_data (tf .estimator .ModeKeys .TRAIN , batch_size = BATCH_SIZE )
17588
17689
177- def _record_dataset (filenames ):
178- """Returns an input pipeline Dataset from `filenames`."""
179- record_bytes = HEIGHT * WIDTH * DEPTH + 1
180- return tf .data .FixedLengthRecordDataset (filenames , record_bytes )
90+ def eval_input_fn (training_dir , hyperparameters ):
91+ return _generate_synthetic_data (tf .estimator .ModeKeys .EVAL , batch_size = BATCH_SIZE )
18192
18293
183- def _filenames (mode , data_dir ):
184- """Returns a list of filenames based on 'mode'."""
185- data_dir = os .path .join (data_dir , 'cifar-10-batches-bin' )
94+ def _generate_synthetic_data (mode , batch_size ):
95+ input_shape = [batch_size , HEIGHT , WIDTH , DEPTH ]
96+ images = tf .truncated_normal (
97+ input_shape ,
98+ dtype = tf .float32 ,
99+ stddev = 1e-1 ,
100+ name = 'synthetic_images' )
101+ labels = tf .random_uniform (
102+ [batch_size , NUM_CLASSES ],
103+ minval = 0 ,
104+ maxval = NUM_CLASSES - 1 ,
105+ dtype = tf .float32 ,
106+ name = 'synthetic_labels' )
186107
187- assert os . path . exists ( data_dir ), ( 'Run cifar10_download_and_extract.py first '
188- 'to download and extract the CIFAR-10 data. ' )
108+ images = tf . contrib . framework . local_variable ( images , name = 'images' )
109+ labels = tf . contrib . framework . local_variable ( labels , name = 'labels ' )
189110
190- if mode == tf .estimator .ModeKeys .TRAIN :
191- return [
192- os .path .join (data_dir , 'data_batch_%d.bin' % i )
193- for i in range (1 , NUM_DATA_BATCHES + 1 )
194- ]
195- elif mode == tf .estimator .ModeKeys .EVAL :
196- return [os .path .join (data_dir , 'test_batch.bin' )]
197- else :
198- raise ValueError ('Invalid mode: %s' % mode )
111+ return {INPUT_TENSOR_NAME : images }, labels
0 commit comments