Source code for podpac.core.data.ogr

import os.path

import numpy as np
import traitlets as tl

from lazy_import import lazy_module

gdal = lazy_module("osgeo.gdal")
ogr = lazy_module("osgeo.ogr")

from podpac import Node, Coordinates, cached_property, settings, clinspace
from podpac.core.utils import common_doc
from podpac.core.node import COMMON_NODE_DOC
from podpac.core.interpolation.interpolation import InterpolationMixin


class OGRRaw(Node):
    """ """

    source = tl.Unicode().tag(attr=True, required=True)
    layer = tl.Unicode().tag(attr=True, required=True)
    attribute = tl.Unicode().tag(attr=True)
    nan_vals = tl.List().tag(attr=True)
    nan_val = tl.Any(np.nan).tag(attr=True)
    driver = tl.Unicode()

    _repr_keys = ["source", "layer", "attribute"]

    # debug traits
    _requested_coordinates = tl.Instance(Coordinates, allow_none=True)
    _evaluated_coordinates = tl.Instance(Coordinates, allow_none=True)

    @tl.validate("driver")
    def _validate_driver(self, d):
        ogr.GetDriverByName(d["value"])
        return d["value"]

    @tl.validate("source")
    def _validate_source(self, d):
        if not os.path.exists(d["value"]):
            raise ValueError("OGR source not found '%s'" % d["value"])
        return d["value"]

    @cached_property
    def datasource(self):
        driver = ogr.GetDriverByName(self.driver)
        return driver.Open(self.source, 0)

    @cached_property
    def extents(self):
        layer = self.datasource.GetLayerByName(self.layer)
        return layer.GetExtent()

    def get_source_data(self, bounds={}):
        """
        Not available for OGR nodes.

        Arguments
        ---------
        bounds : dict
            Dictionary of bounds by dimension, optional.
            Keys must be dimension names, and values are (min, max) tuples, e.g. ``{'lat': (10, 20)}``.

        raises
        ------
        AttributeError : Cannot get source data for OGR datasources
        """

        raise AttributeError(
            "Cannot get source data for OGR datasources. "
            "The source data is a vector-based shapefile without a native resolution."
        )

    def find_coordinates(self):
        """
        Not available for OGR nodes.

        raises
        ------
        coord_list : list
            list of available coordinates (Coordinates objects)
        """

        raise AttributeError(
            "Cannot get available coordinates for OGR datasources. "
            "The source data is a vector-based shapefile without native coordinates."
        )

    @common_doc(COMMON_NODE_DOC)
    def _eval(self, coordinates, output=None, _selector=None):
        if "lat" not in coordinates.udims or "lon" not in coordinates.udims:
            raise RuntimeError("OGR source requires lat and lon dims")

        requested_coordinates = coordinates
        coordinates = coordinates.udrop(["time", "alt"], ignore_missing=True)

        if coordinates.size == 1 or "lat_lon" in coordinates or "lon_lat" in coordinates:
            # point or points
            eps = 1e-6
            data = np.empty(coordinates.size)
            for i, (lat, lon) in enumerate(zip(coordinates["lat"].coordinates, coordinates["lon"].coordinates)):
                geotransform = [lon - eps / 2.0, eps, 0.0, lat - eps / 2.0, 0.0, -1.0 * eps]
                data[i] = self._get_data(1, 1, geotransform)
            data = data.reshape(coordinates.shape)

        else:
            # resample non-uniform coordinates if necessary
            if not coordinates["lat"].is_uniform:
                coordinates["lat"] = clinspace(
                    coordinates["lat"].bounds[0], coordinates["lat"].bounds[1], coordinates["lat"].size, name="lat"
                )
            if not coordinates["lon"].is_uniform:
                coordinates["lon"] = clinspace(
                    coordinates["lon"].bounds[0], coordinates["lon"].bounds[1], coordinates["lon"].size, name="lon"
                )

            # evaluate uniform grid
            data = self._get_data(coordinates["lon"].size, coordinates["lat"].size, coordinates.geotransform)

        if output is None:
            output = self.create_output_array(coordinates, data=data)
        else:
            output.data[:] = data

        # nan values
        output.data[np.isin(output.data, self.nan_vals)] = self.nan_val

        if settings["DEBUG"]:
            self._requested_coordinates = requested_coordinates
            self._evaluated_coordinates = coordinates

        return output

    def _get_data(self, xsize, ysize, geotransform):
        nan_val = 0

        # create target datasource
        driver = gdal.GetDriverByName("MEM")
        target = driver.Create("", xsize, ysize, gdal.GDT_Float64)
        target.SetGeoTransform(geotransform)
        band = target.GetRasterBand(1)
        band.SetNoDataValue(nan_val)
        band.Fill(nan_val)

        # rasterize
        layer = self.datasource.GetLayerByName(self.layer)
        gdal.RasterizeLayer(target, [1], layer, options=["ATTRIBUTE=%s" % self.attribute])

        data = band.ReadAsArray(buf_type=gdal.GDT_Float64).copy()
        data[data == nan_val] = np.nan
        return data


[docs]class OGR(InterpolationMixin, OGRRaw): interpolation = "nearest"