Source code for viscid.amr_field

"""For fields that consist of a list of fields + an AMRSkeleton

Note:
    An AMRField is NOT a subclass of Field, but it is a giant wrapper
    around a lot of Field functionality.
"""

from __future__ import print_function
import numpy as np

import viscid
# from viscid.compat import string_types
from viscid.field import Field


__all__ = ["is_list_of_fields"]


[docs]def is_list_of_fields(lst): """is a sequence a sequence of Field objects?""" for item in lst: if not isinstance(item, Field): return False return True
class _FieldListCallableAttrWrapper(object): objs = None attrname = None post_func = None def __init__(self, objs, attrname, post_func=None): # print(">>> runtime wrapping:", attrname) for o in objs: if not hasattr(o, attrname): raise AttributeError("{0} has no attribute {1}" "".format(o, attrname)) self.objs = objs self.attrname = attrname self.post_func = post_func def __call__(self, *args, **kwargs): lst = [getattr(o, self.attrname)(*args, **kwargs) for o in self.objs] if self.post_func: return self.post_func(lst) else: return lst class AMRField(object): """Field-like Contains an AMRSkeleton and a list of Fields. This mimiks a Field, but it is NOT a subclass of Field. Many methods of Field are wrapped and return a new AMRField. If an attribute of Field is not explicitly wrapped, this class will try to runtime-wrap that method and return a new AMRField or a list containing the result. This will not work for special methods since python will not send those through __getattr__ or __getattribute__. """ _TYPE = "amr" skeleton = None patches = None nr_patches = None def __init__(self, fields, skeleton): if not is_list_of_fields(fields): raise TypeError("AMRField can only contain Fields:", fields) self.skeleton = skeleton self.patches = fields self.nr_patches = len(fields) @property def xl(self): if self.skeleton: return np.min(self.skeleton.xl, axis=0) else: return np.min(np.vstack([p.xl for p in self.patches]), axis=0) @property def xh(self): if self.skeleton: return np.max(self.skeleton.xh, axis=0) else: return np.max(np.vstack([p.xh for p in self.patches]), axis=0) def get_slice_extent(self, selection): extent = self.patches[0]._src_crds.get_slice_extent(selection) for i in range(3): if np.isnan(extent[0, i]): extent[0, i] = self.xl[i] if np.isnan(extent[1, i]): extent[1, i] = self.xh[i] return extent ########### ## slicing def _prepare_amr_slice(self, selection): """ return list of patches that contain selection """ # FIXME: it's not good to reach in to src_field[0]'s private methods # like this, but it's also not good to implement these things twice # print("??", len(self.patches)) if len(self.patches) == 0: raise ValueError("AMR field must contain patches to be slicable") sel_list, _ = self.patches[0]._prepare_slice(selection) try: extent = self.patches[0]._src_crds.get_slice_extent(selection) except RuntimeError: raise RuntimeError("Slicing by global index is poorly defined " "for AMR fields; selection = '{0}'" "".format(selection)) inds = [] # these are patches that look like they contain selection # but might not due to finite precision errors when # calculating xh maybe = [] # detect dimensions that only have one cell (or node) and allow any # slice in that direction, this helps cases where a field mab be # defined or y = [-0.01], but the user tries to slice by 'y=0j' all_xl_nc = np.array([patch.crds.xl_nc for patch in self.patches]) all_xh_nc = np.array([patch.crds.xh_nc for patch in self.patches]) dim_is_2d = np.all(all_xl_nc == all_xh_nc, axis=0) for idim, is_2d in enumerate(dim_is_2d): if is_2d: all_xl_nc[:, idim] = -np.inf all_xh_nc[:, idim] = np.inf for i, fld in enumerate(self.patches): # - if xl - atol > the extent of the slice in any direction, then # there's no overlap # - if xh <= the lower corner of the slice in any direction, then # there's no overlap # the atol and equals are done to match cases where extent overlaps # the lower corner, but not the upper corner # logic goes this way cause extent has NaNs in # dimensions that aren't specified in selection... super-kludge # also, temporarily disable warnings on NaNs in numpy xl_nc = all_xl_nc[i] xh_nc = all_xh_nc[i] invalid_err_level = np.geterr()['invalid'] np.seterr(invalid='ignore') atol = 100 * np.finfo(xl_nc.dtype).eps if (not np.any(np.logical_or(xl_nc - atol > extent[1], xh_nc <= extent[0]))): if np.any(np.isclose(fld.crds.xh_nc, extent[0], atol=atol)): maybe.append(i) else: inds.append(i) np.seterr(invalid=invalid_err_level) # if we found some maybes, but no real hits, then use the maybes if maybe and not inds: inds = maybe if len(inds) == 0: viscid.logger.error("selection {0} not in any patch @ time {1}" "".format(selection, self.patches[0].time)) if self.skeleton: s = " xl= {0}".format(self.skeleton.global_xl) viscid.logger.error(s) s = " xh= {0}".format(self.skeleton.global_xh) viscid.logger.error(s) inds = None flds = None elif len(inds) == 1: inds = inds[0] flds = self.patches[inds] else: flds = [self.patches[i] for i in inds] return flds, inds def _finalize_amr_slice(self, fld_lst): # pylint: disable=no-self-use skeleton = None # FIXME for fld in fld_lst: if isinstance(fld, (int, float, np.number)): m = ("Trying to make an AMRField where 1+ patches " "is just a number... You probably slice_reduced " "a field down to a scalar value") viscid.logger.error(m) # prune out fields that got sliced to smithereens for i in reversed(range(len(fld_lst))): if fld_lst[i].size == 0: viscid.logger.debug("finalize amr slice, remove size 0 patch") fld_lst.pop(i) # look for uneven dimensions in the results and fill them back out axes = [patch.crds.axes for patch in fld_lst] ndims = np.array([len(ax) for ax in axes]) ref_axes = axes[np.argmax(ndims)] ref_ndims = np.max(ndims) for i, fld in enumerate(fld_lst): if ndims[i] < ref_ndims: putback = [] for _ax in ref_axes: if _ax in axes[i]: putback.append('{0}=:'.format(_ax)) else: putback.append('{0}=newaxis'.format(_ax)) putback_slice = ','.join(putback) viscid.logger.debug("putback: {0}".format(putback_slice)) fld_lst[i] = fld[putback_slice] return AMRField(fld_lst, skeleton) def patch_indices(self, selection): """get the indices of the patches that overlap selection Args: selection (slice, str): anything that can slice a field Returns: list of indices """ _, inds = self._prepare_amr_slice(selection) return inds def slice(self, selection): fld_lst, _ = self._prepare_amr_slice(selection) if not isinstance(fld_lst, list): return fld_lst.slice(selection) fld_lst = [fld.slice(selection) for fld in fld_lst] return self._finalize_amr_slice(fld_lst) def slice_reduce(self, selection): fld_lst, _ = self._prepare_amr_slice(selection) if not isinstance(fld_lst, list): return fld_lst.slice_reduce(selection) fld_lst = [fld.slice_reduce(selection) for fld in fld_lst] return self._finalize_amr_slice(fld_lst) def slice_and_keep(self, selection): fld_lst, _ = self._prepare_amr_slice(selection) if not isinstance(fld_lst, list): return fld_lst.slice_and_keep(selection) fld_lst = [fld.slice_and_keep(selection) for fld in fld_lst] return self._finalize_amr_slice(fld_lst) def interpolated_slice(self, selection): fld_lst, _ = self._prepare_amr_slice(selection) if not isinstance(fld_lst, list): raise RuntimeError("can't interpolate to that slice?") ret_lst = [fld.interpolated_slice(selection) for fld in fld_lst] return self._finalize_amr_slice(ret_lst) ################### ## special methods def __getitem__(self, item): return self.slice(item) def __setitem__(self, key, value): raise NotImplementedError() def __delitem__(self, item): raise NotImplementedError() def __enter__(self): return self def __exit__(self, exc_type, value, traceback): """clear all caches""" for blk in self.patches: blk.clear_cache() return None def wrap_field_method(self, attrname, *args, **kwargs): """Wrap methods whose args are Fields and return a Field""" # make sure all args have same number of patches as self is_field = [None] * len(args) for i, arg in enumerate(args): try: if arg.nr_patches != self.nr_patches and arg.nr_patches != 1: raise ValueError("AMR fields in math operations must " "have the same number of patches") is_field[i] = True except AttributeError: is_field[i] = False lst = [None] * self.nr_patches other = [None] * len(args) # FIXME: There must be a better way for i, patch in enumerate(self.patches): for j, arg in enumerate(args): if is_field[j]: try: other[j] = arg.patches[i] except IndexError: other[j] = arg.patches[0] else: other[j] = arg lst[i] = getattr(patch, attrname)(*other, **kwargs) if np.asarray(lst[0]).size == 1: # operation reduced to scalar arr = np.array(lst) return getattr(arr, attrname)(**kwargs) else: return AMRField(lst, self.skeleton) # TODO: as of numpy 1.10, this will be called on ufuncs... this # will help some of the FIXMEs in __array__ # def __numpy_ufunc__(self, ufunc, method, i, inputs, **kwargs): # pass def __array__(self, *args, **kwargs): # FIXME: This is heinously inefficient for large arrays because it # makes an copy of all the arrays... but I don't see # a way around this because ufuncs expect a single array # FIXME: adding a dimension to the arrays will break cases like # np.sum(fld, axis=-1), cause that -1 will now be the patch # dimension patches = [patch.__array__(*args, **kwargs) for patch in self.patches] for i, patch in enumerate(patches): patches[i] = np.expand_dims(patch, 0) # the vstack will copy all the arrays, this is what __numpy_ufunc__ # will be able to avoid arr = np.vstack(patches) # roll the patch dimension to the last dimension... this is for ufuncs # that take an axis argument... this way axis will only be confused # if it's negative, this is the main reason to use __numpy_ufunc__ # in the future arr = np.rollaxis(arr, 0, len(arr.shape)) return arr def __array_wrap__(self, arr, context=None): # pylint: disable=unused-argument # print(">> __array_wrap__", arr.shape, context) flds = [] for i in range(arr.shape[-1]): patch_arr = arr[..., i] fld = self.patches[i].__array_wrap__(patch_arr, context=context) flds.append(fld) return AMRField(flds, self.skeleton) def __add__(self, other): return self.wrap_field_method("__add__", other) def __sub__(self, other): return self.wrap_field_method("__sub__", other) def __mul__(self, other): return self.wrap_field_method("__mul__", other) def __div__(self, other): return self.wrap_field_method("__div__", other) def __truediv__(self, other): return self.wrap_field_method("__truediv__", other) def __floordiv__(self, other): return self.wrap_field_method("__floordiv__", other) def __mod__(self, other): return self.wrap_field_method("__mod__", other) def __divmod__(self, other): return self.wrap_field_method("__divmod__", other) def __pow__(self, other): return self.wrap_field_method("__pow__", other) def __lshift__(self, other): return self.wrap_field_method("__lshift__", other) def __rshift__(self, other): return self.wrap_field_method("__rshift__", other) def __and__(self, other): return self.wrap_field_method("__and__", other) def __xor__(self, other): return self.wrap_field_method("__xor__", other) def __or__(self, other): return self.wrap_field_method("__or__", other) def __radd__(self, other): return self.wrap_field_method("__radd__", other) def __rsub__(self, other): return self.wrap_field_method("__rsub__", other) def __rmul__(self, other): return self.wrap_field_method("__rmul__", other) def __rdiv__(self, other): return self.wrap_field_method("__rdiv__", other) def __rtruediv__(self, other): return self.wrap_field_method("__rtruediv__", other) def __rfloordiv__(self, other): return self.wrap_field_method("__rfloordiv__", other) def __rmod__(self, other): return self.wrap_field_method("__rmod__", other) def __rdivmod__(self, other): return self.wrap_field_method("__rdivmod__", other) def __rpow__(self, other): return self.wrap_field_method("__rpow__", other) def __iadd__(self, other): return self.wrap_field_method("__iadd__", other) def __isub__(self, other): return self.wrap_field_method("__isub__", other) def __imul__(self, other): return self.wrap_field_method("__imul__", other) def __idiv__(self, other): return self.wrap_field_method("__idiv__", other) def __itruediv__(self, other): return self.wrap_field_method("__itruediv__", other) def __ifloordiv__(self, other): return self.wrap_field_method("__ifloordiv__", other) def __imod__(self, other): return self.wrap_field_method("__imod__", other) def __ipow__(self, other): return self.wrap_field_method("__ipow__", other) def __neg__(self): return self.wrap_field_method("__neg__") def __pos__(self): return self.wrap_field_method("__pos__") def __abs__(self): return self.wrap_field_method("__abs__") def __invert__(self): return self.wrap_field_method("__invert__") def __lt__(self, other): return self.wrap_field_method("__lt__", other) def __le__(self, other): return self.wrap_field_method("__le__", other) def __eq__(self, other): return self.wrap_field_method("__eq__", other) def __ne__(self, other): return self.wrap_field_method("__ne__", other) def __gt__(self, other): return self.wrap_field_method("__gt__", other) def __ge__(self, other): return self.wrap_field_method("__ge__", other) def any(self, **kwargs): return self.wrap_field_method("any", **kwargs) def all(self, **kwargs): return self.wrap_field_method("all", **kwargs) def argmax(self, **kwargs): return self.wrap_field_method("argmax", **kwargs) def argmin(self, **kwargs): return self.wrap_field_method("argmin", **kwargs) def argpartition(self, **kwargs): return self.wrap_field_method("argpartition", **kwargs) def argsort(self, **kwargs): return self.wrap_field_method("argsort", **kwargs) def clip(self, **kwargs): return self.wrap_field_method("clip", **kwargs) def conj(self, **kwargs): return self.wrap_field_method("conj", **kwargs) def conjugate(self, **kwargs): return self.wrap_field_method("conjugate", **kwargs) def cumprod(self, **kwargs): return self.wrap_field_method("cumprod", **kwargs) def cumsum(self, **kwargs): return self.wrap_field_method("cumsum", **kwargs) def max(self, **kwargs): return self.wrap_field_method("max", **kwargs) def mean(self, **kwargs): return self.wrap_field_method("mean", **kwargs) def min(self, **kwargs): return self.wrap_field_method("min", **kwargs) def nonzero(self, **kwargs): return self.wrap_field_method("nonzero", **kwargs) def partition(self, **kwargs): return self.wrap_field_method("partition", **kwargs) def prod(self, **kwargs): return self.wrap_field_method("prod", **kwargs) def ptp(self, **kwargs): return self.wrap_field_method("ptp", **kwargs) def round(self, **kwargs): return self.wrap_field_method("round", **kwargs) def std(self, **kwargs): return self.wrap_field_method("std", **kwargs) def sum(self, **kwargs): return self.wrap_field_method("sum", **kwargs) def __getattr__(self, name): # define a callback to finalize # print("!! getting attr::", name) if callable(getattr(self.patches[0], name)): def _wrap(lst): try: return AMRField(lst, self.skeleton) except TypeError: return lst return _FieldListCallableAttrWrapper(self.patches, name, _wrap) else: # return [getattr(fld, name) for fld in self.patches] ret0 = getattr(self.patches[0], name) # Check that all patches have the same value. Maybe this should # have a debugging flag attached to it since it will take time. try: all_same = all(getattr(blk, name) == ret0 for blk in self.patches[1:]) except ValueError: all_same = all(np.all(getattr(blk, name) == ret0) for blk in self.patches[1:]) if not all_same: raise ValueError("different patches of the AMRField have " "different values for attribute: {0}" "".format(name)) return ret0 ## ## EOF ##