"""
Interpolator implementations
"""
from __future__ import division, unicode_literals, print_function, absolute_import
import numpy as np
import traitlets as tl
# Optional dependencies
try:
import scipy
from scipy.interpolate import griddata, RectBivariateSpline, RegularGridInterpolator
from scipy.spatial import KDTree
except:
scipy = None
# podac imports
from podpac.core.interpolation.interpolator import COMMON_INTERPOLATOR_DOCS, Interpolator, InterpolatorException
from podpac.core.units import UnitsDataArray
from podpac.core.coordinates import Coordinates, UniformCoordinates1d, StackedCoordinates
from podpac.core.utils import common_doc
from podpac.core.coordinates.utils import get_timedelta
[docs]@common_doc(COMMON_INTERPOLATOR_DOCS)
class ScipyPoint(Interpolator):
"""Scipy Point Interpolation
Attributes
----------
{interpolator_attributes}
"""
methods_supported = ["nearest"]
method = tl.Unicode(default_value="nearest")
dims_supported = ["lat", "lon"]
# TODO: implement these parameters for the method 'nearest'
spatial_tolerance = tl.Float(default_value=np.inf)
time_tolerance = tl.Union([tl.Unicode(), tl.Instance(np.timedelta64, allow_none=True)])
[docs] @common_doc(COMMON_INTERPOLATOR_DOCS)
def can_interpolate(self, udims, source_coordinates, eval_coordinates):
"""
{interpolator_can_interpolate}
"""
# TODO: make this so we don't need to specify lat and lon together
# or at least throw a warning
if (
"lat" in udims
and "lon" in udims
and not self._dim_in(["lat", "lon"], source_coordinates)
and self._dim_in(["lat", "lon"], source_coordinates, unstacked=True)
and self._dim_in(["lat", "lon"], eval_coordinates, unstacked=True)
):
return tuple(["lat", "lon"])
# otherwise return no supported dims
return tuple()
[docs] @common_doc(COMMON_INTERPOLATOR_DOCS)
def interpolate(self, udims, source_coordinates, source_data, eval_coordinates, output_data):
"""
{interpolator_interpolate}
"""
order = "lat_lon" if "lat_lon" in source_coordinates.dims else "lon_lat"
# calculate tolerance
if isinstance(eval_coordinates["lat"], UniformCoordinates1d):
dlat = eval_coordinates["lat"].step
else:
dlat = (eval_coordinates["lat"].bounds[1] - eval_coordinates["lat"].bounds[0]) / (
eval_coordinates["lat"].size - 1
)
if isinstance(eval_coordinates["lon"], UniformCoordinates1d):
dlon = eval_coordinates["lon"].step
else:
dlon = (eval_coordinates["lon"].bounds[1] - eval_coordinates["lon"].bounds[0]) / (
eval_coordinates["lon"].size - 1
)
tol = np.linalg.norm([dlat, dlon]) * 8
if self._dim_in(["lat", "lon"], eval_coordinates):
pts = np.stack([source_coordinates[dim].coordinates for dim in source_coordinates[order].dims], axis=1)
if order == "lat_lon":
pts = pts[:, ::-1]
pts = KDTree(pts)
lon, lat = np.meshgrid(eval_coordinates["lon"].coordinates, eval_coordinates["lat"].coordinates)
dist, ind = pts.query(np.stack((lon.ravel(), lat.ravel()), axis=1), distance_upper_bound=tol)
mask = ind == source_data[order].size
ind[mask] = 0 # This is a hack to make the select on the next line work
# (the masked values are set to NaN on the following line)
vals = source_data[{order: ind}]
vals[mask] = np.nan
# make sure 'lat_lon' or 'lon_lat' is the first dimension
dims = [dim for dim in source_data.dims if dim != order]
vals = vals.transpose(order, *dims).data
shape = vals.shape
coords = [eval_coordinates["lat"].coordinates, eval_coordinates["lon"].coordinates]
coords += [source_coordinates[d].coordinates for d in dims]
vals = vals.reshape(eval_coordinates["lat"].size, eval_coordinates["lon"].size, *shape[1:])
vals = UnitsDataArray(vals, coords=coords, dims=["lat", "lon"] + dims)
# and transpose back to the destination order
output_data.data[:] = vals.transpose(*output_data.dims).data[:]
return output_data
elif self._dim_in(["lat", "lon"], eval_coordinates, unstacked=True):
dst_order = "lat_lon" if "lat_lon" in eval_coordinates.dims else "lon_lat"
src_stacked = np.stack(
[source_coordinates[dim].coordinates for dim in source_coordinates[order].dims], axis=1
)
new_stacked = np.stack(
[eval_coordinates[dim].coordinates for dim in source_coordinates[order].dims], axis=1
)
pts = KDTree(src_stacked)
dist, ind = pts.query(new_stacked, distance_upper_bound=tol)
mask = ind == source_data[order].size
ind[mask] = 0
vals = source_data[{order: ind}]
vals[{order: mask}] = np.nan
dims = list(output_data.dims)
dims[dims.index(dst_order)] = order
output_data.data[:] = vals.transpose(*dims).data[:]
return output_data
[docs]@common_doc(COMMON_INTERPOLATOR_DOCS)
class ScipyGrid(ScipyPoint):
"""Scipy Interpolation
Attributes
----------
{interpolator_attributes}
"""
methods_supported = ["nearest", "bilinear", "cubic_spline", "spline_2", "spline_3", "spline_4"]
method = tl.Unicode(default_value="nearest")
# TODO: implement these parameters for the method 'nearest'
spatial_tolerance = tl.Float(default_value=np.inf)
time_tolerance = tl.Union([tl.Unicode(), tl.Instance(np.timedelta64, allow_none=True)], default_value=None)
[docs] @common_doc(COMMON_INTERPOLATOR_DOCS)
def can_interpolate(self, udims, source_coordinates, eval_coordinates):
"""
{interpolator_can_interpolate}
"""
# TODO: make this so we don't need to specify lat and lon together
# or at least throw a warning
if (
"lat" in udims
and "lon" in udims
and self._dim_in(["lat", "lon"], source_coordinates)
and self._dim_in(["lat", "lon"], eval_coordinates, unstacked=True)
):
return ["lat", "lon"]
# otherwise return no supported dims
return tuple()
[docs] @common_doc(COMMON_INTERPOLATOR_DOCS)
def interpolate(self, udims, source_coordinates, source_data, eval_coordinates, output_data):
"""
{interpolator_interpolate}
"""
if self._dim_in(["lat", "lon"], eval_coordinates):
return self._interpolate_irregular_grid(
udims, source_coordinates, source_data, eval_coordinates, output_data, grid=True
)
elif self._dim_in(["lat", "lon"], eval_coordinates, unstacked=True):
eval_coordinates_us = eval_coordinates.unstack()
return self._interpolate_irregular_grid(
udims, source_coordinates, source_data, eval_coordinates_us, output_data, grid=False
)
def _interpolate_irregular_grid(
self, udims, source_coordinates, source_data, eval_coordinates, output_data, grid=True
):
if len(source_data.dims) > 2:
keep_dims = ["lat", "lon"]
return self._loop_helper(
self._interpolate_irregular_grid,
keep_dims,
udims,
source_coordinates,
source_data,
eval_coordinates,
output_data,
grid=grid,
)
s = []
if source_coordinates["lat"].is_descending:
lat = source_coordinates["lat"].coordinates[::-1]
s.append(slice(None, None, -1))
else:
lat = source_coordinates["lat"].coordinates
s.append(slice(None, None))
if source_coordinates["lon"].is_descending:
lon = source_coordinates["lon"].coordinates[::-1]
s.append(slice(None, None, -1))
else:
lon = source_coordinates["lon"].coordinates
s.append(slice(None, None))
data = source_data.data[tuple(s)]
# remove nan's
I, J = np.isfinite(lat), np.isfinite(lon)
coords_i = lat[I], lon[J]
coords_i_dst = [eval_coordinates["lon"].coordinates, eval_coordinates["lat"].coordinates]
# Swap order in case datasource uses lon,lat ordering instead of lat,lon
if source_coordinates.dims.index("lat") > source_coordinates.dims.index("lon"):
I, J = J, I
coords_i = coords_i[::-1]
coords_i_dst = coords_i_dst[::-1]
data = data[I, :][:, J]
if self.method in ["bilinear", "nearest"]:
f = RegularGridInterpolator(
coords_i, data, method=self.method.replace("bi", ""), bounds_error=False, fill_value=np.nan
)
if grid:
x, y = np.meshgrid(*coords_i_dst)
else:
x, y = coords_i_dst
output_data.data[:] = f((y.ravel(), x.ravel())).reshape(output_data.shape)
# TODO: what methods is 'spline' associated with?
elif "spline" in self.method:
if self.method == "cubic_spline":
order = 3
else:
# TODO: make this a parameter
order = int(self.method.split("_")[-1])
f = RectBivariateSpline(coords_i[0], coords_i[1], data, kx=max(1, order), ky=max(1, order))
output_data.data[:] = f(coords_i_dst[1], coords_i_dst[0], grid=grid).reshape(output_data.shape)
return output_data