"""
Interpolator implementations
"""
from __future__ import division, unicode_literals, print_function, absolute_import
from six import string_types
import numpy as np
import xarray as xr
import traitlets as tl
from scipy.spatial import cKDTree
# Optional dependencies
# podac imports
from podpac.core.interpolation.interpolator import COMMON_INTERPOLATOR_DOCS, Interpolator, InterpolatorException
from podpac.core.coordinates import Coordinates, UniformCoordinates1d, StackedCoordinates
from podpac.core.coordinates.utils import make_coord_delta, make_coord_value
from podpac.core.utils import common_doc
from podpac.core.coordinates.utils import get_timedelta
from podpac.core.interpolation.selector import Selector, _higher_precision_time_coords1d, _higher_precision_time_stack
[docs]@common_doc(COMMON_INTERPOLATOR_DOCS)
class NearestNeighbor(Interpolator):
    """Nearest Neighbor Interpolation
    {nearest_neighbor_attributes}
    """
    dims_supported = ["lat", "lon", "alt", "time"]
    methods_supported = ["nearest"]
    # defined at instantiation
    method = tl.Unicode(default_value="nearest")
    spatial_tolerance = tl.Float(default_value=np.inf, allow_none=True)
    time_tolerance = tl.Union([tl.Unicode(), tl.Instance(np.timedelta64, allow_none=True)])
    alt_tolerance = tl.Float(default_value=np.inf, allow_none=True)
    # spatial_scale only applies when the source is stacked with time or alt. The supplied value will be assigned a distance of "1'"
    spatial_scale = tl.Float(default_value=1, allow_none=True)
    # time_scale only applies when the source is stacked with lat, lon, or alt. The supplied value will be assigned a distance of "1'"
    time_scale = tl.Union([tl.Unicode(), tl.Instance(np.timedelta64, allow_none=True)])
    # alt_scale only applies when the source is stacked with lat, lon, or time. The supplied value will be assigned a distance of "1'"
    alt_scale = tl.Float(default_value=1, allow_none=True)
    respect_bounds = tl.Bool(True)
    remove_nan = tl.Bool(False)
    use_selector = tl.Bool(True)
    def __repr__(self):
        rep = super(NearestNeighbor, self).__repr__()
        # rep += '\n\tspatial_tolerance: {}\n\ttime_tolerance: {}'.format(self.spatial_tolerance, self.time_tolerance)
        return rep
[docs]    @common_doc(COMMON_INTERPOLATOR_DOCS)
    def can_interpolate(self, udims, source_coordinates, eval_coordinates):
        """
        {interpolator_interpolate}
        """
        udims_subset = self._filter_udims_supported(udims)
        return udims_subset 
[docs]    def can_select(self, udims, source_coordinates, eval_coordinates):
        selector = super().can_select(udims, source_coordinates, eval_coordinates)
        if self.use_selector:
            return selector
        return () 
[docs]    @common_doc(COMMON_INTERPOLATOR_DOCS)
    def interpolate(self, udims, source_coordinates, source_data, eval_coordinates, output_data):
        """
        {interpolator_interpolate}
        """
        # Note, some of the following code duplicates code in the Selector class.
        # This duplication is for the sake of optimization
        def is_stacked(d):
            return "_" in d
        if hasattr(source_data, "attrs") and "bounds" in source_data.attrs:
            bounds = source_data.attrs["bounds"]
            if "time" in bounds and bounds["time"]:
                if "time" in eval_coordinates.udims:
                    bounds["time"] = [
                        self._atime_to_float(b, source_coordinates["time"], eval_coordinates["time"])
                        for b in bounds["time"]
                    ]
                else:
                    bounds["time"] = [
                        self._atime_to_float(b, source_coordinates["time"], source_coordinates["time"])
                        for b in bounds["time"]
                    ]
        else:
            bounds = None
        if self.remove_nan:
            # Eliminate nans from the source data. Note, this could turn a uniform griddted dataset into a stacked one
            source_data, source_coordinates = self._remove_nans(source_data, source_coordinates)
        data_index = []
        for d in source_coordinates.dims:
            # Make sure we're supposed to do nearest neighbor interpolation for this UDIM, otherwise skip this dimension
            if len([dd for dd in d.split("_") if dd in udims]) == 0:
                index = self._resize_unstacked_index(np.arange(source_coordinates[d].size), d, eval_coordinates)
                data_index.append(index)
                continue
            source = source_coordinates[d]
            if is_stacked(d):
                if bounds is not None:
                    bound = np.stack([bounds[dd] for dd in d.split("_")], axis=1)
                else:
                    bound = None
                index = self._get_stacked_index(d, source, eval_coordinates, bound)
                if len(source.shape) == 2:  # Handle case of 2D-stacked coordinates
                    ncols = source.shape[1]
                    index1 = index // ncols
                    index1 = self._resize_stacked_index(index1, d, eval_coordinates)
                    # With nD stacked coordinates, there are 'n' indices in the tuple
                    # All of these need to get into the data_index, and in the right order
                    data_index.append(index1)  # This is a hack
                    index = index % ncols  # The second half can go through the usual machinery
                elif len(source.shape) > 2:  # Handle case of nD-stacked coordinates
                    raise NotImplementedError
                index = self._resize_stacked_index(index, d, eval_coordinates)
            elif source_coordinates[d].is_uniform:
                request = eval_coordinates[d]
                if bounds is not None:
                    bound = bounds[d]
                else:
                    bound = None
                index = self._get_uniform_index(d, source, request, bound)
                index = self._resize_unstacked_index(index, d, eval_coordinates)
            else:  # non-uniform coordinates... probably an optimization here
                request = eval_coordinates[d]
                if bounds is not None:
                    bound = bounds[d]
                else:
                    bound = None
                index = self._get_nonuniform_index(d, source, request, bound)
                index = self._resize_unstacked_index(index, d, eval_coordinates)
            data_index.append(index)
        index = tuple(data_index)
        output_data.data[:] = np.array(source_data)[index]
        bool_inds = sum([i == -1 for i in index]).astype(bool)
        output_data.data[bool_inds] = np.nan
        return output_data 
    def _remove_nans(self, source_data, source_coordinates):
        index = np.array(np.isnan(source_data), bool)
        if not np.any(index):
            return source_data, source_coordinates
        data = source_data.data[~index]
        coords = np.meshgrid(
            *[source_coordinates[d.split("_")[0]].coordinates for d in source_coordinates.dims], indexing="ij"
        )
        repeat_shape = coords[0].shape
        coords = [c[~index] for c in coords]
        final_dims = [d.split("_")[0] for d in source_coordinates.dims]
        # Add back in any stacked coordinates
        for i, d in enumerate(source_coordinates.dims):
            dims = d.split("_")
            if len(dims) == 1:
                continue
            reshape = np.ones(len(coords), int)
            reshape[i] = -1
            repeats = list(repeat_shape)
            repeats[i] = 1
            for dd in dims[1:]:
                crds = source_coordinates[dd].coordinates.reshape(*reshape)
                for j, r in enumerate(repeats):
                    crds = crds.repeat(r, axis=j)
                coords.append(crds[~index])
                final_dims.append(dd)
        return data, Coordinates([coords], dims=[final_dims])
    def _get_tol(self, dim, source, request):
        if dim in ["lat", "lon"]:
            return self.spatial_tolerance
        if dim == "alt":
            return self.alt_tolerance
        if dim == "time":
            if self.time_tolerance == "":
                return np.inf
            return self._time_to_float(self.time_tolerance, source, request)
        raise NotImplementedError()
    def _get_scale(self, dim, source_1d, request_1d):
        if dim in ["lat", "lon"]:
            return 1 / self.spatial_scale
        if dim == "alt":
            return 1 / self.alt_scale
        if dim == "time":
            if self.time_scale == "":
                return 1.0
            return 1 / self._time_to_float(self.time_scale, source_1d, request_1d)
        raise NotImplementedError()
    def _time_to_float(self, time, time_source, time_request):
        dtype0 = time_source.coordinates[0].dtype
        dtype1 = time_request.coordinates[0].dtype
        dtype = dtype0 if dtype0 > dtype1 else dtype1
        time = make_coord_delta(time)
        if isinstance(time, np.timedelta64):
            time1 = (time + np.datetime64("2000")).astype(dtype).astype(float) - (
                np.datetime64("2000").astype(dtype).astype(float)
            )
        return time1
    def _atime_to_float(self, time, time_source, time_request):
        dtype0 = time_source.coordinates[0].dtype
        dtype1 = time_request.coordinates[0].dtype
        dtype = dtype0 if dtype0 > dtype1 else dtype1
        time = make_coord_value(time)
        if isinstance(time, np.datetime64):
            time = time.astype(dtype).astype(float)
        return time
    def _get_stacked_index(self, dim, source, request, bounds=None):
        # The udims are in the order of the request so that the meshgrid calls will be in the right order
        udims = [ud for ud in request.udims if ud in source.udims]
        time_source = time_request = None
        if "time" in udims:
            time_source = source["time"]
            time_request = request["time"]
        tols = np.array([self._get_tol(d, time_source, time_request) for d in udims])[None, :]
        scales = np.array([self._get_scale(d, time_source, time_request) for d in udims])[None, :]
        tol = np.linalg.norm((tols * scales).squeeze())
        src_coords, req_coords_diag = _higher_precision_time_stack(source, request, udims)
        # We need to unwravel the nD stacked coordinates
        ckdtree_source = cKDTree(src_coords.reshape(src_coords.shape[0], -1).T * scales)
        # if the udims are all stacked in the same stack as part of the request coordinates, then we're done.
        # Otherwise we have to evaluate each unstacked set of dimensions independently
        # Note, part of this code is duplicated in the selector
        indep_evals = [ud for ud in udims if not request.is_stacked(ud)]
        # two udims could be stacked, but in different dim groups, e.g. source (lat, lon), request (lat, time), (lon, alt)
        stacked = {d for d in request.dims for ud in udims if ud in d and request.is_stacked(ud)}
        if (len(indep_evals) + len(stacked)) <= 1:  # output is stacked in the same way
            # The ckdtree call below needs the lat/lon pairs in the last axis position
            req_coords = np.moveaxis(req_coords_diag, 0, -1)
        elif (len(stacked) == 0) | (len(indep_evals) == 0 and len(stacked) == len(udims)):
            req_coords = np.stack([i.ravel() for i in np.meshgrid(*req_coords_diag, indexing="ij")], axis=1)
        else:
            # Rare cases? E.g. lat_lon_time_alt source to lon, time_alt, lat destination
            sizes = [request[d].size for d in request.dims]
            reshape = np.ones(len(request.dims), int)
            coords = [None] * len(udims)
            for i in range(len(udims)):
                ii = [ii for ii in range(len(request.dims)) if udims[i] in request.dims[ii]][0]
                reshape[:] = 1
                reshape[ii] = -1
                coords[i] = req_coords_diag[i].reshape(*reshape)
                for j, d in enumerate(request.dims):
                    if udims[i] in d:  # Then we don't need to repeat
                        continue
                    coords[i] = coords[i].repeat(sizes[j], axis=j)
            req_coords = np.stack([i.ravel() for i in coords], axis=1)
        dist, index = ckdtree_source.query(req_coords * scales, k=1)
        if self.respect_bounds:
            if bounds is None:
                bounds = np.stack(
                    [
                        src_coords.reshape(src_coords.shape[0], -1).T.min(0),
                        src_coords.reshape(src_coords.shape[0], -1).T.max(0),
                    ],
                    axis=1,
                )
            # Fix order of bounds
            bounds = bounds[:, [source.udims.index(dim) for dim in udims]]
            index[np.any((req_coords > bounds[1]), axis=-1) | np.any((req_coords < bounds[0]), axis=-1)] = -1
        if tol and tol != np.inf:
            index[dist > tol] = -1
        return index
    def _get_uniform_index(self, dim, source, request, bounds=None):
        tol = self._get_tol(dim, source, request)
        index = (request.coordinates - source.start) / source.step
        rindex = np.around(index).astype(int)
        stop_ind = int(source.size)
        if self.respect_bounds:
            rindex[(rindex < 0) | (rindex >= stop_ind)] = -1
        else:
            rindex = np.clip(rindex, 0, stop_ind - 1)
        if tol and tol != np.inf:
            if dim == "time":
                step = self._time_to_float(source.step, source, request)
            else:
                step = source.step
            rindex[np.abs(index - rindex) * np.abs(step) > tol] = -1
        return rindex
    def _get_nonuniform_index(self, dim, source, request, bounds=None):
        tol = self._get_tol(dim, source, request)
        src, req = _higher_precision_time_coords1d(source, request)
        ckdtree_source = cKDTree(src.reshape(-1, 1))
        dist, index = ckdtree_source.query(req[:].reshape(-1, 1), k=1)
        index[index == source.coordinates.size] = -1
        if self.respect_bounds:
            if bounds is None:
                bounds = [src.min(), src.max()]
            index[(req.ravel() > bounds[1]) | (req.ravel() < bounds[0])] = -1
        if tol and tol != np.inf:
            index[dist > tol] = -1
        return index
    def _resize_unstacked_index(self, index, source_dim, request):
        # When the request is stacked, and the stacked dimensions are n-dimensions where n > 1,
        # Then len(request.shape) != len(request.dims), so it take s a little bit of footwork
        # to get the correct shape for the index
        reshape = np.array(request.shape)
        i = 0
        for dim in request.dims:
            addnext = len(request[dim].shape)
            if source_dim not in dim:
                reshape[i : i + addnext] = 1
            i += addnext
        return index.reshape(*reshape)
    def _resize_stacked_index(self, index, source_dim, request):
        reshape = np.array(request.shape)
        i = 0
        for dim in request.dims:
            addnext = len(request[dim].shape)
            d = dim.split("_")
            if not any([dd in source_dim for dd in d]):
                reshape[i : i + addnext] = 1
            i += addnext
        return index.reshape(*reshape) 
[docs]@common_doc(COMMON_INTERPOLATOR_DOCS)
class NearestPreview(NearestNeighbor):
    """Nearest Neighbor (Preview) Interpolation
    {nearest_neighbor_attributes}
    """
    methods_supported = ["nearest_preview"]
    method = tl.Unicode(default_value="nearest_preview")
    spatial_tolerance = tl.Float(read_only=True, allow_none=True, default_value=None)
[docs]    @common_doc(COMMON_INTERPOLATOR_DOCS)
    def can_select(self, udims, source_coordinates, eval_coordinates):
        """
        {interpolator_can_select}
        """
        udims_subset = self._filter_udims_supported(udims)
        # confirm that udims are in source and eval coordinates
        # TODO: handle stacked coordinates
        if self._dim_in(udims_subset, source_coordinates):
            return udims_subset
        else:
            return tuple() 
[docs]    @common_doc(COMMON_INTERPOLATOR_DOCS)
    def select_coordinates(self, udims, source_coordinates, eval_coordinates, index_type="numpy"):
        """
        {interpolator_select}
        """
        new_coords = []
        new_coords_idx = []
        source_coords, source_coords_index = source_coordinates.intersect(
            eval_coordinates, outer=True, return_index=True
        )
        if source_coords.size == 0:
            return source_coords, source_coords_index
        # iterate over the source coordinate dims in case they are stacked
        for src_dim, idx in zip(source_coords, source_coords_index):
            # TODO: handle stacked coordinates
            if isinstance(source_coords[src_dim], StackedCoordinates):
                raise InterpolatorException("NearestPreview select does not yet support stacked dimensions")
            if src_dim in eval_coordinates.dims:
                src_coords = source_coords[src_dim]
                dst_coords = eval_coordinates[src_dim]
                if isinstance(dst_coords, UniformCoordinates1d):
                    dst_start = dst_coords.start
                    dst_stop = dst_coords.stop
                    dst_delta = dst_coords.step
                else:
                    dst_start = dst_coords.coordinates[0]
                    dst_stop = dst_coords.coordinates[-1]
                    with np.errstate(invalid="ignore"):
                        dst_delta = (dst_stop - dst_start) / (dst_coords.size - 1)
                if isinstance(src_coords, UniformCoordinates1d):
                    src_start = src_coords.start
                    src_stop = src_coords.stop
                    src_delta = src_coords.step
                else:
                    src_start = src_coords.coordinates[0]
                    src_stop = src_coords.coordinates[-1]
                    with np.errstate(invalid="ignore"):
                        src_delta = (src_stop - src_start) / (src_coords.size - 1)
                ndelta = max(1, np.round(np.abs(dst_delta / src_delta)))
                idx_offset = 0
                if src_coords.size == 1:
                    c = src_coords.copy()
                else:
                    c_test = UniformCoordinates1d(src_start, src_stop, ndelta * src_delta, **src_coords.properties)
                    bounds = source_coordinates[src_dim].bounds
                    # The delta/2 ensures the endpoint is included when there is a floating point rounding error
                    # the delta/2 is more than needed, but does guarantee.
                    src_stop = np.clip(src_stop + ndelta * src_delta / 2, bounds[0], bounds[1])
                    c = UniformCoordinates1d(src_start, src_stop, ndelta * src_delta, **src_coords.properties)
                    if c.size > c_test.size:  # need to adjust the index as well
                        idx_offset = int(ndelta)
                idx_start = idx.start if isinstance(idx, slice) else idx[0]
                idx_stop = idx.stop if isinstance(idx, slice) else idx[-1]
                if idx_stop is not None:
                    idx_stop += idx_offset
                idx = slice(idx_start, idx_stop, int(ndelta))
            else:
                c = source_coords[src_dim]
            new_coords.append(c)
            new_coords_idx.append(idx)
        return Coordinates(new_coords, validate_crs=False), tuple(new_coords_idx)