Skip to content

Commit 6e1cab0

Browse files
author
helga_sh
committed
CNN example with Deeplearning4j in Java
1 parent cd2f453 commit 6e1cab0

6 files changed

Lines changed: 226 additions & 0 deletions

File tree

deeplearning4j/pom.xml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@
3737
<artifactId>deeplearning4j-nn</artifactId>
3838
<version>${dl4j.version}</version>
3939
</dependency>
40+
<dependency>
41+
<groupId>org.slf4j</groupId>
42+
<artifactId>slf4j-api</artifactId>
43+
<version>1.7.5</version>
44+
</dependency>
45+
<dependency>
46+
<groupId>org.slf4j</groupId>
47+
<artifactId>slf4j-log4j12</artifactId>
48+
<version>1.7.5</version>
49+
</dependency>
4050
<!-- https://mvnrepository.com/artifact/org.datavec/datavec-api -->
4151
<dependency>
4252
<groupId>org.datavec</groupId>
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package com.baeldung.deeplearning4j.cnn;
2+
3+
4+
import com.baeldung.deeplearning4j.cnn.domain.network.CnnModel;
5+
import com.baeldung.deeplearning4j.cnn.domain.network.CnnModelProperties;
6+
import com.baeldung.deeplearning4j.cnn.service.dataset.CifarDataSetService;
7+
import lombok.extern.slf4j.Slf4j;
8+
import org.deeplearning4j.eval.Evaluation;
9+
10+
@Slf4j
11+
public class CnnExample {
12+
13+
public static void main(String... args) {
14+
CnnModel network = new CnnModel(new CifarDataSetService(), new CnnModelProperties());
15+
16+
network.train();
17+
Evaluation evaluation = network.evaluate();
18+
19+
log.info(evaluation.stats());
20+
}
21+
}
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package com.baeldung.deeplearning4j.cnn.domain.network;
2+
3+
import com.baeldung.deeplearning4j.cnn.service.dataset.IDataSetService;
4+
import lombok.extern.slf4j.Slf4j;
5+
import org.deeplearning4j.eval.Evaluation;
6+
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
7+
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
8+
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
9+
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
10+
import org.deeplearning4j.nn.conf.layers.OutputLayer;
11+
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
12+
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
13+
import org.deeplearning4j.nn.weights.WeightInit;
14+
import org.nd4j.linalg.activations.Activation;
15+
import org.nd4j.linalg.lossfunctions.LossFunctions;
16+
17+
import java.util.stream.IntStream;
18+
19+
@Slf4j
20+
public class CnnModel {
21+
22+
private final IDataSetService dataSetService;
23+
24+
private MultiLayerNetwork network;
25+
26+
private final CnnModelProperties properties;
27+
28+
public CnnModel(IDataSetService dataSetService, CnnModelProperties properties) {
29+
30+
this.dataSetService = dataSetService;
31+
this.properties = properties;
32+
33+
MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
34+
.seed(1611)
35+
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
36+
.learningRate(properties.getLearningRate())
37+
.regularization(true)
38+
.updater(properties.getOptimizer())
39+
.list()
40+
.layer(0, conv5x5())
41+
.layer(1, pooling2x2Stride2())
42+
.layer(2, conv3x3Stride1Padding2())
43+
.layer(3, pooling2x2Stride1())
44+
.layer(4, conv3x3Stride1Padding1())
45+
.layer(5, pooling2x2Stride1())
46+
.layer(6, dense())
47+
.pretrain(false)
48+
.backprop(true)
49+
.setInputType(dataSetService.inputType())
50+
.build();
51+
52+
network = new MultiLayerNetwork(configuration);
53+
}
54+
55+
public void train() {
56+
network.init();
57+
int epochsNum = properties.getEpochsNum();
58+
IntStream.range(1, epochsNum + 1).forEach(epoch -> {
59+
log.info(String.format("Epoch %d?%d", epoch, epochsNum));
60+
network.fit(dataSetService.trainIterator());
61+
});
62+
}
63+
64+
public Evaluation evaluate() {
65+
return network.evaluate(dataSetService.testIterator());
66+
}
67+
68+
private ConvolutionLayer conv5x5() {
69+
return new ConvolutionLayer.Builder(5, 5)
70+
.nIn(3)
71+
.nOut(16)
72+
.stride(1, 1)
73+
.padding(1, 1)
74+
.weightInit(WeightInit.XAVIER_UNIFORM)
75+
.activation(Activation.RELU)
76+
.build();
77+
}
78+
79+
private SubsamplingLayer pooling2x2Stride2() {
80+
return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
81+
.kernelSize(2, 2)
82+
.stride(2, 2)
83+
.build();
84+
}
85+
86+
private ConvolutionLayer conv3x3Stride1Padding2() {
87+
return new ConvolutionLayer.Builder(3, 3)
88+
.nOut(32)
89+
.stride(1, 1)
90+
.padding(2, 2)
91+
.weightInit(WeightInit.XAVIER_UNIFORM)
92+
.activation(Activation.RELU)
93+
.build();
94+
}
95+
96+
private SubsamplingLayer pooling2x2Stride1() {
97+
return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
98+
.kernelSize(2,2)
99+
.stride(1, 1)
100+
.build();
101+
}
102+
103+
private ConvolutionLayer conv3x3Stride1Padding1() {
104+
return new ConvolutionLayer.Builder(3, 3)
105+
.nOut(64)
106+
.stride(1, 1)
107+
.padding(1, 1)
108+
.weightInit(WeightInit.XAVIER_UNIFORM)
109+
.activation(Activation.RELU)
110+
.build();
111+
}
112+
113+
private OutputLayer dense() {
114+
return new OutputLayer.Builder(LossFunctions.LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR)
115+
.activation(Activation.SOFTMAX)
116+
.weightInit(WeightInit.XAVIER_UNIFORM)
117+
.nOut(dataSetService.labels().size() - 1)
118+
.build();
119+
}
120+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package com.baeldung.deeplearning4j.cnn.domain.network;
2+
3+
import lombok.Value;
4+
import org.deeplearning4j.nn.conf.Updater;
5+
6+
@Value
7+
public class CnnModelProperties {
8+
private final int epochsNum = 512;
9+
10+
private final double learningRate = 0.001;
11+
12+
private final Updater optimizer = Updater.ADAM;
13+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package com.baeldung.deeplearning4j.cnn.service.dataset;
2+
3+
import lombok.Getter;
4+
import org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator;
5+
import org.deeplearning4j.nn.conf.inputs.InputType;
6+
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
7+
8+
import java.util.List;
9+
10+
@Getter
11+
public class CifarDataSetService implements IDataSetService {
12+
13+
private CifarDataSetIterator trainIterator;
14+
private CifarDataSetIterator testIterator;
15+
16+
private final InputType inputType = InputType.convolutional(32,32,3);
17+
private final int trainImagesNum = 512;
18+
private final int testImagesNum = 128;
19+
private final int trainBatch = 16;
20+
private final int testBatch = 8;
21+
22+
public CifarDataSetService() {
23+
trainIterator = new CifarDataSetIterator(trainBatch, trainImagesNum, true);
24+
testIterator = new CifarDataSetIterator(testBatch, testImagesNum, false);
25+
}
26+
27+
@Override
28+
public DataSetIterator trainIterator() {
29+
return trainIterator;
30+
}
31+
32+
@Override
33+
public DataSetIterator testIterator() {
34+
return testIterator;
35+
}
36+
37+
@Override
38+
public InputType inputType() {
39+
return inputType;
40+
}
41+
42+
@Override
43+
public List<String> labels() {
44+
return trainIterator.getLabels();
45+
}
46+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package com.baeldung.deeplearning4j.cnn.service.dataset;
2+
3+
import org.deeplearning4j.nn.conf.inputs.InputType;
4+
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
5+
6+
import java.util.List;
7+
8+
public interface IDataSetService {
9+
DataSetIterator trainIterator();
10+
11+
DataSetIterator testIterator();
12+
13+
InputType inputType();
14+
15+
List<String> labels();
16+
}

0 commit comments

Comments
 (0)