Source code for romcomma.user.contexts

#  BSD 3-Clause License.
# 
#  Copyright (c) 2019-2024 Robert A. Milton. All rights reserved.
# 
#  Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
# 
#  1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
# 
#  2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the
#     documentation and/or other materials provided with the distribution.
# 
#  3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this
#     software without specific prior written permission.
# 
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
#  THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
#  CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
#  PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
#  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
#  EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

""" **Context managers** """

from __future__ import annotations

from romcomma.base.definitions import *
from time import time
from datetime import timedelta
from contextlib import contextmanager


[docs] @contextmanager def Timer(name: str = '', is_inline: bool = True): """ Context Manager for timing operations. Args: name: The name of this context, ``print``ed as what is being timed. The (default) empty string will not be timed. is_inline: Whether to report timing inline (the default), or with linebreaks to top and tail a paragraph. """ _enter = time() if name != '': if is_inline: print(f'Running {name}', end='', flush=True) else: print(f'Running {name}...') yield if name != '': _exit = time() if is_inline: print(f' took {timedelta(seconds=int(_exit-_enter))}.') else: print(f'...took {timedelta(seconds=int(_exit-_enter))}.')
[docs] @contextmanager def Environment(name: str = '', device: str = '', **kwargs): """ Context Manager setting up the environment to run operations. Args: name: The name of this context, ``print``ed as what is being run. The (default) empty string will not be timed. device: The device to run on. If this ends in the regex ``[C,G]PU*`` then the logical device ``/[C,G]PU*`` is used, otherwise device allocation is automatic. **kwargs: Is passed straight to the implementation GPFlow manager. Note, however, that ``float=float32`` is inoperative due to SciPy. ``eager=bool`` is passed to `tf.config.run_functions_eagerly <https://www.tensorflow.org/api_docs/python/tf/config/run_functions_eagerly>`_. """ with Timer(name): kwargs = kwargs | {'float': 'float64'} eager = kwargs.pop('eager', None) tf.config.run_functions_eagerly(eager) print(' using GPFlow(' + ', '.join([f'{k}={v!r}' for k, v in kwargs.items()]), end=')') device = '/' + device[max(device.rfind('CPU'), device.rfind('GPU')):] if len(device) > 3: device_manager = tf.device(device) print(f' on {device}', end='') else: device_manager = Timer() implementation_manager = gf.config.as_context(gf.config.Config(**kwargs)) print('...') with device_manager: with implementation_manager: yield print('...Running ' + name, end='')