~/deeplearning4j/run.sc.html
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
println("hello")
val numRows = 28;
val numColumns = 28;
val outputNum = 10;
val batchSize = 128;
val rngSeed = 123;
val numEpochs = 15;
val mnistTrain: DataSetIterator = new MnistDataSetIterator(batchSize, true, rngSeed);
val mnistTest: DataSetIterator = new MnistDataSetIterator(batchSize, false, rngSeed);
println("Build model....");
val conf: MultiLayerConfiguration = new NeuralNetConfiguration.Builder()
.seed(rngSeed)
.updater(new Nesterovs(0.006, 0.9))
.l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(numRows * numColumns)
.nOut(1000)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.build())
.layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(1000)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER)
.build())
.build();
val model: MultiLayerNetwork = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(1));
println("Train model....");
for(i <- 0 to numEpochs) {
model.fit(mnistTrain);
}
println("Evaluate model....");
val eval: Evaluation = new Evaluation(outputNum);
while(mnistTest.hasNext()){
val next:DataSet = mnistTest.next();
val output: INDArray = model.output(next.getFeatures());
eval.eval(next.getLabels(), output);
}
println(eval.stats());
println("****************Example finished********************");
Comments
Post a Comment