from __future__ import division, unicode_literals, print_function, absolute_import
from podpac.core.cache import cache_ctrl
import traitlets as tl
from copy import deepcopy
from collections import OrderedDict
from six import string_types
import logging
import traitlets as tl
import numpy as np
from podpac.core.settings import settings
from podpac.core.node import Node
from podpac.core.utils import NodeTrait, common_doc, cached_property
from podpac.core.units import UnitsDataArray
from podpac.core.coordinates import merge_dims, Coordinates
from podpac.core.interpolation.interpolation_manager import InterpolationManager, InterpolationTrait
from podpac.core.cache.cache_ctrl import CacheCtrl
from podpac.core.data.datasource import DataSource
_logger = logging.getLogger(__name__)
class InterpolationMixin(tl.HasTraits):
# interpolation = InterpolationTrait().tag(attr=True, required=False, default = "nearesttt")
interpolation = InterpolationTrait().tag(attr=True)
_interp_node = None
@property
def _repr_keys(self):
return super()._repr_keys + ["interpolation"]
def _eval(self, coordinates, output=None, _selector=None):
node = Interpolate(
interpolation=self.interpolation,
source_id=self.hash,
force_eval=True,
cache_ctrl=CacheCtrl([]),
style=self.style,
)
node._set_interpolation()
selector = node._interpolation.select_coordinates
node._source_xr = super()._eval(coordinates, _selector=selector)
self._interp_node = node
if isinstance(self, DataSource):
# This is required to ensure that the output coordinates
# match the requested coordinates to floating point precision
r = node.eval(self._requested_coordinates, output=output)
else:
r = node.eval(coordinates, output=output)
# Helpful for debugging
self._from_cache = node._from_cache
return r
[docs]class Interpolate(Node):
"""Node to used to interpolate from self.source.coordinates to the user-specified, evaluated coordinates.
Parameters
----------
source : Any
The source node which will be interpolated
interpolation : str, dict, optional
Interpolation definition for the data source.
By default, the interpolation method is set to `podpac.settings["DEFAULT_INTERPOLATION"]` which defaults to ``'nearest'`` for all dimensions.
If input is a string, it must match one of the interpolation shortcuts defined in
:attr:`podpac.data.INTERPOLATION_SHORTCUTS`. The interpolation method associated
with this string will be applied to all dimensions at the same time.
If input is a dict or list of dict, the dict or dict elements must adhere to the following format:
The key ``'method'`` defining the interpolation method name.
If the interpolation method is not one of :attr:`podpac.data.INTERPOLATION_SHORTCUTS`, a
second key ``'interpolators'`` must be defined with a list of
:class:`podpac.interpolators.Interpolator` classes to use in order of uages.
The dictionary may contain an option ``'params'`` key which contains a dict of parameters to pass along to
the :class:`podpac.interpolators.Interpolator` classes associated with the interpolation method.
The dict may contain the key ``'dims'`` which specifies dimension names (i.e. ``'time'`` or ``('lat', 'lon')`` ).
If the dictionary does not contain a key for all unstacked dimensions of the source coordinates, the
:attr:`podpac.data.INTERPOLATION_DEFAULT` value will be used.
All dimension keys must be unstacked even if the underlying coordinate dimensions are stacked.
Any extra dimensions included but not found in the source coordinates will be ignored.
The dict may contain a key ``'params'`` that can be used to configure the :class:`podpac.interpolators.Interpolator` classes associated with the interpolation method.
If input is a :class:`podpac.data.Interpolation` class, this Interpolation
class will be used without modification.
cache_output : bool
Should the node's output be cached? If not provided or None, uses default based on
settings["CACHE_DATASOURCE_OUTPUT_DEFAULT"]. If True, outputs will be cached and retrieved from cache. If False,
outputs will not be cached OR retrieved from cache (even if they exist in cache).
Examples
-----
# To use bilinear interpolation for [lat,lon] a specific interpolator for [time], and the default for [alt], use:
>>> interp_node = Interpolation(
source=some_node,
interpolation=interpolation = [
{
'method': 'bilinear',
'dims': ['lat', 'lon']
},
{
'method': [podpac.interpolators.NearestNeighbor],
'dims': ['time']
}
]
)
"""
source = NodeTrait(allow_none=True).tag(attr=True, required=True)
source_id = tl.Unicode(allow_none=True).tag(attr=True)
_source_xr = tl.Instance(UnitsDataArray, allow_none=True) # This is needed for the Interpolation Mixin
interpolation = InterpolationTrait().tag(attr=True)
cache_output = tl.Bool()
# privates
_interpolation = tl.Instance(InterpolationManager)
_coordinates = tl.Instance(Coordinates, allow_none=True, default_value=None, read_only=True)
_requested_source_coordinates = tl.Instance(Coordinates)
_requested_source_coordinates_index = tl.Tuple()
_requested_source_data = tl.Instance(UnitsDataArray)
_evaluated_coordinates = tl.Instance(Coordinates)
@tl.default("style")
def _default_style(self): # Pass through source style by default
if self.source is not None:
return self.source.style
else:
return super()._default_style()
# this adds a more helpful error message if user happens to try an inspect _interpolation before evaluate
@tl.default("_interpolation")
def _default_interpolation(self):
self._set_interpolation()
return self._interpolation
@tl.default("cache_output")
def _cache_output_default(self):
return settings["CACHE_NODE_OUTPUT_DEFAULT"]
@tl.default("units")
def _use_source_units(self):
return getattr(self.source, "units", None)
# ------------------------------------------------------------------------------------------------------------------
# Properties
# ------------------------------------------------------------------------------------------------------------------
@property
def interpolation_class(self):
"""Get the interpolation class currently set for this data source.
The DataSource ``interpolation`` property is used to define the
:class:`podpac.data.InterpolationManager` class that will handle interpolation for requested coordinates.
Returns
-------
:class:`podpac.data.InterpolationManager`
InterpolationManager class defined by DataSource `interpolation` definition
"""
return self._interpolation
@property
def interpolators(self):
"""Return the interpolators selected for the previous node evaluation interpolation.
If the node has not been evaluated, or if interpolation was not necessary, this will return
an empty OrderedDict
Returns
-------
OrderedDict
Key are tuple of unstacked dimensions, the value is the interpolator used to interpolate these dimensions
"""
if self._interpolation._last_interpolator_queue is not None:
return self._interpolation._last_interpolator_queue
else:
return OrderedDict()
def _set_interpolation(self):
"""Update _interpolation property"""
# define interpolator with source coordinates dimensions
if isinstance(self.interpolation, InterpolationManager):
self._interpolation = self.interpolation
else:
self._interpolation = InterpolationManager(self.interpolation)
def _eval(self, coordinates, output=None, _selector=None):
"""Evaluates this node using the supplied coordinates.
The coordinates are mapped to the requested coordinates, interpolated if necessary, and set to
`_requested_source_coordinates` with associated index `_requested_source_coordinates_index`. The requested
source coordinates and index are passed to `get_data()` returning the source data at the
coordinatesset to `_requested_source_data`. Finally `_requested_source_data` is interpolated
using the `interpolate` method and set to the `output` attribute of the node.
Parameters
----------
coordinates : :class:`podpac.Coordinates`
{requested_coordinates}
An exception is raised if the requested coordinates are missing dimensions in the DataSource.
Extra dimensions in the requested coordinates are dropped.
output : :class:`podpac.UnitsDataArray`, optional
{eval_output}
_selector :
{eval_selector}
Returns
-------
{eval_return}
Raises
------
ValueError
Cannot evaluate these coordinates
"""
_logger.debug("Evaluating {} data source".format(self.__class__.__name__))
# store requested coordinates for debugging
if settings["DEBUG"]:
self._original_requested_coordinates = coordinates
# store input coordinates to evaluated coordinates
self._evaluated_coordinates = deepcopy(coordinates)
# reset interpolation
self._set_interpolation()
selector = self._interpolation.select_coordinates
source_out = self._source_eval(self._evaluated_coordinates, selector)
source_coords = Coordinates.from_xarray(source_out)
# Drop extra coordinates
extra_dims = [d for d in coordinates.udims if d not in source_coords.udims]
coordinates = coordinates.udrop(extra_dims)
# Transform so that interpolation happens on the source data coordinate system
if source_coords.crs.lower() != coordinates.crs.lower():
coordinates = coordinates.transform(source_coords.crs)
# Fix source coordinates in the case where some dimension are not being interpolated
coordinates = self._interpolation._fix_coordinates_for_none_interp(coordinates, source_coords)
if output is None:
if "output" in source_out.dims:
self.set_trait("outputs", source_out.coords["output"].data.tolist())
output = self.create_output_array(coordinates)
if source_out.size == 0: # short cut
return output
# interpolate data into output
output = self._interpolation.interpolate(source_coords, source_out, coordinates, output)
# if requested crs is differented than coordinates,
# fabricate a new output with the original coordinates and new values
if self._evaluated_coordinates.crs != coordinates.crs:
output = self.create_output_array(self._evaluated_coordinates.drop(extra_dims), data=output[:].values)
# save output to private for debugging
if settings["DEBUG"]:
self._output = output
self._source_xr = source_out
return output
def _source_eval(self, coordinates, selector, output=None):
if isinstance(self._source_xr, UnitsDataArray):
return self._source_xr
else:
return self.source.eval(coordinates, output=output, _selector=selector)
[docs] def find_coordinates(self):
"""
Get the available coordinates for the Node. For a DataSource, this is just the coordinates.
Returns
-------
coords_list : list
singleton list containing the coordinates (Coordinates object)
"""
return self.source.find_coordinates()
[docs] def get_bounds(self, crs="default"):
"""Get the full available coordinate bounds for the Node.
Arguments
---------
crs : str
Desired CRS for the bounds. Use 'source' to use the native source crs.
If not specified, the default CRS in the podpac settings is used. Optional.
Returns
-------
bounds : dict
Bounds for each dimension. Keys are dimension names and values are tuples (hi, lo).
crs : str
The crs for the bounds.
"""
return self.source.get_bounds(crs=crs)