from __future__ import division, unicode_literals, print_function, absolute_import
from collections import OrderedDict
import io
from podpac.core.coordinates.array_coordinates1d import ArrayCoordinates1d
import re
from six import string_types
import traitlets as tl
import numpy as np
import pyproj
import logging
from lazy_import import lazy_module
rasterio = lazy_module("rasterio")
boto3 = lazy_module("boto3")
from podpac.core.utils import common_doc, cached_property
from podpac.core.coordinates import UniformCoordinates1d, Coordinates, merge_dims
from podpac.core.data.datasource import COMMON_DATA_DOC, DATA_DOC
from podpac.core.data.file_source import BaseFileSource
from podpac.core.authentication import S3Mixin
from podpac.core.interpolation.interpolation import InterpolationMixin
_logger = logging.getLogger(__name__)
@common_doc(COMMON_DATA_DOC)
class RasterioRaw(S3Mixin, BaseFileSource):
"""Create a DataSource using rasterio.
Attributes
----------
source : str, :class:`io.BytesIO`
Path to the data source
dataset : :class:`rasterio._io.RasterReader`
A reference to the datasource opened by rasterio
coordinates : :class:`podpac.Coordinates`
{coordinates}
band : int
The 'band' or index for the variable being accessed in files such as GeoTIFFs. Use None for all bounds.
crs : str, optional
The coordinate reference system. Normally this will come directly from the file, but this allows users to
specify the crs in case this information is missing from the file.
aws_https: bool
Default is True. If False, will not use https when reading from AWS. This is useful for debugging when SSL certificates are invalid.
prefer_overviews: bool, optional
Default is False. If True, will pull data from an overview with the closest resolution (step size) matching the smallest resolution
in the request.
prefer_overviews_closest: bool, optional
Default is False. If True, will find the closest overview instead of the closest
See Also
--------
Rasterio : Interpolated rasterio datasource for general use.
"""
band = tl.CInt(allow_none=True).tag(attr=True)
crs = tl.Unicode(allow_none=True, default_value=None).tag(attr=True)
driver = tl.Unicode(allow_none=True, default_value=None)
coordinate_index_type = tl.Unicode()
aws_https = tl.Bool(True).tag(attr=True)
prefer_overviews = tl.Bool(False).tag(attr=True)
prefer_overviews_closest = tl.Bool(False).tag(attr=True)
@tl.default("coordinate_index_type")
def _default_coordinate_index_type(self):
if self.prefer_overviews:
return "numpy"
else:
return "slice"
@cached_property
def dataset(self):
return self.open_dataset(self.source)
def open_dataset(self, source, overview_level=None):
envargs = {"AWS_HTTPS": self.aws_https}
kwargs = {}
if overview_level is not None:
kwargs = {"overview_level": overview_level}
if source.startswith("s3://"):
envargs["session"] = rasterio.session.AWSSession(
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
region_name=self.aws_region_name,
requester_pays=self.aws_requester_pays,
aws_unsigned=self.anon,
)
with rasterio.env.Env(**envargs) as env:
_logger.debug("Rasterio environment options: {}".format(env.options))
return rasterio.open(source, **kwargs)
else:
return rasterio.open(source, **kwargs)
@tl.default("band")
def _band_default(self):
if self.outputs is not None and self.output is not None:
return self.outputs.index(self.output)
elif self.outputs is None:
return 1
else:
return None # All bands
# -------------------------------------------------------------------------
# public api methods
# -------------------------------------------------------------------------
@cached_property
def nan_vals(self):
return np.unique(np.array(self.dataset.nodatavals).astype(self.dtype)).tolist()
def close_dataset(self):
"""Closes the file for the datasource"""
self.dataset.close()
@common_doc(COMMON_DATA_DOC)
def get_coordinates(self):
"""{get_coordinates}
The default implementation tries to find the lat/lon coordinates based on dataset.affine.
It cannot determine the alt or time dimensions, so child classes may
have to overload this method.
"""
# check to see if the coordinates are rotated used affine
affine = self.dataset.transform
validate_crs = True
if self.crs is not None:
crs = self.crs
elif isinstance(self.dataset.crs, rasterio.crs.CRS) and "init" in self.dataset.crs:
crs = self.dataset.crs["init"].upper()
if self.dataset.crs.is_valid:
validate_crs = False
elif isinstance(self.dataset.crs, dict) and "init" in self.dataset.crs:
crs = self.dataset.crs["init"].upper()
if self.dataset.crs.is_valid:
validate_crs = False
else:
try:
crs = pyproj.CRS(self.dataset.crs).to_wkt()
except pyproj.exceptions.CRSError:
raise RuntimeError("Unexpected rasterio crs '%s'" % self.dataset.crs)
return Coordinates.from_geotransform(affine.to_gdal(), self.dataset.shape, crs, validate_crs)
@common_doc(COMMON_DATA_DOC)
def get_data(self, coordinates, coordinates_index):
"""{get_data}"""
if self.prefer_overviews:
return self.get_data_overviews(coordinates, coordinates_index)
data = self.create_output_array(coordinates)
slc = coordinates_index
# read data within coordinates_index window
window = ((slc[0].start, slc[0].stop), (slc[1].start, slc[1].stop))
if self.outputs is not None: # read all the bands
raster_data = self.dataset.read(out_shape=(len(self.outputs),) + tuple(coordinates.shape), window=window)
raster_data = np.moveaxis(raster_data, 0, 2)
else: # read the requested band
raster_data = self.dataset.read(self.band, out_shape=tuple(coordinates.shape)[:2], window=window)
# set raster data to output array
data.data.ravel()[:] = raster_data.ravel()
return data
def get_data_overviews(self, coordinates, coordinates_index):
# Figure out how much coarser the request is than the actual data
reduction_factor = np.inf
for c in ["lat", "lon"]:
crd = coordinates[c]
if crd.size == 1:
reduction_factor = 0
break
if isinstance(crd, UniformCoordinates1d):
min_delta = crd.step
elif isinstance(crd, ArrayCoordinates1d) and crd.is_monotonic:
min_delta = crd.deltas.min()
else:
raise NotImplementedError(
"The Rasterio node with prefer_overviews=True currently does not support request coordinates type {}".format(
coordinates
)
)
reduction_factor = min(
reduction_factor, np.abs(min_delta / self.coordinates[c].step) # self.coordinates is always uniform
)
# Find the overview that's closest to this reduction factor
if (reduction_factor < 2) or (len(self.overviews) == 0): # Then we shouldn't use an overview
overview = 1
overview_level = None
else:
diffs = reduction_factor - np.array(self.overviews)
if self.prefer_overviews_closest:
diffs = np.abs(diffs)
else:
diffs[diffs < 0] = np.inf
overview_level = np.argmin(diffs)
overview = self.overviews[np.argmin(diffs)]
# Now read the data
inds = coordinates_index
if overview_level is None:
dataset = self.dataset
else:
dataset = self.open_dataset(self.source, overview_level)
try:
# read data within coordinates_index window at the resolution of the overview
# Rasterio will then automatically pull from the overview
window = (
((inds[0].min() // overview), int(np.ceil(inds[0].max() / overview) + 1)),
((inds[1].min() // overview), int(np.ceil(inds[1].max() / overview) + 1)),
)
slc = (slice(window[0][0], window[0][1], 1), slice(window[1][0], window[1][1], 1))
new_coords = Coordinates.from_geotransform(
dataset.transform.to_gdal(), dataset.shape, crs=self.coordinates.crs
)
new_coords = new_coords[slc]
missing_coords = self.coordinates.drop(["lat", "lon"])
new_coords = merge_dims([new_coords, missing_coords])
new_coords = new_coords.transpose(*self.coordinates.dims)
coordinates_shape = new_coords.shape[:2]
# The following lines are *nearly* copied/pasted from get_data
if self.outputs is not None: # read all the bands
raster_data = dataset.read(out_shape=(len(self.outputs),) + coordinates_shape, window=window)
raster_data = np.moveaxis(raster_data, 0, 2)
else: # read the requested band
raster_data = dataset.read(self.band, out_shape=coordinates_shape, window=window)
# set raster data to output array
data = self.create_output_array(new_coords)
data.data.ravel()[:] = raster_data.ravel()
except Exception as e:
_logger.error("Error occurred when reading overview with Rasterio: {}".format(e))
if overview_level is not None:
dataset.close()
return data
# -------------------------------------------------------------------------
# additional methods and properties
# -------------------------------------------------------------------------
@property
def overviews(self):
return self.dataset.overviews(self.band)
@property
def tags(self):
return self.dataset.tags()
@property
def subdatasets(self):
return self.dataset.subdatasets
@property
def band_count(self):
"""The number of bands"""
return self.dataset.count
@cached_property
def band_descriptions(self):
"""A description of each band contained in dataset.tags
Returns
-------
OrderedDict
Dictionary of band_number: band_description pairs. The band_description values are a dictionary, each
containing a number of keys -- depending on the metadata
"""
return OrderedDict((i, self.dataset.tags(i + 1)) for i in range(self.band_count))
@cached_property
def band_keys(self):
"""An alternative view of band_descriptions based on the keys present in the metadata
Returns
-------
dict
Dictionary of metadata keys, where the values are the value of the key for each band.
For example, band_keys['TIME'] = ['2015', '2016', '2017'] for a dataset with three bands.
"""
keys = {k for i in range(self.band_count) for k in self.band_descriptions[i]} # set
return {k: [self.band_descriptions[i].get(k) for i in range(self.band_count)] for k in keys}
def get_band_numbers(self, key, value):
"""Return the bands that have a key equal to a specified value.
Parameters
----------
key : str / list
Key present in the metadata of the band. Can be a single key, or a list of keys.
value : str / list
Value of the key that should be returned. Can be a single value, or a list of values
Returns
-------
np.ndarray
An array of band numbers that match the criteria
"""
if not hasattr(key, "__iter__") or isinstance(key, string_types):
key = [key]
if not hasattr(value, "__iter__") or isinstance(value, string_types):
value = [value]
match = np.ones(self.band_count, bool)
for k, v in zip(key, value):
match = match & (np.array(self.band_keys[k]) == v)
matches = np.where(match)[0] + 1
return matches
[docs]class Rasterio(InterpolationMixin, RasterioRaw):
"""Rasterio datasource with interpolation."""
pass