3434import org .nd4j .linalg .indexing .conditions .Conditions ;
3535import org .nd4j .linalg .learning .config .Adam ;
3636import org .nd4j .linalg .lossfunctions .LossFunctions ;
37+ import org .nd4j .shade .guava .collect .Streams ;
3738
3839import javax .imageio .ImageIO ;
3940import javax .swing .*;
4041import java .awt .image .BufferedImage ;
42+ import java .awt .image .DataBufferByte ;
4143import java .io .File ;
44+ import java .util .Random ;
4245
4346/**
4447 * Application to show a neural network learning to draw an image.
4548 * Demonstrates how to feed an NN with externally originated data.
4649 *
4750 * Updates from previous versions:
4851 * - Now uses swing. No longer uses JavaFX which caused problems with the OpenJDK.
49- * - All slow java loops in the dataset creation and image drawing are replaced with fast vectorized code.
5052 *
5153 * @author Robert Altena
5254 * Many thanks to @tmanthey for constructive feedback and suggestions.
@@ -59,17 +61,11 @@ public class ImageDrawer {
5961 private BufferedImage originalImage ;
6062 private JLabel generatedLabel ;
6163
62- private INDArray blueMat ; // color channels of he original image.
63- private INDArray greenMat ;
64- private INDArray redMat ;
65-
66- private INDArray xPixels ; // x coordinates of the pixels for the NN.
67- private INDArray yPixels ; // y coordinates of the pixels for the NN.
68-
6964 private INDArray xyOut ; //x,y grid to calculate the output image. Needs to be calculated once, then re-used.
7065
7166 private Java2DNativeImageLoader j2dNil ; //Datavec class used to read and write images to /from INDArrays.
72-
67+ private FastRGB rgb ; // helper class for fast access to the image pixels.
68+ private Random random ;
7369
7470 private void init () throws Exception {
7571
@@ -78,6 +74,7 @@ private void init() throws Exception {
7874
7975 String localDataPath = DownloaderUtility .DATAEXAMPLES .Download ();
8076 originalImage = ImageIO .read (new File (localDataPath , "Mona_Lisa.png" ));
77+
8178 //start with a blank image of the same size as the original.
8279 BufferedImage generatedImage = new BufferedImage (originalImage .getWidth (), originalImage .getHeight (), originalImage .getType ());
8380
@@ -98,15 +95,13 @@ private void init() throws Exception {
9895 mainFrame .setVisible (true ); // Show UI
9996
10097
101- j2dNil = new Java2DNativeImageLoader (); //Datavec class used to read and write images.
98+ j2dNil = new Java2DNativeImageLoader (); //Datavec class used to write images.
99+ random = new Random ();
102100 nn = createNN (); // Create the neural network.
103101 xyOut = calcGrid (); //Create a mesh used to generate the image.
104102
105103 // read the color channels from the original image.
106- INDArray imageMat = j2dNil .asMatrix (originalImage ).castTo (DataType .DOUBLE ).div (255.0 );
107- blueMat = imageMat .tensorAlongDimension (1 , 0 , 2 , 3 ).reshape (width * height , 1 );
108- greenMat = imageMat .tensorAlongDimension (2 , 0 , 2 , 3 ).reshape (width * height , 1 );
109- redMat = imageMat .tensorAlongDimension (3 , 0 , 2 , 3 ).reshape (width * height , 1 );
104+ rgb = new FastRGB (originalImage );
110105
111106 SwingUtilities .invokeLater (this ::onCalc );
112107 }
@@ -127,30 +122,30 @@ private static MultiLayerNetwork createNN() {
127122 int numOutputs = 3 ; //R, G and B value.
128123
129124 MultiLayerConfiguration conf = new NeuralNetConfiguration .Builder ()
130- .seed (seed )
131- .optimizationAlgo (OptimizationAlgorithm .STOCHASTIC_GRADIENT_DESCENT )
132- .weightInit (WeightInit .XAVIER )
133- .updater (new Adam (learningRate ))
134- .list ()
135- .layer (new DenseLayer .Builder ().nIn (numInputs ).nOut (numHiddenNodes )
136- .activation (Activation .LEAKYRELU )
137- .build ())
138- .layer (new DenseLayer .Builder ().nOut (numHiddenNodes )
139- .activation (Activation .LEAKYRELU )
140- .build ())
141- .layer (new DenseLayer .Builder ().nOut (numHiddenNodes )
142- .activation (Activation .LEAKYRELU )
143- .build ())
144- .layer (new DenseLayer .Builder ().nOut (numHiddenNodes )
145- .activation (Activation .LEAKYRELU )
146- .build ())
147- .layer (new DenseLayer .Builder ().nOut (numHiddenNodes )
148- .activation (Activation .LEAKYRELU )
149- .build ())
150- .layer ( new OutputLayer .Builder (LossFunctions .LossFunction .L2 )
151- .activation (Activation .IDENTITY )
152- .nOut (numOutputs ).build ())
153- .build ();
125+ .seed (seed )
126+ .optimizationAlgo (OptimizationAlgorithm .STOCHASTIC_GRADIENT_DESCENT )
127+ .weightInit (WeightInit .XAVIER )
128+ .updater (new Adam (learningRate ))
129+ .list ()
130+ .layer (new DenseLayer .Builder ().nIn (numInputs ).nOut (numHiddenNodes )
131+ .activation (Activation .LEAKYRELU )
132+ .build ())
133+ .layer (new DenseLayer .Builder ().nOut (numHiddenNodes )
134+ .activation (Activation .LEAKYRELU )
135+ .build ())
136+ .layer (new DenseLayer .Builder ().nOut (numHiddenNodes )
137+ .activation (Activation .LEAKYRELU )
138+ .build ())
139+ .layer (new DenseLayer .Builder ().nOut (numHiddenNodes )
140+ .activation (Activation .LEAKYRELU )
141+ .build ())
142+ .layer (new DenseLayer .Builder ().nOut (numHiddenNodes )
143+ .activation (Activation .LEAKYRELU )
144+ .build ())
145+ .layer ( new OutputLayer .Builder (LossFunctions .LossFunction .L2 )
146+ .activation (Activation .IDENTITY )
147+ .nOut (numOutputs ).build ())
148+ .build ();
154149
155150 MultiLayerNetwork net = new MultiLayerNetwork (conf );
156151 net .init ();
@@ -162,8 +157,9 @@ private static MultiLayerNetwork createNN() {
162157 * Training the NN and updating the current graphical output.
163158 */
164159 private void onCalc (){
165- int batchSize = 1000 ;
166- int numBatches = 10 ;
160+ // Find a reasonable balance between batch size and number of batches per generated redraw.
161+ int batchSize = 1000 ; //larger batch size slows the calculation but speeds up the learning per batch
162+ int numBatches = 10 ; // Drawing the generated image is slow. Doing multiple batches before redrawing increases speed.
167163 for (int i =0 ; i < numBatches ; i ++){
168164 DataSet ds = generateDataSet (batchSize );
169165 nn .fit (ds );
@@ -172,11 +168,13 @@ private void onCalc(){
172168 mainFrame .invalidate ();
173169 mainFrame .repaint ();
174170
175- SwingUtilities .invokeLater (this ::onCalc );
171+ SwingUtilities .invokeLater (this ::onCalc ); //TODO: move training to a worker thread,
176172 }
177173
178174 /**
179175 * Take a batchsize of random samples from the source image.
176+ * This illustrates how to generate a custom dataset. The normal way of doing this would be to generate a dataset
177+ * of the entire source image, train om shuffled batches from there.
180178 *
181179 * @param batchSize number of sample points to take out of the image.
182180 * @return DeepLearning4J DataSet.
@@ -185,22 +183,22 @@ private DataSet generateDataSet(int batchSize) {
185183 int w = originalImage .getWidth ();
186184 int h = originalImage .getHeight ();
187185
188- INDArray xindex = Nd4j . rand ( batchSize ). muli ( w - 1 ). castTo ( DataType . UINT32 ) ;
189- INDArray yindex = Nd4j . rand ( batchSize ). muli ( h - 1 ). castTo ( DataType . UINT32 ) ;
190-
191- INDArray xPos = xPixels . get ( xindex ). reshape ( batchSize ); // Look up the normalized positions pf the pixels.
192- INDArray yPos = yPixels . get ( yindex ). reshape ( batchSize );
193-
194- INDArray xy = Nd4j . vstack ( xPos , yPos ). transpose (); // Create the array that can be fed into the NN.
195-
196- //Look up the correct colors fot our random pixels.
197- INDArray xyIndex = yindex . mul ( w ). add ( xindex ); //TODO: figure out the 2D version of INDArray.get.
198- INDArray b = blueMat . get ( xyIndex ). reshape ( batchSize ) ;
199- INDArray g = greenMat . get ( xyIndex ). reshape ( batchSize );
200- INDArray r = redMat . get ( xyIndex ). reshape ( batchSize );
201- INDArray out = Nd4j .vstack ( r , g , b ). transpose (); // Create the array that can be used for NN training.
202-
203- return new DataSet (xy , out );
186+ float [][] in = new float [ batchSize ][ 2 ] ;
187+ float [][] out = new float [ batchSize ][ 3 ] ;
188+ final int [] i = { 0 };
189+ Streams . forEachPair (
190+ random . ints ( batchSize , 0 , w ). boxed (),
191+ random . ints ( batchSize , 0 , h ). boxed (),
192+ ( a , b ) -> {
193+ final short [] parts = rgb . getRGB ( a , b );
194+ in [ i [ 0 ]] = new float []{(( a / ( float ) w ) - 0.5f ) * 2f , (( b / ( float ) h ) - 0.5f ) * 2f };
195+ out [ i [ 0 ]] = new float []{ parts [ 0 ], parts [ 1 ], parts [ 2 ]};
196+ i [ 0 ]++ ;
197+ }
198+ );
199+ final INDArray input = Nd4j .create ( in );
200+ final INDArray labels = Nd4j . create ( out ). divi ( 255 );
201+ return new DataSet (input , labels );
204202 }
205203
206204 /**
@@ -211,7 +209,7 @@ private void drawImage() {
211209 int h = originalImage .getHeight ();
212210
213211 INDArray out = nn .output (xyOut ); // The raw NN output.
214- BooleanIndexing .replaceWhere (out , 0.0 , Conditions .lessThan (0.0 )); // Cjip between 0 and 1.
212+ BooleanIndexing .replaceWhere (out , 0.0 , Conditions .lessThan (0.0 )); // Clip between 0 and 1.
215213 BooleanIndexing .replaceWhere (out , 1.0 , Conditions .greaterThan (1.0 ));
216214 out = out .mul (255 ).castTo (DataType .BYTE ); //convert to bytes.
217215
@@ -231,13 +229,40 @@ private void drawImage() {
231229 private INDArray calcGrid (){
232230 int w = originalImage .getWidth ();
233231 int h = originalImage .getHeight ();
234- xPixels = Nd4j .linspace (-1.0 , 1.0 , w , DataType .DOUBLE );
235- yPixels = Nd4j .linspace (-1.0 , 1.0 , h , DataType .DOUBLE );
232+ INDArray xPixels = Nd4j .linspace (-1.0 , 1.0 , w , DataType .DOUBLE );
233+ INDArray yPixels = Nd4j .linspace (-1.0 , 1.0 , h , DataType .DOUBLE );
236234 INDArray [] mesh = Nd4j .meshgrid (xPixels , yPixels );
237235
238- xPixels = xPixels .reshape (w , 1 ); // This is a hack to work around a bug in INDArray.get()
239- yPixels = yPixels .reshape (h , 1 ); // in the dataset generation.
240-
241236 return Nd4j .vstack (mesh [0 ].ravel (), mesh [1 ].ravel ()).transpose ();
242237 }
238+
239+
240+ public class FastRGB {
241+ int width ;
242+ int height ;
243+ private boolean hasAlphaChannel ;
244+ private int pixelLength ;
245+ private byte [] pixels ;
246+
247+ FastRGB (BufferedImage image ) {
248+ pixels = ((DataBufferByte ) image .getRaster ().getDataBuffer ()).getData ();
249+ width = image .getWidth ();
250+ height = image .getHeight ();
251+ hasAlphaChannel = image .getAlphaRaster () != null ;
252+ pixelLength = 3 ;
253+ if (hasAlphaChannel )
254+ pixelLength = 4 ;
255+ }
256+
257+ short [] getRGB (int x , int y ) {
258+ int pos = (y * pixelLength * width ) + (x * pixelLength );
259+ short rgb [] = new short [4 ];
260+ if (hasAlphaChannel )
261+ rgb [3 ] = (short ) (pixels [pos ++] & 0xFF ); // Alpha
262+ rgb [2 ] = (short ) (pixels [pos ++] & 0xFF ); // Blue
263+ rgb [1 ] = (short ) (pixels [pos ++] & 0xFF ); // Green
264+ rgb [0 ] = (short ) (pixels [pos ] & 0xFF ); // Red
265+ return rgb ;
266+ }
267+ }
243268}
0 commit comments