|
| 1 | +/* ***************************************************************************** |
| 2 | + * |
| 3 | + * |
| 4 | + * |
| 5 | + * This program and the accompanying materials are made available under the |
| 6 | + * terms of the Apache License, Version 2.0 which is available at |
| 7 | + * https://www.apache.org/licenses/LICENSE-2.0. |
| 8 | + * See the NOTICE file distributed with this work for additional |
| 9 | + * information regarding copyright ownership. |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| 13 | + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| 14 | + * License for the specific language governing permissions and limitations |
| 15 | + * under the License. |
| 16 | + * |
| 17 | + * SPDX-License-Identifier: Apache-2.0 |
| 18 | + ******************************************************************************/ |
| 19 | + |
| 20 | +package org.deeplearning4j.examples.quickstart.modeling.convolution; |
| 21 | + |
| 22 | +import org.apache.commons.io.FileUtils; |
| 23 | +import org.apache.commons.io.IOUtils; |
| 24 | +import org.datavec.api.records.reader.SequenceRecordReader; |
| 25 | +import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; |
| 26 | +import org.datavec.api.split.NumberedFileInputSplit; |
| 27 | +import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; |
| 28 | +import org.deeplearning4j.nn.conf.GradientNormalization; |
| 29 | +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; |
| 30 | +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; |
| 31 | +import org.deeplearning4j.nn.conf.RNNFormat; |
| 32 | +import org.deeplearning4j.nn.conf.inputs.InputType; |
| 33 | +import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; |
| 34 | +import org.deeplearning4j.nn.conf.layers.LSTM; |
| 35 | +import org.deeplearning4j.nn.conf.layers.OutputLayer; |
| 36 | +import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; |
| 37 | +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; |
| 38 | +import org.deeplearning4j.nn.weights.WeightInit; |
| 39 | +import org.deeplearning4j.optimize.api.InvocationType; |
| 40 | +import org.deeplearning4j.optimize.listeners.EvaluativeListener; |
| 41 | +import org.deeplearning4j.optimize.listeners.ScoreIterationListener; |
| 42 | +import org.nd4j.common.primitives.Pair; |
| 43 | +import org.nd4j.evaluation.classification.Evaluation; |
| 44 | +import org.nd4j.linalg.activations.Activation; |
| 45 | +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; |
| 46 | +import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; |
| 47 | +import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; |
| 48 | +import org.nd4j.linalg.learning.config.Nadam; |
| 49 | +import org.nd4j.linalg.lossfunctions.LossFunctions; |
| 50 | +import org.slf4j.Logger; |
| 51 | +import org.slf4j.LoggerFactory; |
| 52 | + |
| 53 | +import java.io.File; |
| 54 | +import java.net.URL; |
| 55 | +import java.nio.charset.Charset; |
| 56 | +import java.util.ArrayList; |
| 57 | +import java.util.Collections; |
| 58 | +import java.util.List; |
| 59 | +import java.util.Random; |
| 60 | + |
| 61 | +/** |
| 62 | + * Sequence Classification Example Using a LSTM Recurrent Neural Network |
| 63 | + * |
| 64 | + * This example learns how to classify univariate time series as belonging to one of six categories. |
| 65 | + * Categories are: Normal, Cyclic, Increasing trend, Decreasing trend, Upward shift, Downward shift |
| 66 | + * |
| 67 | + * Data is the UCI Synthetic Control Chart Time Series Data Set |
| 68 | + * Details: https://archive.ics.uci.edu/ml/datasets/Synthetic+Control+Chart+Time+Series |
| 69 | + * Data: https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/synthetic_control.data |
| 70 | + * Image: https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/data.jpeg |
| 71 | + * |
| 72 | + * This example proceeds as follows: |
| 73 | + * 1. Download and prepare the data (in downloadUCIData() method) |
| 74 | + * (a) Split the 600 sequences into train set of size 450, and test set of size 150 |
| 75 | + * (b) Write the data into a format suitable for loading using the CSVSequenceRecordReader for sequence classification |
| 76 | + * This format: one time series per file, and a separate file for the labels. |
| 77 | + * For example, train/features/0.csv is the features using with the labels file train/labels/0.csv |
| 78 | + * Because the data is a univariate time series, we only have one column in the CSV files. Normally, each column |
| 79 | + * would contain multiple values - one time step per row. |
| 80 | + * Furthermore, because we have only one label for each time series, the labels CSV files contain only a single value |
| 81 | + * |
| 82 | + * 2. Load the training data using CSVSequenceRecordReader (to load/parse the CSV files) and SequenceRecordReaderDataSetIterator |
| 83 | + * (to convert it to DataSet objects, ready to train) |
| 84 | + * For more details on this step, see: https://deeplearning4j.konduit.ai/models/recurrent#data-for-rnns |
| 85 | + * |
| 86 | + * 3. Normalize the data. The raw data contain values that are too large for effective training, and need to be normalized. |
| 87 | + * Normalization is conducted using NormalizerStandardize, based on statistics (mean, st.dev) collected on the training |
| 88 | + * data only. Note that both the training data and test data are normalized in the same way. |
| 89 | + * |
| 90 | + * 4. Configure the network |
| 91 | + * The data set here is very small, so we can't afford to use a large network with many parameters. |
| 92 | + * We are using one small LSTM layer and one RNN output layer |
| 93 | + * |
| 94 | + * 5. Train the network for 40 epochs |
| 95 | + * At each epoch, evaluate and print the accuracy and f1 on the test set |
| 96 | + * |
| 97 | + * @author Alex Black |
| 98 | + */ |
| 99 | +@SuppressWarnings("ResultOfMethodCallIgnored") |
| 100 | +public class Conv1DUCISequenceClassification { |
| 101 | + private static final Logger log = LoggerFactory.getLogger(Conv1DUCISequenceClassification.class); |
| 102 | + |
| 103 | + //'baseDir': Base directory for the data. Change this if you want to save the data somewhere else |
| 104 | + private static File baseDir = new File("src/main/resources/uci/"); |
| 105 | + private static File baseTrainDir = new File(baseDir, "train"); |
| 106 | + private static File featuresDirTrain = new File(baseTrainDir, "features"); |
| 107 | + private static File labelsDirTrain = new File(baseTrainDir, "labels"); |
| 108 | + private static File baseTestDir = new File(baseDir, "test"); |
| 109 | + private static File featuresDirTest = new File(baseTestDir, "features"); |
| 110 | + private static File labelsDirTest = new File(baseTestDir, "labels"); |
| 111 | + |
| 112 | + public static void main(String[] args) throws Exception { |
| 113 | + downloadUCIData(); |
| 114 | + |
| 115 | + // ----- Load the training data ----- |
| 116 | + //Note that we have 450 training files for features: train/features/0.csv through train/features/449.csv |
| 117 | + SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); |
| 118 | + trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); |
| 119 | + SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); |
| 120 | + trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); |
| 121 | + |
| 122 | + int miniBatchSize = 10; |
| 123 | + int numLabelClasses = 6; |
| 124 | + DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, |
| 125 | + false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); |
| 126 | + |
| 127 | + //Normalize the training data |
| 128 | + DataNormalization normalizer = new NormalizerStandardize(); |
| 129 | + normalizer.fit(trainData); //Collect training data statistics |
| 130 | + trainData.reset(); |
| 131 | + |
| 132 | + //Use previously collected statistics to normalize on-the-fly. Each DataSet returned by 'trainData' iterator will be normalized |
| 133 | + trainData.setPreProcessor(normalizer); |
| 134 | + |
| 135 | + |
| 136 | + // ----- Load the test data ----- |
| 137 | + //Same process as for the training data. |
| 138 | + SequenceRecordReader testFeatures = new CSVSequenceRecordReader(); |
| 139 | + testFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); |
| 140 | + SequenceRecordReader testLabels = new CSVSequenceRecordReader(); |
| 141 | + testLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); |
| 142 | + |
| 143 | + DataSetIterator testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels, miniBatchSize, numLabelClasses, |
| 144 | + false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); |
| 145 | + |
| 146 | + testData.setPreProcessor(normalizer); //Note that we are using the exact same normalization process as the training data |
| 147 | + |
| 148 | + |
| 149 | + // ----- Configure the network ----- |
| 150 | + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() |
| 151 | + .seed(123) //Random number generator seed for improved repeatability. Optional. |
| 152 | + .weightInit(WeightInit.XAVIER) |
| 153 | + .updater(new Nadam()) |
| 154 | + .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) //Not always required, but helps with this data set |
| 155 | + .gradientNormalizationThreshold(0.5) |
| 156 | + .list() |
| 157 | + .layer(new Convolution1DLayer.Builder() |
| 158 | + .kernelSize(3) |
| 159 | + .stride(1) |
| 160 | + .activation(Activation.TANH).nOut(1).build()) |
| 161 | + .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) |
| 162 | + .activation(Activation.SOFTMAX).nOut(numLabelClasses).build()) |
| 163 | + .setInputType(InputType.recurrent(1, -1,RNNFormat.NCW)) |
| 164 | + .build(); |
| 165 | + |
| 166 | + MultiLayerNetwork net = new MultiLayerNetwork(conf); |
| 167 | + net.init(); |
| 168 | + |
| 169 | + log.info("Starting training..."); |
| 170 | + net.setListeners(new ScoreIterationListener(20), new EvaluativeListener(testData, 1, InvocationType.EPOCH_END)); //Print the score (loss function value) every 20 iterations |
| 171 | + |
| 172 | + int nEpochs = 40; |
| 173 | + net.fit(trainData, nEpochs); |
| 174 | + |
| 175 | + log.info("Evaluating..."); |
| 176 | + Evaluation eval = net.evaluate(testData); |
| 177 | + log.info(eval.stats()); |
| 178 | + |
| 179 | + log.info("----- Example Complete -----"); |
| 180 | + } |
| 181 | + |
| 182 | + |
| 183 | + //This method downloads the data, and converts the "one time series per line" format into a suitable |
| 184 | + //CSV sequence format that DataVec (CsvSequenceRecordReader) and DL4J can read. |
| 185 | + private static void downloadUCIData() throws Exception { |
| 186 | + if (baseDir.exists()) return; //Data already exists, don't download it again |
| 187 | + |
| 188 | + String url = "https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/synthetic_control.data"; |
| 189 | + String data = IOUtils.toString(new URL(url), (Charset) null); |
| 190 | + |
| 191 | + String[] lines = data.split("\n"); |
| 192 | + |
| 193 | + //Create directories |
| 194 | + baseDir.mkdir(); |
| 195 | + baseTrainDir.mkdir(); |
| 196 | + featuresDirTrain.mkdir(); |
| 197 | + labelsDirTrain.mkdir(); |
| 198 | + baseTestDir.mkdir(); |
| 199 | + featuresDirTest.mkdir(); |
| 200 | + labelsDirTest.mkdir(); |
| 201 | + |
| 202 | + int lineCount = 0; |
| 203 | + List<Pair<String, Integer>> contentAndLabels = new ArrayList<>(); |
| 204 | + for (String line : lines) { |
| 205 | + String transposed = line.replaceAll(" +", "\n"); |
| 206 | + |
| 207 | + //Labels: first 100 quickstartexamples (lines) are label 0, second 100 quickstartexamples are label 1, and so on |
| 208 | + contentAndLabels.add(new Pair<>(transposed, lineCount++ / 100)); |
| 209 | + } |
| 210 | + |
| 211 | + //Randomize and do a train/test split: |
| 212 | + Collections.shuffle(contentAndLabels, new Random(12345)); |
| 213 | + |
| 214 | + int nTrain = 450; //75% train, 25% test |
| 215 | + int trainCount = 0; |
| 216 | + int testCount = 0; |
| 217 | + for (Pair<String, Integer> p : contentAndLabels) { |
| 218 | + //Write output in a format we can read, in the appropriate locations |
| 219 | + File outPathFeatures; |
| 220 | + File outPathLabels; |
| 221 | + if (trainCount < nTrain) { |
| 222 | + outPathFeatures = new File(featuresDirTrain, trainCount + ".csv"); |
| 223 | + outPathLabels = new File(labelsDirTrain, trainCount + ".csv"); |
| 224 | + trainCount++; |
| 225 | + } else { |
| 226 | + outPathFeatures = new File(featuresDirTest, testCount + ".csv"); |
| 227 | + outPathLabels = new File(labelsDirTest, testCount + ".csv"); |
| 228 | + testCount++; |
| 229 | + } |
| 230 | + |
| 231 | + FileUtils.writeStringToFile(outPathFeatures, p.getFirst(), (Charset) null); |
| 232 | + FileUtils.writeStringToFile(outPathLabels, p.getSecond().toString(), (Charset) null); |
| 233 | + } |
| 234 | + } |
| 235 | +} |
0 commit comments