Source code for viscid.vutil

#!/usr/bin/env python

# FIXME: this module is way too long and disorganized

from __future__ import print_function, division

import fnmatch
from glob import glob
from itertools import chain
import logging
from operator import itemgetter
import os.path
import re
import sys
from timeit import default_timer as time

import viscid
from viscid import logger
from viscid import sliceutil
from viscid.compat import izip

import numpy as np


__all__ = ["timeit", "resolve_path", "find_item", "find_items",
           "get_trilinear_field", "slice_globbed_filenames", "glob2",
           "interact"]


tree_prefix = ".   "


def find_field(vfile, fld_name_lst):
    """ convenience function to get a field that could be called many things
    returns the first fld_name in the list that is in the file """
    for fld_name in fld_name_lst:
        if fld_name in vfile:
            return vfile[fld_name]
    raise KeyError("file {0} contains none of {1}".format(vfile, fld_name_lst))

def split_floats(arg_str):
    return [float(s) for s in arg_str.split(',')]

def add_animate_arguments(parser):
    """ add common options for animating, you may want to make sure parser was
    constructed with conflict_handler='resolve' """
    anim = parser.add_argument_group("Options for creating animations")
    anim.add_argument("-a", "--animate", default=None,
                      help="animate results")
    anim.add_argument("--prefix", default=None,
                      help="Prefix of the output image filenames")
    anim.add_argument('-r', '--rate', dest='framerate', type=int, default=5,
                      help="animation frame rate (default 5).")
    anim.add_argument('--qscale', dest='qscale', default='2',
                      help="animation quality flag (default 2).")
    anim.add_argument('-k', dest='keep', action='store_true',
                      help="keep temporary files.")
    return parser

def add_mpl_output_arguments(parser):
    """ add common options for tuning matplotlib output, you may want to make
    sure parser was constructed with conflict_handler='resolve' """
    mplargs = parser.add_argument_group("Options for tuning matplotlib")
    mplargs.add_argument("-s", "--size", dest="plot_size", type=split_floats,
                         default=None, help="size of mpl plot (inches)")
    mplargs.add_argument("--dpi", dest="dpi", type=float, default=None,
                         help="dpi of plot")
    parser.add_argument("--prefix", default=None,
                        help="Prefix of the output image filenames")
    parser.add_argument("--format", "-f", default="png",
                        help="output format, as in 'png'|'pdf'|...")
    parser.add_argument('-w', '--show', dest='show', action="store_true",
                        help="show plots with plt.show()")
    return parser

def common_argparse(parser, default_verb=0):
    """ add some common verbosity stuff to argparse, parse the
    command line args, and setup the logging levels
    parser should be an ArgumentParser instance, and kwargs
    should be options that get passed to logger.basicConfig
    returns the args namespace  """
    general = parser.add_argument_group("Viscid general options")
    general.add_argument("--log", action="store", type=str, default=None,
                         help="Logging level (overrides verbosity)")
    general.add_argument("-v", action="count", default=default_verb,
                         help="increase verbosity")
    general.add_argument("-q", action="count", default=0,
                         help="decrease verbosity")
    args = parser.parse_args()

    # setup the logging level
    if args.log is not None:
        logger.setLevel(getattr(logging, args.log.upper()))
    else:
        # default = 30 WARNING
        verb = args.v - args.q
        logger.setLevel(int(30 - 10 * verb))

    return args

def subclass_spider(cls):
    """ return recursive list of subclasses of cls (depth first) """
    lst = [cls]
    # reversed gives precedence to the more recently declared classes
    for c in reversed(cls.__subclasses__()):
        lst += subclass_spider(c)
    return lst

def timereps(reps, func, *args, **kwargs):
    arr = [None] * reps
    for i in range(reps):
        start = time()
        func(*args, **kwargs)
        end = time()
        arr[i] = end - start
    return min(arr), max(arr), sum(arr) / reps

[docs]def timeit(f, *args, **kwargs): """overly simple timeit wrapper Arguments: f: callable to timeit *args: positional arguments for `f` **kwargs: keyword arguments for `f` Keyword arguments: timeit_repeat (int): number of times to call `f` (Default: 1) timeit_print_stats (bool): print min/max/mean/median when done timeit_quet (bool): quiets all output (useful if you only want the timeit_stats dict filled) timeit_stats (dict): Stats will be stuffed into here Returns: The result of `f(*args, **kwargs)` """ timeit_repeat = kwargs.pop('timeit_repeat', 1) timeit_print_stats = kwargs.pop('timeit_print_stats', True) timeit_quiet = kwargs.pop('timeit_quiet', False) timeit_stats = kwargs.pop('timeit_stats', dict()) times = np.empty((timeit_repeat,), dtype='f8') for i in range(timeit_repeat): ret = None t0 = time() ret = f(*args, **kwargs) t1 = time() s = "{0:.03g}".format(t1 - t0) times[i] = t1 - t0 if not timeit_quiet and (timeit_repeat == 1 or not timeit_print_stats): secs = "second" if s == "1" else "seconds" print("<function {0}.{1}>".format(f.__module__, f.__name__), "took", s, secs) timeit_stats['min'] = np.min(times) timeit_stats['max'] = np.max(times) timeit_stats['mean'] = np.mean(times) timeit_stats['median'] = np.median(times) timeit_stats['repeat'] = timeit_repeat if not timeit_quiet and timeit_repeat > 1 and timeit_print_stats: print("<function {0}.{1}> stats ({2} runs):" "".format(f.__module__, f.__name__, timeit_repeat)) print(" Min: {min:.3g}, Mean: {mean:.3g}, Median: {median:.3g}, " "Max: {max:.3g}".format(**timeit_stats)) return ret
[docs]def resolve_path(dset, loc, first=False): """Search for globbed paths in a nested dict-like hierarchy Args: dset (dict): Root of some nested dict-like hierarchy loc (str): path as a glob pattern first (bool): Stop at first match and return a single value Raises: KeyError: If there are no glob matches Returns: If first == True, (value, path) else, ([value0, value1, ...], [path0, path1, ...]) """ try: if first: return dset[loc], loc else: return [dset[loc]], [loc] except KeyError: searches = [loc.strip('/').split('/')] dsets = [dset] paths = [[]] while any(searches): next_dsets = [] next_searches = [] next_paths = [] for dset, search, path in izip(dsets, searches, paths): try: next_dsets.append(dset[search[0]]) next_searches.append(search[1:]) next_paths.append(path + [search[0]]) except (KeyError, TypeError, IndexError): s = [{}.items()] if hasattr(dset, 'items'): s.append(dset.items()) if hasattr(dset, 'attrs'): s.append(dset.attrs.items()) for key, val in chain(*s): if fnmatch.fnmatchcase(key, search[0]): next_dsets.append(val) next_searches.append(search[1:]) next_paths.append(path + [key]) if first: break dsets = next_dsets searches = next_searches paths = next_paths if dsets: dsets, paths = dsets, ['/'.join(p) for p in paths] if first: return dsets[0], paths[0] else: return dsets, paths else: raise KeyError("Path {0} has no matches".format(loc))
[docs]def find_item(dset, loc): """Shortcut for first :py:func:`resolve_path`, item only""" return resolve_path(dset, loc, first=True)[0]
[docs]def find_items(dset, loc): """Shortcut for :py:func:`resolve_path`, items only""" return resolve_path(dset, loc)[0]
def _hexchar2int(arr): # this np.char.decode(..., 'hex') doesn't work for py3k; kinda silly try: return np.frombuffer(np.char.decode(arr, 'hex'), dtype='u1') except LookupError: import codecs return np.frombuffer(codecs.decode(arr, 'hex_codec'), dtype='u1') def _string_colors_as_hex(scalars): if isinstance(scalars, viscid.string_types): scalars = [scalars] # 24bit rgb or 32bit rgba strings only hex_char_re = r"#[0-9a-fA-F]{6,8}" allhex = all(re.match(hex_char_re, s) for s in scalars) # if not all hex color codes, then process all colors through # and make them all 24/32 bit rgba hex codes if not allhex: nompl_err_str = ("Matplotlib must be installed to use " "color names. Please either install " "matplotlib or use hex color codes.") cc = None scalars_as_hex = [] for s in scalars: if re.match(hex_char_re, s): # already 24bit rgb or 32bit rgba hex s_hex = s elif re.match(r"#[0-9a-fA-F]{3,4}", s): # 12bit rgb or 16bit rgba hex s_hex = "#" + "".join(_s + _s for _s in s[1:]) else: # matplotlib / html color string if cc is None: try: import matplotlib.colors cc = matplotlib.colors.ColorConverter() except ImportError: raise RuntimeError(nompl_err_str) s_rgba = np.asarray(cc.to_rgba(s)) s_rgba = np.round(255 * s_rgba).astype('i') s_hex = ("#{0[0]:02x}{0[1]:02x}{0[2]:02x}" "{0[3]:02x}".format(s_rgba)) scalars_as_hex.append(s_hex) scalars = scalars_as_hex # fill the alpha channel if any of the scalars have a 4th # bytes-worth of data if any(len(s) == 9 for s in scalars): scalars = [s.ljust(9, 'f') for s in scalars] return np.asarray(scalars) def prepare_lines(lines, scalars=None, do_connections=False, other=None): """Concatenate and standardize a list of lines Args: lines (list): Must be a list of 3xN or 4xN ndarrays of xyz(s) data for N points along the line. N need not be the same for all lines. Can alse be 6xN such that lines[:][3:, :] are interpreted as rgb colors scalars (sequence): can be one of:: - single hex color (ex, `#FD7400`) - sequence of Nlines hex colors - single rgb(a) tuple - sequence of Nlines rgb(a) sequences - sequence of N values that go with each vertex and will be mapped with a colormap - sequence of Nlines values that go each line and will be mapped with a colormap - Field. If `np.prod(fld.shape) in (N, nlines)` then the field is interpreted as a simple sequence of values (ie, the topology result from calc_streamlines for coloring each line). Otherwise, the field is interpolated onto each vertex. do_connections (bool): Whether or not to make connections array other (dict): a dictionary of other arrays that should be reshaped and the like the same way scalars is Returns: (vertices, scalars, connections, other) * vertices (ndarray): 3xN array of N xyz points. N is the sum of the lengths of all the lines * scalars (ndarray): N array of scalars, 3xN array of uint8 rgb values, 4xN array of uint8 rgba values, or None * connections (ndarray): Nx2 array of ints (indices along axis 1 of vertices) describing the forward and backward connectedness of the lines, or None * other (dict): a dict of N length arrays Raises: ValueError: If rgb data is not in a valid range or the shape of scalars is not understood """ nlines = len(lines) npts = [line.shape[1] for line in lines] N = np.sum(npts) first_idx = np.cumsum([0] + npts[:-1]) vertices = [np.asarray(line) for line in lines] vertices = np.concatenate(lines, axis=1) if vertices.dtype.kind not in 'fc': vertices = np.asarray(vertices, dtype='f') if vertices.shape[0] > 3: if scalars is not None: viscid.logger.warning("Overriding line scalars with scalars kwarg") else: scalars = vertices[3:, :] vertices = vertices[:3, :] if scalars is not None: scalars_are_strings = False if isinstance(scalars, viscid.field.Field): if np.prod(scalars.shape) in (nlines, N): scalars = np.asarray(scalars).reshape(-1) else: scalars = viscid.interp_trilin(scalars, vertices) if scalars.size != N: raise ValueError("Scalars was not a scalar field") elif isinstance(scalars, (list, tuple, viscid.string_types, np.ndarray)): # string types need some extra massaging if any(isinstance(s, viscid.string_types) for s in scalars): assert all(isinstance(s, viscid.string_types) for s in scalars) scalars_are_strings = True scalars = _string_colors_as_hex(scalars) elif isinstance(scalars, np.ndarray): scalars = scalars else: for i, si in enumerate(scalars): if not isinstance(si, np.ndarray): scalars[i] = np.asarray(si) scalars[i] = np.atleast_2d(scalars[i]) try: scalars = np.concatenate(scalars, axis=0) except ValueError: scalars = np.concatenate(scalars, axis=1) scalars = np.atleast_2d(scalars) if scalars.dtype == np.dtype('object'): raise RuntimeError("Scalars must be numbers, tuples of numbers " "that indicate rgb(a), or hex strings - they " "must not be python objects") if scalars.shape == (1, 1): scalars = scalars.repeat(N, axis=1) elif scalars.shape == (1, nlines) or scalars.shape == (nlines, 1): # one scalar for each line, so broadcast it scalars = scalars.reshape(nlines, 1) scalars = [scalars[i].repeat(ni) for i, ni in enumerate(npts)] scalars = np.concatenate(scalars, axis=0).reshape(1, N) elif scalars.shape == (N, 1) or scalars.shape == (1, N): # catch these so they're not interpreted as colors if # nlines == 1 and N == 3; ie. 1 line with 3 points scalars = scalars.reshape(1, N) elif scalars.shape in [(3, nlines), (nlines, 3), (4, nlines), (nlines, 4)]: # one rgb(a) color for each line, so broadcast it if (scalars.shape in [(3, nlines), (4, nlines)] and scalars.shape not in [(3, 3), (4, 4)]): # the guard against shapes (3, 3) and (4, 4) mean that # these square shapes are assumed Nlines x {3,4} scalars = scalars.T nccomps = scalars.shape[1] # 3 for rgb, 4 for rgba colors = [] for i, ni in enumerate(npts): c = scalars[i].reshape(nccomps, 1).repeat(ni, axis=1) colors.append(c) scalars = np.concatenate(colors, axis=1) elif scalars.shape in [(3, N), (N, 3), (4, N), (N, 4)]: # one rgb(a) color for each vertex if (scalars.shape in [(3, N), (4, N)] and scalars.shape not in [(3, 3), (4, 4)]): # the guard against shapes (3, 3) and (4, 4) mean that # these square shapes are assumed N x {3,4} scalars = scalars.T elif scalars.shape in [(1, 3), (3, 1), (1, 4), (4, 1)]: # interpret a single rgb(a) color, and repeat/broadcast it scalars = scalars.reshape([-1, 1]).repeat(N, axis=1) else: scalars = scalars.reshape(-1, N) # scalars now has shape (1, N), (3, N), or (4, N) if scalars_are_strings: # translate hex colors (#ff00ff) into rgb(a) values scalars = np.char.lstrip(scalars, '#') strlens = np.char.str_len(scalars) min_strlen, max_strlen = np.min(strlens), np.max(strlens) if min_strlen == max_strlen == 8: # 32-bit rgba (two hex chars per channel) scalars = _hexchar2int(scalars.astype('S8')).reshape(-1, 4).T elif min_strlen == max_strlen == 6: # 24-bit rgb (two hex chars per channel) scalars = _hexchar2int(scalars.astype('S6')).reshape(-1, 3).T else: raise NotImplementedError("This should never happen as " "scalars as colors should already " "be preprocessed appropriately") elif scalars.shape[0] == 1: # normal scalars, cast them down to a single dimension scalars = scalars.reshape(-1) elif scalars.shape[0] in (3, 4): # The scalars encode rgb data, standardize the result to a # 3xN or 4xN ndarray of 1 byte unsigned ints [0..255] if np.all(scalars >= 0) and np.all(scalars <= 1): scalars = (255 * scalars).round().astype('u1') elif np.all(scalars >= 0) and np.all(scalars < 256): scalars = scalars.round().astype('u1') else: raise ValueError("Rgb data should be in range [0, 1] or " "[0, 255], range given is [{0}, {1}]" "".format(np.min(scalars), np.max(scalars))) else: raise ValueError("Scalars should either be a number, or set of " "rgb values, shape is {0}".format(scalars.shape)) # scalars should now have shape (N, ) or be a uint8 array with shape # (3, N) or (4, N) encoding an rgb(a) color for each point [0..255] # ... done with scalars... # broadcast / reshape additional arrays given in other if other: for key, arr in other.items(): if arr is None: pass elif arr.shape == (1, nlines) or arr.shape == (nlines, 1): arr = arr.reshape(nlines, 1) arr = [arr[i].repeat(ni) for i, ni in enumerate(npts)] other[key] = np.concatenate(arr, axis=0).reshape(1, N) else: try: other[key] = arr.reshape(-1, N) except ValueError: viscid.logger.warning("Unknown dimension, dropping array {0}" "".format(key)) if do_connections: connections = [None] * nlines for i, ni in enumerate(npts): # i0 is the index of the first point of the i'th line in lines i0 = first_idx[i] connections[i] = np.vstack([np.arange(i0, i0 + ni - 1.5), np.arange(i0 + 1, i0 + ni - 0.5)]).T connections = np.concatenate(connections, axis=0).astype('i') else: connections = None return vertices, scalars, connections, other
[docs]def get_trilinear_field(): """get a generic trilinear field""" xl, xh, nx = -1.0, 1.0, 41 yl, yh, ny = -1.5, 1.5, 41 zl, zh, nz = -2.0, 2.0, 41 x = np.linspace(xl, xh, nx) y = np.linspace(yl, yh, ny) z = np.linspace(zl, zh, nz) crds = viscid.wrap_crds("nonuniform_cartesian", [('x', x), ('y', y), ('z', z)]) b = viscid.empty(crds, name="f", nr_comps=3, center="Cell", layout="interlaced") X, Y, Z = b.get_crds(shaped=True) x01, y01, z01 = 0.5, 0.5, 0.5 x02, y02, z02 = 0.5, 0.5, 0.5 x03, y03, z03 = 0.5, 0.5, 0.5 b['x'][:] = (0.0 + 1.0 * (X - x01) + 1.0 * (Y - y01) + 1.0 * (Z - z01) + 1.0 * (X - x01) * (Y - y01) + 1.0 * (Y - y01) * (Z - z01) + 1.0 * (X - x01) * (Y - y01) * (Z - z01)) b['y'][:] = (0.0 + 1.0 * (X - x02) - 1.0 * (Y - y02) + 1.0 * (Z - z02) + 1.0 * (X - x02) * (Y - y02) + 1.0 * (Y - y02) * (Z - z02) - 1.0 * (X - x02) * (Y - y02) * (Z - z02)) b['z'][:] = (0.0 + 1.0 * (X - x03) + 1.0 * (Y - y03) - 1.0 * (Z - z03) + 1.0 * (X - x03) * (Y - y03) + 1.0 * (Y - y03) * (Z - z03) + 1.0 * (X - x03) * (Y - y03) * (Z - z03)) return b
[docs]def slice_globbed_filenames(glob_pattern): """Apply a slice to a glob pattern Note: Slice by value works by adding an 'f' to a value, as like the rest of Viscid. Args: glob_pattern (str): A string Returns: list of filenames Examples: If a directory contains files, >>> os.listdir() ["file.010.txt", "file.020.txt", "file.030.txt", "file.040.txt"] then sliced globs can look like >>> expand_glob_slice("f*.[:2].txt") ["file.010.txt", "file.020.txt"] >>> expand_glob_slice("f*.[10.0j::2].txt") ["file.010.txt", "file.030.txt"] >>> expand_glob_slice("f*.[20j:2].txt") ["file.020.txt", "file.040.txt"] """ glob_pattern = os.path.expanduser(os.path.expandvars(glob_pattern)) glob_pattern = os.path.abspath(glob_pattern) # construct a regex to match the results # verify glob pattern has only one dtime_re = sliceutil.RE_DTIME_SLC_GROUP number_re = r"[-+]?[0-9]*\.?[0-9]+[fjFJ]?|[-+]?[0-9+]" el_re = r"(?:{0}|{1})".format(dtime_re, number_re) slc_re = r"\[({0})?(:({0}))?(:[-+]?[0-9]*)?\]".format(el_re) n_slices = len(re.findall(slc_re, glob_pattern)) if n_slices > 1: viscid.logger.warning("Multiple filename slices found, only using the " "first.") if n_slices: m = re.search(slc_re, glob_pattern) slcstr = glob_pattern[m.start() + 1:m.end() - 1] edited_glob = glob_pattern[:m.start()] + "*" + glob_pattern[m.end():] res_re = glob_pattern[:m.start()] + "TSLICE" + glob_pattern[m.end():] res_re = fnmatch.translate(res_re) res_re = res_re.replace("TSLICE", r"(?P<TSLICE>.*?)") else: edited_glob = glob_pattern slcstr = "" fnames = glob(edited_glob) if n_slices: if not fnames: raise IOError("the glob {0} matched no files".format(edited_glob)) times = [] _newfn = [] for fn in fnames: try: times.append(float(re.match(res_re, fn).group('TSLICE'))) _newfn.append(fn) except ValueError: pass fnames = _newfn times = [float(re.match(res_re, fn).group('TSLICE')) for fn in fnames] fnames = [fn for fn, t in sorted(zip(fnames, times), key=itemgetter(1))] times.sort() std_sel = sliceutil.standardize_sel(slcstr) slc = sliceutil.std_sel2index(std_sel, times, tdunit='s', epoch=None) else: times = [None] * len(fnames) slc = slice(None) idx_whitelist = np.asarray(np.arange(len(fnames))[slc]).reshape(-1) culled_fnames = [s for i, s in enumerate(fnames) if i in idx_whitelist] return culled_fnames
[docs]def glob2(glob_pattern, *args, **kwargs): """Wrap slice_globbed_filenames, but return [] on no match See Also: * :py:func:`slice_globbed_filenames` """ try: return slice_globbed_filenames(glob_pattern, *args, **kwargs) except IOError: return []
[docs]def interact(banner=None, ipython=True, stack_depth=0, global_ns=None, local_ns=None, viscid_ns=True, mpl_ns=False, mvi_ns=False): """Start an interactive interpreter""" if banner is None: banner = "Interactive Viscid..." if mpl_ns: banner += "\n - Viscid's matplotlib interface available as `vlt`" if mvi_ns: banner += "\n - Viscid's mayavi interface available as `vlab`" banner += "\n - Use vlab.show(...) to interact with Mayavi" banner += "\n - FYI, all Mayavi objects all have trait_names()" banner += "\n - Use Ctrl-D (eof) to end interaction" def _merge_ns(src, target): target.update(dict([(name, getattr(src, name)) for name in dir(src)])) target[src.__name__.split('.')[-1]] = src ns = dict() if viscid_ns: _merge_ns(viscid, ns) if mpl_ns: from viscid.plot import vpyplot as vlt _merge_ns(vlt, ns) if mvi_ns: from viscid.plot import vlab _merge_ns(vlab, ns) call_frame = sys._getframe(stack_depth).f_back # pylint: disable=protected-access if global_ns is None: global_ns = call_frame.f_globals ns.update(global_ns) if local_ns is None: local_ns = call_frame.f_locals ns.update(local_ns) try: if not ipython: raise ImportError from IPython import embed embed(user_ns=ns, banner1=banner) except ImportError: import code code.interact(banner, local=ns) print("Resuming Script")
## ## EOF ##