from __future__ import division, unicode_literals, print_function, absolute_import
from six import string_types

import numpy as np
import xarray as xr
import traitlets as tl
from scipy.spatial import cKDTree

from podpac.core.interpolation.interpolator import COMMON_INTERPOLATOR_DOCS, Interpolator, InterpolatorException
from podpac.core.coordinates import Coordinates, UniformCoordinates1d, StackedCoordinates
from podpac.core.coordinates.utils import make_coord_delta, make_coord_value
from podpac.core.utils import common_doc
from podpac.core.coordinates.utils import get_timedelta
from podpac.core.interpolation.selector import Selector, _higher_precision_time_coords1d, _higher_precision_time_stack

[docs]@common_doc(COMMON_INTERPOLATOR_DOCS) class NearestNeighbor(Interpolator): """Nearest Neighbor Interpolation {nearest_neighbor_attributes} """ dims_supported = ["lat", "lon", "alt", "time"] methods_supported = ["nearest"] # defined at instantiation method = tl.Unicode(default_value="nearest") spatial_tolerance = tl.Float(default_value=np.inf, allow_none=True) time_tolerance = tl.Union([tl.Unicode(), tl.Instance(np.timedelta64, allow_none=True)]) alt_tolerance = tl.Float(default_value=np.inf, allow_none=True) # spatial_scale only applies when the source is stacked with time or alt. The supplied value will be assigned a distance of "1'" spatial_scale = tl.Float(default_value=1, allow_none=True) # time_scale only applies when the source is stacked with lat, lon, or alt. The supplied value will be assigned a distance of "1'" time_scale = tl.Union([tl.Unicode(), tl.Instance(np.timedelta64, allow_none=True)]) # alt_scale only applies when the source is stacked with lat, lon, or time. The supplied value will be assigned a distance of "1'" alt_scale = tl.Float(default_value=1, allow_none=True) respect_bounds = tl.Bool(True) remove_nan = tl.Bool(False) use_selector = tl.Bool(True) def __repr__(self): rep = super(NearestNeighbor, self).__repr__() # rep += '\n\tspatial_tolerance: {}\n\ttime_tolerance: {}'.format(self.spatial_tolerance, self.time_tolerance) return rep
[docs] @common_doc(COMMON_INTERPOLATOR_DOCS) def can_interpolate(self, udims, source_coordinates, eval_coordinates): """ {interpolator_interpolate} """ udims_subset = self._filter_udims_supported(udims) return udims_subset
[docs] def can_select(self, udims, source_coordinates, eval_coordinates): selector = super().can_select(udims, source_coordinates, eval_coordinates) if self.use_selector: return selector return ()
[docs] @common_doc(COMMON_INTERPOLATOR_DOCS) def interpolate(self, udims, source_coordinates, source_data, eval_coordinates, output_data): """ {interpolator_interpolate} """ # Note, some of the following code duplicates code in the Selector class. # This duplication is for the sake of optimization def is_stacked(d): return "_" in d if hasattr(source_data, "attrs") and "bounds" in source_data.attrs: bounds = source_data.attrs["bounds"] if "time" in bounds and bounds["time"]: if "time" in eval_coordinates.udims: bounds["time"] = [ self._atime_to_float(b, source_coordinates["time"], eval_coordinates["time"]) for b in bounds["time"] ] else: bounds["time"] = [ self._atime_to_float(b, source_coordinates["time"], source_coordinates["time"]) for b in bounds["time"] ] else: bounds = None if self.remove_nan: # Eliminate nans from the source data. Note, this could turn a uniform griddted dataset into a stacked one source_data, source_coordinates = self._remove_nans(source_data, source_coordinates) data_index = [] for d in source_coordinates.dims: # Make sure we're supposed to do nearest neighbor interpolation for this UDIM, otherwise skip this dimension if len([dd for dd in d.split("_") if dd in udims]) == 0: index = self._resize_unstacked_index(np.arange(source_coordinates[d].size), d, eval_coordinates) data_index.append(index) continue source = source_coordinates[d] if is_stacked(d): if bounds is not None: bound = np.stack([bounds[dd] for dd in d.split("_")], axis=1) else: bound = None index = self._get_stacked_index(d, source, eval_coordinates, bound) if len(source.shape) == 2: # Handle case of 2D-stacked coordinates ncols = source.shape[1] index1 = index // ncols index1 = self._resize_stacked_index(index1, d, eval_coordinates) # With nD stacked coordinates, there are 'n' indices in the tuple # All of these need to get into the data_index, and in the right order data_index.append(index1) # This is a hack index = index % ncols # The second half can go through the usual machinery elif len(source.shape) > 2: # Handle case of nD-stacked coordinates raise NotImplementedError index = self._resize_stacked_index(index, d, eval_coordinates) elif source_coordinates[d].is_uniform: request = eval_coordinates[d] if bounds is not None: bound = bounds[d] else: bound = None index = self._get_uniform_index(d, source, request, bound) index = self._resize_unstacked_index(index, d, eval_coordinates) else: # non-uniform coordinates... probably an optimization here request = eval_coordinates[d] if bounds is not None: bound = bounds[d] else: bound = None index = self._get_nonuniform_index(d, source, request, bound) index = self._resize_unstacked_index(index, d, eval_coordinates) data_index.append(index) index = tuple(data_index)[:] = np.array(source_data)[index] bool_inds = sum([i == -1 for i in index]).astype(bool)[bool_inds] = np.nan return output_data
def _remove_nans(self, source_data, source_coordinates): index = np.array(np.isnan(source_data), bool) if not np.any(index): return source_data, source_coordinates data =[~index] coords = np.meshgrid( *[source_coordinates[d.split("_")[0]].coordinates for d in source_coordinates.dims], indexing="ij" ) repeat_shape = coords[0].shape coords = [c[~index] for c in coords] final_dims = [d.split("_")[0] for d in source_coordinates.dims] # Add back in any stacked coordinates for i, d in enumerate(source_coordinates.dims): dims = d.split("_") if len(dims) == 1: continue reshape = np.ones(len(coords), int) reshape[i] = -1 repeats = list(repeat_shape) repeats[i] = 1 for dd in dims[1:]: crds = source_coordinates[dd].coordinates.reshape(*reshape) for j, r in enumerate(repeats): crds = crds.repeat(r, axis=j) coords.append(crds[~index]) final_dims.append(dd) return data, Coordinates([coords], dims=[final_dims]) def _get_tol(self, dim, source, request): if dim in ["lat", "lon"]: return self.spatial_tolerance if dim == "alt": return self.alt_tolerance if dim == "time": if self.time_tolerance == "": return np.inf return self._time_to_float(self.time_tolerance, source, request) raise NotImplementedError() def _get_scale(self, dim, source_1d, request_1d): if dim in ["lat", "lon"]: return 1 / self.spatial_scale if dim == "alt": return 1 / self.alt_scale if dim == "time": if self.time_scale == "": return 1.0 return 1 / self._time_to_float(self.time_scale, source_1d, request_1d) raise NotImplementedError() def _time_to_float(self, time, time_source, time_request): dtype0 = time_source.coordinates[0].dtype dtype1 = time_request.coordinates[0].dtype dtype = dtype0 if dtype0 > dtype1 else dtype1 time = make_coord_delta(time) if isinstance(time, np.timedelta64): time1 = (time + np.datetime64("2000")).astype(dtype).astype(float) - ( np.datetime64("2000").astype(dtype).astype(float) ) return time1 def _atime_to_float(self, time, time_source, time_request): dtype0 = time_source.coordinates[0].dtype dtype1 = time_request.coordinates[0].dtype dtype = dtype0 if dtype0 > dtype1 else dtype1 time = make_coord_value(time) if isinstance(time, np.datetime64): time = time.astype(dtype).astype(float) return time def _get_stacked_index(self, dim, source, request, bounds=None): # The udims are in the order of the request so that the meshgrid calls will be in the right order udims = [ud for ud in request.udims if ud in source.udims] time_source = time_request = None if "time" in udims: time_source = source["time"] time_request = request["time"] tols = np.array([self._get_tol(d, time_source, time_request) for d in udims])[None, :] scales = np.array([self._get_scale(d, time_source, time_request) for d in udims])[None, :] tol = np.linalg.norm((tols * scales).squeeze()) src_coords, req_coords_diag = _higher_precision_time_stack(source, request, udims) # We need to unwravel the nD stacked coordinates ckdtree_source = cKDTree(src_coords.reshape(src_coords.shape[0], -1).T * scales) # if the udims are all stacked in the same stack as part of the request coordinates, then we're done. # Otherwise we have to evaluate each unstacked set of dimensions independently # Note, part of this code is duplicated in the selector indep_evals = [ud for ud in udims if not request.is_stacked(ud)] # two udims could be stacked, but in different dim groups, e.g. source (lat, lon), request (lat, time), (lon, alt) stacked = {d for d in request.dims for ud in udims if ud in d and request.is_stacked(ud)} if (len(indep_evals) + len(stacked)) <= 1: # output is stacked in the same way # The ckdtree call below needs the lat/lon pairs in the last axis position req_coords = np.moveaxis(req_coords_diag, 0, -1) elif (len(stacked) == 0) | (len(indep_evals) == 0 and len(stacked) == len(udims)): req_coords = np.stack([i.ravel() for i in np.meshgrid(*req_coords_diag, indexing="ij")], axis=1) else: # Rare cases? E.g. lat_lon_time_alt source to lon, time_alt, lat destination sizes = [request[d].size for d in request.dims] reshape = np.ones(len(request.dims), int) coords = [None] * len(udims) for i in range(len(udims)): ii = [ii for ii in range(len(request.dims)) if udims[i] in request.dims[ii]][0] reshape[:] = 1 reshape[ii] = -1 coords[i] = req_coords_diag[i].reshape(*reshape) for j, d in enumerate(request.dims): if udims[i] in d: # Then we don't need to repeat continue coords[i] = coords[i].repeat(sizes[j], axis=j) req_coords = np.stack([i.ravel() for i in coords], axis=1) dist, index = ckdtree_source.query(req_coords * scales, k=1) if self.respect_bounds: if bounds is None: bounds = np.stack( [ src_coords.reshape(src_coords.shape[0], -1).T.min(0), src_coords.reshape(src_coords.shape[0], -1).T.max(0), ], axis=1, ) # Fix order of bounds bounds = bounds[:, [source.udims.index(dim) for dim in udims]] index[np.any((req_coords > bounds[1]), axis=-1) | np.any((req_coords < bounds[0]), axis=-1)] = -1 if tol and tol != np.inf: index[dist > tol] = -1 return index def _get_uniform_index(self, dim, source, request, bounds=None): tol = self._get_tol(dim, source, request) index = (request.coordinates - source.start) / source.step rindex = np.around(index).astype(int) stop_ind = int(source.size) if self.respect_bounds: rindex[(rindex < 0) | (rindex >= stop_ind)] = -1 else: rindex = np.clip(rindex, 0, stop_ind - 1) if tol and tol != np.inf: if dim == "time": step = self._time_to_float(source.step, source, request) else: step = source.step rindex[np.abs(index - rindex) * np.abs(step) > tol] = -1 return rindex def _get_nonuniform_index(self, dim, source, request, bounds=None): tol = self._get_tol(dim, source, request) src, req = _higher_precision_time_coords1d(source, request) ckdtree_source = cKDTree(src.reshape(-1, 1)) dist, index = ckdtree_source.query(req[:].reshape(-1, 1), k=1) index[index == source.coordinates.size] = -1 if self.respect_bounds: if bounds is None: bounds = [src.min(), src.max()] index[(req.ravel() > bounds[1]) | (req.ravel() < bounds[0])] = -1 if tol and tol != np.inf: index[dist > tol] = -1 return index def _resize_unstacked_index(self, index, source_dim, request): # When the request is stacked, and the stacked dimensions are n-dimensions where n > 1, # Then len(request.shape) != len(request.dims), so it take s a little bit of footwork # to get the correct shape for the index reshape = np.array(request.shape) i = 0 for dim in request.dims: addnext = len(request[dim].shape) if source_dim not in dim: reshape[i : i + addnext] = 1 i += addnext return index.reshape(*reshape) def _resize_stacked_index(self, index, source_dim, request): reshape = np.array(request.shape) i = 0 for dim in request.dims: addnext = len(request[dim].shape) d = dim.split("_") if not any([dd in source_dim for dd in d]): reshape[i : i + addnext] = 1 i += addnext return index.reshape(*reshape)
[docs]@common_doc(COMMON_INTERPOLATOR_DOCS) class NearestPreview(NearestNeighbor): """Nearest Neighbor (Preview) Interpolation {nearest_neighbor_attributes} """ methods_supported = ["nearest_preview"] method = tl.Unicode(default_value="nearest_preview") spatial_tolerance = tl.Float(read_only=True, allow_none=True, default_value=None)
[docs] @common_doc(COMMON_INTERPOLATOR_DOCS) def can_select(self, udims, source_coordinates, eval_coordinates): """ {interpolator_can_select} """ udims_subset = self._filter_udims_supported(udims) # confirm that udims are in source and eval coordinates # TODO: handle stacked coordinates if self._dim_in(udims_subset, source_coordinates): return udims_subset else: return tuple()
[docs] @common_doc(COMMON_INTERPOLATOR_DOCS) def select_coordinates(self, udims, source_coordinates, eval_coordinates, index_type="numpy"): """ {interpolator_select} """ new_coords = [] new_coords_idx = [] source_coords, source_coords_index = source_coordinates.intersect( eval_coordinates, outer=True, return_index=True ) if source_coords.size == 0: return source_coords, source_coords_index # iterate over the source coordinate dims in case they are stacked for src_dim, idx in zip(source_coords, source_coords_index): # TODO: handle stacked coordinates if isinstance(source_coords[src_dim], StackedCoordinates): raise InterpolatorException("NearestPreview select does not yet support stacked dimensions") if src_dim in eval_coordinates.dims: src_coords = source_coords[src_dim] dst_coords = eval_coordinates[src_dim] if isinstance(dst_coords, UniformCoordinates1d): dst_start = dst_coords.start dst_stop = dst_coords.stop dst_delta = dst_coords.step else: dst_start = dst_coords.coordinates[0] dst_stop = dst_coords.coordinates[-1] with np.errstate(invalid="ignore"): dst_delta = (dst_stop - dst_start) / (dst_coords.size - 1) if isinstance(src_coords, UniformCoordinates1d): src_start = src_coords.start src_stop = src_coords.stop src_delta = src_coords.step else: src_start = src_coords.coordinates[0] src_stop = src_coords.coordinates[-1] with np.errstate(invalid="ignore"): src_delta = (src_stop - src_start) / (src_coords.size - 1) ndelta = max(1, np.round(np.abs(dst_delta / src_delta))) idx_offset = 0 if src_coords.size == 1: c = src_coords.copy() else: c_test = UniformCoordinates1d(src_start, src_stop, ndelta * src_delta, ** bounds = source_coordinates[src_dim].bounds # The delta/2 ensures the endpoint is included when there is a floating point rounding error # the delta/2 is more than needed, but does guarantee. src_stop = np.clip(src_stop + ndelta * src_delta / 2, bounds[0], bounds[1]) c = UniformCoordinates1d(src_start, src_stop, ndelta * src_delta, ** if c.size > c_test.size: # need to adjust the index as well idx_offset = int(ndelta) idx_start = idx.start if isinstance(idx, slice) else idx[0] idx_stop = idx.stop if isinstance(idx, slice) else idx[-1] if idx_stop is not None: idx_stop += idx_offset idx = slice(idx_start, idx_stop, int(ndelta)) else: c = source_coords[src_dim] new_coords.append(c) new_coords_idx.append(idx) return Coordinates(new_coords, validate_crs=False), tuple(new_coords_idx)