diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 1e6f893..3bed241 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,7 +9,7 @@ see `#17 `__ with psy.plot.mapplot('file.nc') as sp: sp.export('output.png') - sp will be closed automatically (see commit `ee7415b `__) + sp will be closed automatically (see `#18 `__) Changed ------- @@ -18,6 +18,8 @@ Changed * Specifying names in `x`, `y`, `t` and `z` attributes of the `CFDecoder` class now means that any other attribute (such as the `coordinates` or `axis` attribute) are ignored +* If a given variable cannot be found in the provided coords to ``CFDecoder.get_variable_by_axis``, + we fall back to the ``CFDecoder.ds.coords`` attribute, see `#19 `__ v1.2.1 diff --git a/psyplot/data.py b/psyplot/data.py index 3e4f2c4..ac04b90 100755 --- a/psyplot/data.py +++ b/psyplot/data.py @@ -905,6 +905,23 @@ def get_variable_by_axis(self, var, axis, coords=None): See Also -------- get_x, get_y, get_z, get_t""" + + def get_coord(cname, raise_error=True): + try: + return coords[cname] + except KeyError: + if cname not in self.ds.coords: + if raise_error: + raise + return None + ret = self.ds.coords[cname] + try: + idims = var.psy.idims + except AttributeError: # got xarray.Variable + idims = {} + return ret.isel(**{d: sl for d, sl in idims.items() + if d in ret.dims}) + axis = axis.lower() if axis not in list('xyzt'): raise ValueError("Axis must be one of X, Y, Z, T, not {0}".format( @@ -949,23 +966,23 @@ def get_variable_by_axis(self, var, axis, coords=None): if axis == 'x': for cname in filter(lambda cname: re.search('lon', cname), coord_names): - return coords[cname] - return coords.get(coord_names[-1]) + return get_coord(cname) + return get_coord(coord_names[-1], raise_error=False) elif axis == 'y' and len(coord_names) >= 2: for cname in filter(lambda cname: re.search('lat', cname), coord_names): - return coords[cname] - return coords.get(coord_names[-2]) + return get_coord(cname) + return get_coord(coord_names[-2], raise_error=False) elif (axis == 'z' and len(coord_names) >= 3 and coord_names[-3] not in tnames): - return coords.get(coord_names[-3]) + return get_coord(coord_names[-3], raise_error=False) elif axis == 't' and tnames: tname = next(iter(tnames)) if len(tnames) > 1: warn("Found multiple matches for time coordinate in the " "coordinates: %s. I use %s" % (', '.join(tnames), tname), PsyPlotRuntimeWarning) - return coords.get(tname) + return get_coord(tname, raise_error=False) @docstrings.get_sectionsf("CFDecoder.get_x", sections=[ 'Parameters', 'Returns']) diff --git a/tests/test_data.py b/tests/test_data.py index 9caeebe..166381d 100755 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -346,6 +346,19 @@ def test_get_variable_by_axis(self): # close the dataset ds.close() + def test_get_variable_by_axis_02(self): + """Test the :meth:`CFDecoder.get_variable_by_axis` method with missing + coordinates, see https://github.com/psyplot/psyplot/pull/19""" + fname = os.path.join(bt.test_dir, 'icon_test.nc') + with psyd.open_dataset(fname) as ds: + ds['ncells'] = ('ncells', np.arange(ds.dims['ncells'])) + decoder = psyd.CFDecoder(ds) + arr = ds.psy['t2m'].psy.isel(ncells=slice(3, 10)) + del arr['clon'] + xcoord = decoder.get_variable_by_axis(arr, 'x', arr.coords) + self.assertEqual(xcoord.name, 'clon') + self.assertEqual(list(xcoord.ncells), list(arr.ncells)) + def test_plot_bounds_1d(self): """Test to get 2d-interval breaks""" x = xr.Variable(('x', ), np.arange(1, 5))