Skip to content

Commit eeff326

Browse files
author
helga_sh
committed
CNN example with Deeplearning4j in Java: refactor
1 parent 6e1cab0 commit eeff326

6 files changed

Lines changed: 26 additions & 28 deletions

File tree

deeplearning4j/pom.xml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@
4040
<dependency>
4141
<groupId>org.slf4j</groupId>
4242
<artifactId>slf4j-api</artifactId>
43-
<version>1.7.5</version>
43+
<version>${sl4j.version}</version>
4444
</dependency>
4545
<dependency>
4646
<groupId>org.slf4j</groupId>
4747
<artifactId>slf4j-log4j12</artifactId>
48-
<version>1.7.5</version>
48+
<version>${sl4j.version}</version>
4949
</dependency>
5050
<!-- https://mvnrepository.com/artifact/org.datavec/datavec-api -->
5151
<dependency>
@@ -63,6 +63,7 @@
6363
<properties>
6464
<dl4j.version>0.9.1</dl4j.version> <!-- Latest non beta version -->
6565
<httpclient.version>4.3.5</httpclient.version>
66+
<sl4j.version>1.7.5</sl4j.version>
6667
</properties>
6768

6869
</project>

deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/CifarDataSetService.java renamed to deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CifarDataSetService.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package com.baeldung.deeplearning4j.cnn.service.dataset;
1+
package com.baeldung.deeplearning4j.cnn;
22

33
import lombok.Getter;
44
import org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator;
@@ -8,18 +8,19 @@
88
import java.util.List;
99

1010
@Getter
11-
public class CifarDataSetService implements IDataSetService {
11+
class CifarDataSetService implements IDataSetService {
1212

13-
private CifarDataSetIterator trainIterator;
14-
private CifarDataSetIterator testIterator;
15-
16-
private final InputType inputType = InputType.convolutional(32,32,3);
13+
private final InputType inputType = InputType.convolutional(32, 32, 3);
1714
private final int trainImagesNum = 512;
1815
private final int testImagesNum = 128;
1916
private final int trainBatch = 16;
2017
private final int testBatch = 8;
2118

22-
public CifarDataSetService() {
19+
private final CifarDataSetIterator trainIterator;
20+
21+
private final CifarDataSetIterator testIterator;
22+
23+
CifarDataSetService() {
2324
trainIterator = new CifarDataSetIterator(trainBatch, trainImagesNum, true);
2425
testIterator = new CifarDataSetIterator(testBatch, testImagesNum, false);
2526
}

deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnExample.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
package com.baeldung.deeplearning4j.cnn;
22

33

4-
import com.baeldung.deeplearning4j.cnn.domain.network.CnnModel;
5-
import com.baeldung.deeplearning4j.cnn.domain.network.CnnModelProperties;
6-
import com.baeldung.deeplearning4j.cnn.service.dataset.CifarDataSetService;
74
import lombok.extern.slf4j.Slf4j;
85
import org.deeplearning4j.eval.Evaluation;
96

107
@Slf4j
11-
public class CnnExample {
8+
class CnnExample {
129

1310
public static void main(String... args) {
1411
CnnModel network = new CnnModel(new CifarDataSetService(), new CnnModelProperties());

deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModel.java renamed to deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnModel.java

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
package com.baeldung.deeplearning4j.cnn.domain.network;
1+
package com.baeldung.deeplearning4j.cnn;
22

3-
import com.baeldung.deeplearning4j.cnn.service.dataset.IDataSetService;
43
import lombok.extern.slf4j.Slf4j;
54
import org.deeplearning4j.eval.Evaluation;
65
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
@@ -17,15 +16,15 @@
1716
import java.util.stream.IntStream;
1817

1918
@Slf4j
20-
public class CnnModel {
19+
class CnnModel {
2120

2221
private final IDataSetService dataSetService;
2322

24-
private MultiLayerNetwork network;
23+
private final MultiLayerNetwork network;
2524

2625
private final CnnModelProperties properties;
2726

28-
public CnnModel(IDataSetService dataSetService, CnnModelProperties properties) {
27+
CnnModel(IDataSetService dataSetService, CnnModelProperties properties) {
2928

3029
this.dataSetService = dataSetService;
3130
this.properties = properties;
@@ -52,17 +51,17 @@ public CnnModel(IDataSetService dataSetService, CnnModelProperties properties) {
5251
network = new MultiLayerNetwork(configuration);
5352
}
5453

55-
public void train() {
54+
void train() {
5655
network.init();
5756
int epochsNum = properties.getEpochsNum();
5857
IntStream.range(1, epochsNum + 1).forEach(epoch -> {
59-
log.info(String.format("Epoch %d?%d", epoch, epochsNum));
58+
log.info("Epoch {} / {}", epoch, epochsNum);
6059
network.fit(dataSetService.trainIterator());
6160
});
6261
}
6362

64-
public Evaluation evaluate() {
65-
return network.evaluate(dataSetService.testIterator());
63+
Evaluation evaluate() {
64+
return network.evaluate(dataSetService.testIterator());
6665
}
6766

6867
private ConvolutionLayer conv5x5() {
@@ -84,7 +83,7 @@ private SubsamplingLayer pooling2x2Stride2() {
8483
}
8584

8685
private ConvolutionLayer conv3x3Stride1Padding2() {
87-
return new ConvolutionLayer.Builder(3, 3)
86+
return new ConvolutionLayer.Builder(3, 3)
8887
.nOut(32)
8988
.stride(1, 1)
9089
.padding(2, 2)
@@ -95,7 +94,7 @@ private ConvolutionLayer conv3x3Stride1Padding2() {
9594

9695
private SubsamplingLayer pooling2x2Stride1() {
9796
return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
98-
.kernelSize(2,2)
97+
.kernelSize(2, 2)
9998
.stride(1, 1)
10099
.build();
101100
}

deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/domain/network/CnnModelProperties.java renamed to deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/CnnModelProperties.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
package com.baeldung.deeplearning4j.cnn.domain.network;
1+
package com.baeldung.deeplearning4j.cnn;
22

33
import lombok.Value;
44
import org.deeplearning4j.nn.conf.Updater;
55

66
@Value
7-
public class CnnModelProperties {
7+
class CnnModelProperties {
88
private final int epochsNum = 512;
99

1010
private final double learningRate = 0.001;

deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/service/dataset/IDataSetService.java renamed to deeplearning4j/src/main/java/com/baeldung/deeplearning4j/cnn/IDataSetService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
package com.baeldung.deeplearning4j.cnn.service.dataset;
1+
package com.baeldung.deeplearning4j.cnn;
22

33
import org.deeplearning4j.nn.conf.inputs.InputType;
44
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
55

66
import java.util.List;
77

8-
public interface IDataSetService {
8+
interface IDataSetService {
99
DataSetIterator trainIterator();
1010

1111
DataSetIterator testIterator();

0 commit comments

Comments
 (0)