Skip to content

Commit e687a71

Browse files
authored
Merge pull request #4630 from BlGene/load_hdf5_fix
Made load_hd5 check blob dims by default, instead of reshaping.
2 parents 85ab610 + 95a436c commit e687a71

5 files changed

Lines changed: 39 additions & 14 deletions

File tree

include/caffe/util/hdf5.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ namespace caffe {
1313
template <typename Dtype>
1414
void hdf5_load_nd_dataset_helper(
1515
hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
16-
Blob<Dtype>* blob);
16+
Blob<Dtype>* blob, bool reshape);
1717

1818
template <typename Dtype>
1919
void hdf5_load_nd_dataset(
2020
hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
21-
Blob<Dtype>* blob);
21+
Blob<Dtype>* blob, bool reshape = false);
2222

2323
template <typename Dtype>
2424
void hdf5_save_nd_dataset(

src/caffe/layers/hdf5_data_layer.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ void HDF5DataLayer<Dtype>::LoadHDF5FileData(const char* filename) {
3939

4040
for (int i = 0; i < top_size; ++i) {
4141
hdf_blobs_[i] = shared_ptr<Blob<Dtype> >(new Blob<Dtype>());
42+
// Allow reshape here, as we are loading data not params
4243
hdf5_load_nd_dataset(file_id, this->layer_param_.top(i).c_str(),
43-
MIN_DATA_DIM, MAX_DATA_DIM, hdf_blobs_[i].get());
44+
MIN_DATA_DIM, MAX_DATA_DIM, hdf_blobs_[i].get(), true);
4445
}
4546

4647
herr_t status = H5Fclose(file_id);

src/caffe/test/test_hdf5_output_layer.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,12 @@ TYPED_TEST(HDF5OutputLayerTest, TestForward) {
7777
H5P_DEFAULT);
7878
ASSERT_GE(file_id, 0)<< "Failed to open HDF5 file" <<
7979
this->input_file_name_;
80+
// Allow reshape here as we are loading data not params
81+
bool reshape = true;
8082
hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
81-
this->blob_data_);
83+
this->blob_data_, reshape);
8284
hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
83-
this->blob_label_);
85+
this->blob_label_, reshape);
8486
herr_t status = H5Fclose(file_id);
8587
EXPECT_GE(status, 0)<< "Failed to close HDF5 file " <<
8688
this->input_file_name_;
@@ -105,12 +107,12 @@ TYPED_TEST(HDF5OutputLayerTest, TestForward) {
105107

106108
Blob<Dtype>* blob_data = new Blob<Dtype>();
107109
hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
108-
blob_data);
110+
blob_data, reshape);
109111
this->CheckBlobEqual(*(this->blob_data_), *blob_data);
110112

111113
Blob<Dtype>* blob_label = new Blob<Dtype>();
112114
hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
113-
blob_label);
115+
blob_label, reshape);
114116
this->CheckBlobEqual(*(this->blob_label_), *blob_label);
115117

116118
status = H5Fclose(file_id);

src/caffe/test/test_hdf5data_layer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ TYPED_TEST(HDF5DataLayerTest, TestRead) {
7070
int height = 6;
7171
int width = 5;
7272

73-
// Test that the layer setup got the correct parameters.
73+
// Test that the layer setup gives correct parameters.
7474
HDF5DataLayer<Dtype> layer(param);
7575
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
7676
EXPECT_EQ(this->blob_top_data_->num(), batch_size);

src/caffe/util/hdf5.cpp

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace caffe {
99
template <typename Dtype>
1010
void hdf5_load_nd_dataset_helper(
1111
hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
12-
Blob<Dtype>* blob) {
12+
Blob<Dtype>* blob, bool reshape) {
1313
// Verify that the dataset exists.
1414
CHECK(H5LTfind_dataset(file_id, dataset_name_))
1515
<< "Failed to find HDF5 dataset " << dataset_name_;
@@ -56,26 +56,48 @@ void hdf5_load_nd_dataset_helper(
5656
LOG(FATAL) << "Datatype class unknown";
5757
}
5858

59+
5960
vector<int> blob_dims(dims.size());
6061
for (int i = 0; i < dims.size(); ++i) {
6162
blob_dims[i] = dims[i];
6263
}
63-
blob->Reshape(blob_dims);
64+
65+
if (reshape) {
66+
blob->Reshape(blob_dims);
67+
} else {
68+
if (blob_dims != blob->shape()) {
69+
// create shape string for error message
70+
ostringstream stream;
71+
int count = 1;
72+
for (int i = 0; i < blob_dims.size(); ++i) {
73+
stream << blob_dims[i] << " ";
74+
count = count * blob_dims[i];
75+
}
76+
stream << "(" << count << ")";
77+
string source_shape_string = stream.str();
78+
79+
CHECK(blob_dims == blob->shape()) << "Cannot load blob from hdf5; shape "
80+
<< "mismatch. Source shape is " << source_shape_string
81+
<< " target shape is " << blob->shape_string();
82+
}
83+
}
6484
}
6585

6686
template <>
6787
void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
68-
int min_dim, int max_dim, Blob<float>* blob) {
69-
hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
88+
int min_dim, int max_dim, Blob<float>* blob, bool reshape) {
89+
hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob,
90+
reshape);
7091
herr_t status = H5LTread_dataset_float(
7192
file_id, dataset_name_, blob->mutable_cpu_data());
7293
CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_;
7394
}
7495

7596
template <>
7697
void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,
77-
int min_dim, int max_dim, Blob<double>* blob) {
78-
hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
98+
int min_dim, int max_dim, Blob<double>* blob, bool reshape) {
99+
hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob,
100+
reshape);
79101
herr_t status = H5LTread_dataset_double(
80102
file_id, dataset_name_, blob->mutable_cpu_data());
81103
CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_;

0 commit comments

Comments
 (0)