Skip to content

Commit afc512b

Browse files
committed
Create Conv1DUCISequenceClassification.java
1 parent 4c9adda commit afc512b

File tree

1 file changed

+235
-0
lines changed

1 file changed

+235
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
/* *****************************************************************************
2+
*
3+
*
4+
*
5+
* This program and the accompanying materials are made available under the
6+
* terms of the Apache License, Version 2.0 which is available at
7+
* https://www.apache.org/licenses/LICENSE-2.0.
8+
* See the NOTICE file distributed with this work for additional
9+
* information regarding copyright ownership.
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14+
* License for the specific language governing permissions and limitations
15+
* under the License.
16+
*
17+
* SPDX-License-Identifier: Apache-2.0
18+
******************************************************************************/
19+
20+
package org.deeplearning4j.examples.quickstart.modeling.convolution;
21+
22+
import org.apache.commons.io.FileUtils;
23+
import org.apache.commons.io.IOUtils;
24+
import org.datavec.api.records.reader.SequenceRecordReader;
25+
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
26+
import org.datavec.api.split.NumberedFileInputSplit;
27+
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
28+
import org.deeplearning4j.nn.conf.GradientNormalization;
29+
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
30+
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
31+
import org.deeplearning4j.nn.conf.RNNFormat;
32+
import org.deeplearning4j.nn.conf.inputs.InputType;
33+
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
34+
import org.deeplearning4j.nn.conf.layers.LSTM;
35+
import org.deeplearning4j.nn.conf.layers.OutputLayer;
36+
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
37+
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
38+
import org.deeplearning4j.nn.weights.WeightInit;
39+
import org.deeplearning4j.optimize.api.InvocationType;
40+
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
41+
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
42+
import org.nd4j.common.primitives.Pair;
43+
import org.nd4j.evaluation.classification.Evaluation;
44+
import org.nd4j.linalg.activations.Activation;
45+
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
46+
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
47+
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
48+
import org.nd4j.linalg.learning.config.Nadam;
49+
import org.nd4j.linalg.lossfunctions.LossFunctions;
50+
import org.slf4j.Logger;
51+
import org.slf4j.LoggerFactory;
52+
53+
import java.io.File;
54+
import java.net.URL;
55+
import java.nio.charset.Charset;
56+
import java.util.ArrayList;
57+
import java.util.Collections;
58+
import java.util.List;
59+
import java.util.Random;
60+
61+
/**
62+
* Sequence Classification Example Using a LSTM Recurrent Neural Network
63+
*
64+
* This example learns how to classify univariate time series as belonging to one of six categories.
65+
* Categories are: Normal, Cyclic, Increasing trend, Decreasing trend, Upward shift, Downward shift
66+
*
67+
* Data is the UCI Synthetic Control Chart Time Series Data Set
68+
* Details: https://archive.ics.uci.edu/ml/datasets/Synthetic+Control+Chart+Time+Series
69+
* Data: https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/synthetic_control.data
70+
* Image: https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/data.jpeg
71+
*
72+
* This example proceeds as follows:
73+
* 1. Download and prepare the data (in downloadUCIData() method)
74+
* (a) Split the 600 sequences into train set of size 450, and test set of size 150
75+
* (b) Write the data into a format suitable for loading using the CSVSequenceRecordReader for sequence classification
76+
* This format: one time series per file, and a separate file for the labels.
77+
* For example, train/features/0.csv is the features using with the labels file train/labels/0.csv
78+
* Because the data is a univariate time series, we only have one column in the CSV files. Normally, each column
79+
* would contain multiple values - one time step per row.
80+
* Furthermore, because we have only one label for each time series, the labels CSV files contain only a single value
81+
*
82+
* 2. Load the training data using CSVSequenceRecordReader (to load/parse the CSV files) and SequenceRecordReaderDataSetIterator
83+
* (to convert it to DataSet objects, ready to train)
84+
* For more details on this step, see: https://deeplearning4j.konduit.ai/models/recurrent#data-for-rnns
85+
*
86+
* 3. Normalize the data. The raw data contain values that are too large for effective training, and need to be normalized.
87+
* Normalization is conducted using NormalizerStandardize, based on statistics (mean, st.dev) collected on the training
88+
* data only. Note that both the training data and test data are normalized in the same way.
89+
*
90+
* 4. Configure the network
91+
* The data set here is very small, so we can't afford to use a large network with many parameters.
92+
* We are using one small LSTM layer and one RNN output layer
93+
*
94+
* 5. Train the network for 40 epochs
95+
* At each epoch, evaluate and print the accuracy and f1 on the test set
96+
*
97+
* @author Alex Black
98+
*/
99+
@SuppressWarnings("ResultOfMethodCallIgnored")
100+
public class Conv1DUCISequenceClassification {
101+
private static final Logger log = LoggerFactory.getLogger(Conv1DUCISequenceClassification.class);
102+
103+
//'baseDir': Base directory for the data. Change this if you want to save the data somewhere else
104+
private static File baseDir = new File("src/main/resources/uci/");
105+
private static File baseTrainDir = new File(baseDir, "train");
106+
private static File featuresDirTrain = new File(baseTrainDir, "features");
107+
private static File labelsDirTrain = new File(baseTrainDir, "labels");
108+
private static File baseTestDir = new File(baseDir, "test");
109+
private static File featuresDirTest = new File(baseTestDir, "features");
110+
private static File labelsDirTest = new File(baseTestDir, "labels");
111+
112+
public static void main(String[] args) throws Exception {
113+
downloadUCIData();
114+
115+
// ----- Load the training data -----
116+
//Note that we have 450 training files for features: train/features/0.csv through train/features/449.csv
117+
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
118+
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
119+
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
120+
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
121+
122+
int miniBatchSize = 10;
123+
int numLabelClasses = 6;
124+
DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
125+
false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
126+
127+
//Normalize the training data
128+
DataNormalization normalizer = new NormalizerStandardize();
129+
normalizer.fit(trainData); //Collect training data statistics
130+
trainData.reset();
131+
132+
//Use previously collected statistics to normalize on-the-fly. Each DataSet returned by 'trainData' iterator will be normalized
133+
trainData.setPreProcessor(normalizer);
134+
135+
136+
// ----- Load the test data -----
137+
//Same process as for the training data.
138+
SequenceRecordReader testFeatures = new CSVSequenceRecordReader();
139+
testFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
140+
SequenceRecordReader testLabels = new CSVSequenceRecordReader();
141+
testLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
142+
143+
DataSetIterator testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels, miniBatchSize, numLabelClasses,
144+
false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
145+
146+
testData.setPreProcessor(normalizer); //Note that we are using the exact same normalization process as the training data
147+
148+
149+
// ----- Configure the network -----
150+
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
151+
.seed(123) //Random number generator seed for improved repeatability. Optional.
152+
.weightInit(WeightInit.XAVIER)
153+
.updater(new Nadam())
154+
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) //Not always required, but helps with this data set
155+
.gradientNormalizationThreshold(0.5)
156+
.list()
157+
.layer(new Convolution1DLayer.Builder()
158+
.kernelSize(3)
159+
.stride(1)
160+
.activation(Activation.TANH).nOut(1).build())
161+
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
162+
.activation(Activation.SOFTMAX).nOut(numLabelClasses).build())
163+
.setInputType(InputType.recurrent(1, -1,RNNFormat.NCW))
164+
.build();
165+
166+
MultiLayerNetwork net = new MultiLayerNetwork(conf);
167+
net.init();
168+
169+
log.info("Starting training...");
170+
net.setListeners(new ScoreIterationListener(20), new EvaluativeListener(testData, 1, InvocationType.EPOCH_END)); //Print the score (loss function value) every 20 iterations
171+
172+
int nEpochs = 40;
173+
net.fit(trainData, nEpochs);
174+
175+
log.info("Evaluating...");
176+
Evaluation eval = net.evaluate(testData);
177+
log.info(eval.stats());
178+
179+
log.info("----- Example Complete -----");
180+
}
181+
182+
183+
//This method downloads the data, and converts the "one time series per line" format into a suitable
184+
//CSV sequence format that DataVec (CsvSequenceRecordReader) and DL4J can read.
185+
private static void downloadUCIData() throws Exception {
186+
if (baseDir.exists()) return; //Data already exists, don't download it again
187+
188+
String url = "https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/synthetic_control.data";
189+
String data = IOUtils.toString(new URL(url), (Charset) null);
190+
191+
String[] lines = data.split("\n");
192+
193+
//Create directories
194+
baseDir.mkdir();
195+
baseTrainDir.mkdir();
196+
featuresDirTrain.mkdir();
197+
labelsDirTrain.mkdir();
198+
baseTestDir.mkdir();
199+
featuresDirTest.mkdir();
200+
labelsDirTest.mkdir();
201+
202+
int lineCount = 0;
203+
List<Pair<String, Integer>> contentAndLabels = new ArrayList<>();
204+
for (String line : lines) {
205+
String transposed = line.replaceAll(" +", "\n");
206+
207+
//Labels: first 100 quickstartexamples (lines) are label 0, second 100 quickstartexamples are label 1, and so on
208+
contentAndLabels.add(new Pair<>(transposed, lineCount++ / 100));
209+
}
210+
211+
//Randomize and do a train/test split:
212+
Collections.shuffle(contentAndLabels, new Random(12345));
213+
214+
int nTrain = 450; //75% train, 25% test
215+
int trainCount = 0;
216+
int testCount = 0;
217+
for (Pair<String, Integer> p : contentAndLabels) {
218+
//Write output in a format we can read, in the appropriate locations
219+
File outPathFeatures;
220+
File outPathLabels;
221+
if (trainCount < nTrain) {
222+
outPathFeatures = new File(featuresDirTrain, trainCount + ".csv");
223+
outPathLabels = new File(labelsDirTrain, trainCount + ".csv");
224+
trainCount++;
225+
} else {
226+
outPathFeatures = new File(featuresDirTest, testCount + ".csv");
227+
outPathLabels = new File(labelsDirTest, testCount + ".csv");
228+
testCount++;
229+
}
230+
231+
FileUtils.writeStringToFile(outPathFeatures, p.getFirst(), (Charset) null);
232+
FileUtils.writeStringToFile(outPathLabels, p.getSecond().toString(), (Charset) null);
233+
}
234+
}
235+
}

0 commit comments

Comments
 (0)