3535import org .nd4j .linalg .api .ndarray .INDArray ;
3636import org .nd4j .linalg .factory .Nd4j ;
3737import org .nd4j .linalg .indexing .NDArrayIndex ;
38- import org .nd4j .linalg .io .ClassPathResource ;
3938import org .nd4j .linalg .learning .config .Adam ;
4039import org .nd4j .linalg .lossfunctions .LossFunctions ;
4140import org .nd4j .resources .Downloader ;
@@ -98,17 +97,17 @@ public static void main(String[] args) throws Exception {
9897
9998 //Set up network configuration
10099 MultiLayerConfiguration conf = new NeuralNetConfiguration .Builder ()
101- .seed (seed )
102- .updater (new Adam (5e-3 ))
103- .l2 (1e-5 )
104- .weightInit (WeightInit .XAVIER )
105- .gradientNormalization (GradientNormalization .ClipElementWiseAbsoluteValue ).gradientNormalizationThreshold (1.0 )
106- .list ()
107- .layer (new LSTM .Builder ().nIn (vectorSize ).nOut (256 )
108- .activation (Activation .TANH ).build ())
109- .layer (new RnnOutputLayer .Builder ().activation (Activation .SOFTMAX )
110- .lossFunction (LossFunctions .LossFunction .MCXENT ).nIn (256 ).nOut (2 ).build ())
111- .build ();
100+ .seed (seed )
101+ .updater (new Adam (5e-3 ))
102+ .l2 (1e-5 )
103+ .weightInit (WeightInit .XAVIER )
104+ .gradientNormalization (GradientNormalization .ClipElementWiseAbsoluteValue ).gradientNormalizationThreshold (1.0 )
105+ .list ()
106+ .layer (new LSTM .Builder ().nIn (vectorSize ).nOut (256 )
107+ .activation (Activation .TANH ).build ())
108+ .layer (new RnnOutputLayer .Builder ().activation (Activation .SOFTMAX )
109+ .lossFunction (LossFunctions .LossFunction .MCXENT ).nIn (256 ).nOut (2 ).build ())
110+ .build ();
112111
113112 MultiLayerNetwork net = new MultiLayerNetwork (conf );
114113 net .init ();
@@ -171,11 +170,12 @@ public static void downloadData() throws Exception {
171170
172171 public static void checkDownloadW2VECModel () throws IOException {
173172 String defaultwordVectorsPath = FilenameUtils .concat (System .getProperty ("user.home" ), "dl4j-examples-data/w2vec300" );
173+ String md5w2vec = "1c892c4707a8a1a508b01a01735c0339" ;
174174 wordVectorsPath = new File (defaultwordVectorsPath , "GoogleNews-vectors-negative300.bin.gz" ).getAbsolutePath ();
175175 if (new File (wordVectorsPath ).exists ()) {
176176 System .out .println ("\n \t GoogleNews-vectors-negative300.bin.gz file found at path: " + defaultwordVectorsPath );
177177 System .out .println ("\t Checking md5 of existing file.." );
178- if (Downloader .checkMD5OfFile ("1c892c4707a8a1a508b01a01735c0339" , new File (wordVectorsPath ))) {
178+ if (Downloader .checkMD5OfFile (md5w2vec , new File (wordVectorsPath ))) {
179179 System .out .println ("\t Existing file hash matches." );
180180 return ;
181181 } else {
@@ -189,23 +189,8 @@ public static void checkDownloadW2VECModel() throws IOException {
189189 Scanner scanner = new Scanner (System .in );
190190 scanner .nextLine ();
191191 System .out .println ("Starting model download (1.5GB!)..." );
192- String downloadScript = new ClassPathResource ("w2vecdownload/word2vec-download300model.sh" ).getFile ().getAbsolutePath ();
193- ProcessBuilder processBuilder = new ProcessBuilder (downloadScript , defaultwordVectorsPath );
194- try {
195- processBuilder .inheritIO ();
196- Process process = processBuilder .start ();
197- int exitVal = process .waitFor ();
198- if (exitVal == 0 ) {
199- System .out .println ("Successfully downloaded word2vec model!" );
200- } else {
201- System .out .println ("Download failed. Please download model manually and set the \" wordVectorsPath\" in the code with the path to it." );
202- System .exit (0 );
203- }
204- } catch (IOException e ) {
205- e .printStackTrace ();
206- } catch (InterruptedException e ) {
207- e .printStackTrace ();
208- }
192+ Downloader .download ("Word2Vec" , new URL ("https://dl4jdata.blob.core.windows.net/resources/wordvectors/GoogleNews-vectors-negative300.bin.gz" ), new File (wordVectorsPath ), md5w2vec , 5 );
193+ System .out .println ("Successfully downloaded word2vec model to " + wordVectorsPath );
209194 }
210195}
211196
0 commit comments