Source code for podpac.core.interpolation.scipy_interpolator

"""
Interpolator implementations
"""

from __future__ import division, unicode_literals, print_function, absolute_import

import numpy as np
import traitlets as tl

# Optional dependencies
try:
    import scipy
    from scipy.interpolate import griddata, RectBivariateSpline, RegularGridInterpolator
    from scipy.spatial import KDTree
except:
    scipy = None

# podac imports
from podpac.core.interpolation.interpolator import COMMON_INTERPOLATOR_DOCS, Interpolator, InterpolatorException
from podpac.core.units import UnitsDataArray
from podpac.core.coordinates import Coordinates, UniformCoordinates1d, StackedCoordinates
from podpac.core.utils import common_doc
from podpac.core.coordinates.utils import get_timedelta


[docs]@common_doc(COMMON_INTERPOLATOR_DOCS) class ScipyPoint(Interpolator): """Scipy Point Interpolation Attributes ---------- {interpolator_attributes} """ methods_supported = ["nearest"] method = tl.Unicode(default_value="nearest") dims_supported = ["lat", "lon"] # TODO: implement these parameters for the method 'nearest' spatial_tolerance = tl.Float(default_value=np.inf) time_tolerance = tl.Union([tl.Unicode(), tl.Instance(np.timedelta64, allow_none=True)])
[docs] @common_doc(COMMON_INTERPOLATOR_DOCS) def can_interpolate(self, udims, source_coordinates, eval_coordinates): """ {interpolator_can_interpolate} """ # TODO: make this so we don't need to specify lat and lon together # or at least throw a warning if ( "lat" in udims and "lon" in udims and not self._dim_in(["lat", "lon"], source_coordinates) and self._dim_in(["lat", "lon"], source_coordinates, unstacked=True) and self._dim_in(["lat", "lon"], eval_coordinates, unstacked=True) ): return tuple(["lat", "lon"]) # otherwise return no supported dims return tuple()
[docs] @common_doc(COMMON_INTERPOLATOR_DOCS) def interpolate(self, udims, source_coordinates, source_data, eval_coordinates, output_data): """ {interpolator_interpolate} """ order = "lat_lon" if "lat_lon" in source_coordinates.dims else "lon_lat" # calculate tolerance if isinstance(eval_coordinates["lat"], UniformCoordinates1d): dlat = eval_coordinates["lat"].step else: dlat = (eval_coordinates["lat"].bounds[1] - eval_coordinates["lat"].bounds[0]) / ( eval_coordinates["lat"].size - 1 ) if isinstance(eval_coordinates["lon"], UniformCoordinates1d): dlon = eval_coordinates["lon"].step else: dlon = (eval_coordinates["lon"].bounds[1] - eval_coordinates["lon"].bounds[0]) / ( eval_coordinates["lon"].size - 1 ) tol = np.linalg.norm([dlat, dlon]) * 8 if self._dim_in(["lat", "lon"], eval_coordinates): pts = np.stack([source_coordinates[dim].coordinates for dim in source_coordinates[order].dims], axis=1) if order == "lat_lon": pts = pts[:, ::-1] pts = KDTree(pts) lon, lat = np.meshgrid(eval_coordinates["lon"].coordinates, eval_coordinates["lat"].coordinates) dist, ind = pts.query(np.stack((lon.ravel(), lat.ravel()), axis=1), distance_upper_bound=tol) mask = ind == source_data[order].size ind[mask] = 0 # This is a hack to make the select on the next line work # (the masked values are set to NaN on the following line) vals = source_data[{order: ind}] vals[mask] = np.nan # make sure 'lat_lon' or 'lon_lat' is the first dimension dims = [dim for dim in source_data.dims if dim != order] vals = vals.transpose(order, *dims).data shape = vals.shape coords = [eval_coordinates["lat"].coordinates, eval_coordinates["lon"].coordinates] coords += [source_coordinates[d].coordinates for d in dims] vals = vals.reshape(eval_coordinates["lat"].size, eval_coordinates["lon"].size, *shape[1:]) vals = UnitsDataArray(vals, coords=coords, dims=["lat", "lon"] + dims) # and transpose back to the destination order output_data.data[:] = vals.transpose(*output_data.dims).data[:] return output_data elif self._dim_in(["lat", "lon"], eval_coordinates, unstacked=True): dst_order = "lat_lon" if "lat_lon" in eval_coordinates.dims else "lon_lat" src_stacked = np.stack( [source_coordinates[dim].coordinates for dim in source_coordinates[order].dims], axis=1 ) new_stacked = np.stack( [eval_coordinates[dim].coordinates for dim in source_coordinates[order].dims], axis=1 ) pts = KDTree(src_stacked) dist, ind = pts.query(new_stacked, distance_upper_bound=tol) mask = ind == source_data[order].size ind[mask] = 0 vals = source_data[{order: ind}] vals[{order: mask}] = np.nan dims = list(output_data.dims) dims[dims.index(dst_order)] = order output_data.data[:] = vals.transpose(*dims).data[:] return output_data
[docs]@common_doc(COMMON_INTERPOLATOR_DOCS) class ScipyGrid(ScipyPoint): """Scipy Interpolation Attributes ---------- {interpolator_attributes} """ methods_supported = ["nearest", "bilinear", "cubic_spline", "spline_2", "spline_3", "spline_4"] method = tl.Unicode(default_value="nearest") # TODO: implement these parameters for the method 'nearest' spatial_tolerance = tl.Float(default_value=np.inf) time_tolerance = tl.Union([tl.Unicode(), tl.Instance(np.timedelta64, allow_none=True)], default_value=None)
[docs] @common_doc(COMMON_INTERPOLATOR_DOCS) def can_interpolate(self, udims, source_coordinates, eval_coordinates): """ {interpolator_can_interpolate} """ # TODO: make this so we don't need to specify lat and lon together # or at least throw a warning if ( "lat" in udims and "lon" in udims and self._dim_in(["lat", "lon"], source_coordinates) and self._dim_in(["lat", "lon"], eval_coordinates, unstacked=True) ): return ["lat", "lon"] # otherwise return no supported dims return tuple()
[docs] @common_doc(COMMON_INTERPOLATOR_DOCS) def interpolate(self, udims, source_coordinates, source_data, eval_coordinates, output_data): """ {interpolator_interpolate} """ if self._dim_in(["lat", "lon"], eval_coordinates): return self._interpolate_irregular_grid( udims, source_coordinates, source_data, eval_coordinates, output_data, grid=True ) elif self._dim_in(["lat", "lon"], eval_coordinates, unstacked=True): eval_coordinates_us = eval_coordinates.unstack() return self._interpolate_irregular_grid( udims, source_coordinates, source_data, eval_coordinates_us, output_data, grid=False )
def _interpolate_irregular_grid( self, udims, source_coordinates, source_data, eval_coordinates, output_data, grid=True ): if len(source_data.dims) > 2: keep_dims = ["lat", "lon"] return self._loop_helper( self._interpolate_irregular_grid, keep_dims, udims, source_coordinates, source_data, eval_coordinates, output_data, grid=grid, ) s = [] if source_coordinates["lat"].is_descending: lat = source_coordinates["lat"].coordinates[::-1] s.append(slice(None, None, -1)) else: lat = source_coordinates["lat"].coordinates s.append(slice(None, None)) if source_coordinates["lon"].is_descending: lon = source_coordinates["lon"].coordinates[::-1] s.append(slice(None, None, -1)) else: lon = source_coordinates["lon"].coordinates s.append(slice(None, None)) data = source_data.data[tuple(s)] # remove nan's I, J = np.isfinite(lat), np.isfinite(lon) coords_i = lat[I], lon[J] coords_i_dst = [eval_coordinates["lon"].coordinates, eval_coordinates["lat"].coordinates] # Swap order in case datasource uses lon,lat ordering instead of lat,lon if source_coordinates.dims.index("lat") > source_coordinates.dims.index("lon"): I, J = J, I coords_i = coords_i[::-1] coords_i_dst = coords_i_dst[::-1] data = data[I, :][:, J] if self.method in ["bilinear", "nearest"]: f = RegularGridInterpolator( coords_i, data, method=self.method.replace("bi", ""), bounds_error=False, fill_value=np.nan ) if grid: x, y = np.meshgrid(*coords_i_dst) else: x, y = coords_i_dst output_data.data[:] = f((y.ravel(), x.ravel())).reshape(output_data.shape) # TODO: what methods is 'spline' associated with? elif "spline" in self.method: if self.method == "cubic_spline": order = 3 else: # TODO: make this a parameter order = int(self.method.split("_")[-1]) f = RectBivariateSpline(coords_i[0], coords_i[1], data, kx=max(1, order), ky=max(1, order)) output_data.data[:] = f(coords_i_dst[1], coords_i_dst[0], grid=grid).reshape(output_data.shape) return output_data