"""
Base class for Algorithm Nodes
"""
from __future__ import division, unicode_literals, print_function, absolute_import
from collections import OrderedDict
import inspect
import numpy as np
import xarray as xr
import traitlets as tl
# Internal dependencies
from podpac.core.coordinates import Coordinates, union
from podpac.core.units import UnitsDataArray
from podpac.core.node import Node, NodeException, COMMON_NODE_DOC
from podpac.core.utils import common_doc, NodeTrait
from podpac.core.settings import settings
from podpac.core.managers.multi_threading import thread_manager
COMMON_DOC = COMMON_NODE_DOC.copy()
class BaseAlgorithm(Node):
"""Base class for algorithm nodes.
Note: developers should generally use one of the Algorithm or UnaryAlgorithm child classes.
"""
@property
def inputs(self):
# gettattr(self, ref) can take a long time, so we inspect trait.klass instead
return {
ref: getattr(self, ref)
for ref, trait in self.traits().items()
if hasattr(trait, "klass") and Node in inspect.getmro(trait.klass) and getattr(self, ref) is not None
}
def find_coordinates(self):
"""
Get the available coordinates for the inputs to the Node.
Returns
-------
coords_list : list
list of available coordinates (Coordinate objects)
"""
return [c for node in self.inputs.values() for c in node.find_coordinates()]
[docs]class Algorithm(BaseAlgorithm):
"""Base class for computation nodes with a custom algorithm.
Notes
------
Developers of new Algorithm nodes need to implement the `algorithm` method.
"""
# not the best solution... hard to check for these attrs
# abstract = tl.Bool(default_value=True, allow_none=True).tag(attr=True, required=False, hidden=True)
[docs] def algorithm(self, inputs, coordinates):
"""
Arguments
----------
inputs : dict
Evaluated outputs of the input nodes. The keys are the attribute names. Each item is a `UnitsDataArray`.
coordinates : podpac.Coordinates
Requested coordinates.
Note that the ``inputs`` may contain different coordinates than the requested coordinates
"""
raise NotImplementedError
@common_doc(COMMON_DOC)
def _eval(self, coordinates, output=None, _selector=None):
"""Evalutes this nodes using the supplied coordinates.
Parameters
----------
coordinates : podpac.Coordinates
{requested_coordinates}
output : podpac.UnitsDataArray, optional
{eval_output}
_selector: callable(coordinates, request_coordinates)
{eval_selector}
Returns
-------
{eval_return}
"""
self._requested_coordinates = coordinates
inputs = {}
if settings["MULTITHREADING"]:
n_threads = thread_manager.request_n_threads(len(self.inputs))
if n_threads == 1:
thread_manager.release_n_threads(n_threads)
else:
n_threads = 0
if settings["MULTITHREADING"] and n_threads > 1:
# Create a function for each thread to execute asynchronously
def f(node):
return node.eval(coordinates, _selector=_selector)
# Create pool of size n_threads, note, this may be created from a sub-thread (i.e. not the main thread)
pool = thread_manager.get_thread_pool(processes=n_threads)
# Evaluate nodes in parallel/asynchronously
results = [pool.apply_async(f, [node]) for node in self.inputs.values()]
# Collect the results in dictionary
for key, res in zip(self.inputs.keys(), results):
inputs[key] = res.get()
# This prevents any more tasks from being submitted to the pool, and will close the workers once done
pool.close()
# Release these number of threads back to the thread pool
thread_manager.release_n_threads(n_threads)
self._multi_threaded = True
else:
# Evaluate nodes in serial
for key, node in self.inputs.items():
inputs[key] = node.eval(coordinates, output=output, _selector=_selector)
self._multi_threaded = False
result = self.algorithm(inputs, coordinates)
if not isinstance(result, xr.DataArray):
raise NodeException("algorithm returned unsupported type '%s'" % type(result))
if "output" in result.dims and self.output is not None:
result = result.sel(output=self.output)
if output is not None:
missing = [dim for dim in result.dims if dim not in output.dims]
if any(missing):
raise NodeException("provided output is missing dims %s" % missing)
output_dims = output.dims
output = output.transpose(..., *result.dims)
output[:] = result.data
output = output.transpose(*output_dims)
elif isinstance(result, UnitsDataArray):
output = result
else:
output_coordinates = Coordinates.from_xarray(result)
output = self.create_output_array(output_coordinates, data=result.data)
return output
[docs]class UnaryAlgorithm(BaseAlgorithm):
"""
Base class for computation nodes that take a single source and transform it.
Attributes
----------
source : Node
The source node
Notes
------
Developers of new Algorithm nodes need to implement the `eval` method.
"""
source = NodeTrait().tag(attr=True, required=True)
# list of attribute names, used by __repr__ and __str__ to display minimal info about the node
_repr_keys = ["source"]
@tl.default("outputs")
def _default_outputs(self):
return self.source.outputs
@tl.default("style")
def _default_style(self): # Pass through source style by default
return self.source.style