@@ -9,7 +9,7 @@ namespace caffe {
99template <typename Dtype>
1010void hdf5_load_nd_dataset_helper (
1111 hid_t file_id, const char * dataset_name_, int min_dim, int max_dim,
12- Blob<Dtype>* blob) {
12+ Blob<Dtype>* blob, bool reshape ) {
1313 // Verify that the dataset exists.
1414 CHECK (H5LTfind_dataset (file_id, dataset_name_))
1515 << " Failed to find HDF5 dataset " << dataset_name_;
@@ -56,26 +56,48 @@ void hdf5_load_nd_dataset_helper(
5656 LOG (FATAL) << " Datatype class unknown" ;
5757 }
5858
59+
5960 vector<int > blob_dims (dims.size ());
6061 for (int i = 0 ; i < dims.size (); ++i) {
6162 blob_dims[i] = dims[i];
6263 }
63- blob->Reshape (blob_dims);
64+
65+ if (reshape) {
66+ blob->Reshape (blob_dims);
67+ } else {
68+ if (blob_dims != blob->shape ()) {
69+ // create shape string for error message
70+ ostringstream stream;
71+ int count = 1 ;
72+ for (int i = 0 ; i < blob_dims.size (); ++i) {
73+ stream << blob_dims[i] << " " ;
74+ count = count * blob_dims[i];
75+ }
76+ stream << " (" << count << " )" ;
77+ string source_shape_string = stream.str ();
78+
79+ CHECK (blob_dims == blob->shape ()) << " Cannot load blob from hdf5; shape "
80+ << " mismatch. Source shape is " << source_shape_string
81+ << " target shape is " << blob->shape_string ();
82+ }
83+ }
6484}
6585
6686template <>
6787void hdf5_load_nd_dataset<float >(hid_t file_id, const char * dataset_name_,
68- int min_dim, int max_dim, Blob<float >* blob) {
69- hdf5_load_nd_dataset_helper (file_id, dataset_name_, min_dim, max_dim, blob);
88+ int min_dim, int max_dim, Blob<float >* blob, bool reshape) {
89+ hdf5_load_nd_dataset_helper (file_id, dataset_name_, min_dim, max_dim, blob,
90+ reshape);
7091 herr_t status = H5LTread_dataset_float (
7192 file_id, dataset_name_, blob->mutable_cpu_data ());
7293 CHECK_GE (status, 0 ) << " Failed to read float dataset " << dataset_name_;
7394}
7495
7596template <>
7697void hdf5_load_nd_dataset<double >(hid_t file_id, const char * dataset_name_,
77- int min_dim, int max_dim, Blob<double >* blob) {
78- hdf5_load_nd_dataset_helper (file_id, dataset_name_, min_dim, max_dim, blob);
98+ int min_dim, int max_dim, Blob<double >* blob, bool reshape) {
99+ hdf5_load_nd_dataset_helper (file_id, dataset_name_, min_dim, max_dim, blob,
100+ reshape);
79101 herr_t status = H5LTread_dataset_double (
80102 file_id, dataset_name_, blob->mutable_cpu_data ());
81103 CHECK_GE (status, 0 ) << " Failed to read double dataset " << dataset_name_;
0 commit comments