Source code for sciunit.base

"""The base class for many SciUnit objects."""

import sys

PLATFORM = sys.platform
PYTHON_MAJOR_VERSION = sys.version_info.major
if PYTHON_MAJOR_VERSION < 3:  # Python 2
    raise Exception('Only Python 3 is supported')

import json, git, pickle, hashlib

import numpy as np
import pandas as pd
from pathlib import Path
from git.exc import GitCommandError, InvalidGitRepositoryError
from git.cmd import Git
from git.remote import Remote
from git.repo.base import Repo
from typing import Dict, List, Optional, Tuple, Union, Any
from io import StringIO
try:
    import tkinter
except ImportError:
    tkinter = None

KERNEL = ('ipykernel' in sys.modules)
HERE = Path(__file__).resolve().parent.name


[docs]class Versioned(object): """A Mixin class for SciUnit objects. Provides a version string based on the Git repository where the model is tracked. Provided in part by Andrew Davison in issue #53. """
[docs] def get_repo(self, cached: bool=True) -> Repo: """Get a git repository object for this instance. Args: cached (bool, optional): Whether to use cached data. Defaults to True. Returns: Repo: The git repo for this instance. """ module = sys.modules[self.__module__] # We use module.__file__ instead of module.__path__[0] # to include modules without a __path__ attribute. if hasattr(self.__class__, '_repo') and cached: repo = self.__class__._repo elif hasattr(module, '__file__'): path = Path(module.__file__).resolve() try: repo = git.Repo(path, search_parent_directories=True) except InvalidGitRepositoryError: repo = None else: repo = None self.__class__._repo = repo return repo
[docs] def get_version(self, cached: bool=True) -> str: """Get a git version (i.e. a git commit hash) for this instance. Args: cached (bool, optional): Whether to use the cached data. Defaults to True. Returns: str: The git version for this instance. """ if cached and hasattr(self.__class__, '_version'): version = self.__class__._version else: repo = self.get_repo() if repo is not None: head = repo.head version = head.commit.hexsha if repo.is_dirty(): version += "*" else: version = None self.__class__._version = version return version
version = property(get_version)
[docs] def get_remote(self, remote: str='origin') -> Remote: """Get a git remote object for this instance. Args: remote (str, optional): The remote Git repo. Defaults to 'origin'. Returns: Remote: The git remote object for this instance. """ repo = self.get_repo() if repo is not None: remotes = {r.name: r for r in repo.remotes} r = repo.remotes[0] if remote not in remotes else remotes[remote] else: r = None return r
[docs] def get_remote_url(self, remote: str='origin', cached: bool=True) -> str: """Get a git remote URL for this instance. Args: remote (str, optional): The remote Git repo. Defaults to 'origin'. cached (bool, optional): Whether to use cached data. Defaults to True. Raises: ex: A Git command error. Returns: str: The git remote URL for this instance. """ if hasattr(self.__class__, '_remote_url') and cached: url = self.__class__._remote_url else: r = self.get_remote(remote) try: url = list(r.urls)[0] except GitCommandError as ex: if 'correct access rights' in str(ex): # If ssh is not setup to access this repository cmd = ['git', 'config', '--get', 'remote.%s.url' % r.name] url = Git().execute(cmd) else: raise ex except AttributeError: url = None if url is not None and url.startswith('git@'): domain = url.split('@')[1].split(':')[0] path = url.split(':')[1] url = "http://%s/%s" % (domain, path) self.__class__._remote_url = url return url
remote_url = property(get_remote_url)
[docs]class SciUnit(Versioned): """Abstract base class for models, tests, and scores."""
[docs] def __init__(self): """Instantiate a SciUnit object.""" self.unpicklable = []
#: A list of attributes that cannot or should not be pickled. unpicklable = [] #: A URL where the code for this object can be found. _url = None #: A verbosity level for printing information. verbose = 1
[docs] def __getstate__(self) -> dict: """Copy the object's state from self.__dict__. Contains all of the instance attributes. Always uses the dict.copy() method to avoid modifying the original state. Returns: dict: The state of this instance. """ state = self.__dict__.copy() # Remove the unpicklable entries. if hasattr(self, 'unpicklable'): for key in set(self.unpicklable).intersection(state): del state[key] return state
[docs] def _state(self, state: dict=None, keys: list=None, exclude: List[str]=None) -> dict: """Get the state of the instance. Args: state (dict, optional): The dict instance that contains a part of state info of this instance. Defaults to None. keys (list, optional): Some keys of `state`. Values in `state` associated with these keys will be kept and others will be discarded. Defaults to None. exclude (List[str], optional): The list of keys. Values in `state` that associated with these keys will be removed from `state`. Defaults to None. Returns: dict: The state of the current instance. """ if state is None: state = self.__getstate__() if keys: state = {key: state[key] for key in keys if key in state.keys()} if exclude: state = {key: state[key] for key in state.keys() if key not in exclude} state = deep_exclude(state, exclude) return state
[docs] def _properties(self, keys: list=None, exclude: list=None) -> dict: """Get the properties of the instance. Args: keys (list, optional): If not None, only the properties that are in `keys` will be included in the return data. Defaults to None. exclude (list, optional): The list of properties that will not be included in return data. Defaults to None. Returns: dict: The dict of properties of the instance. """ result = {} props = self.raw_props() exclude = exclude if exclude else [] exclude += ['state', 'id'] for prop in set(props).difference(exclude): if prop == 'properties': pass # Avoid infinite recursion elif not keys or prop in keys: result[prop] = getattr(self, prop) return result
[docs] def raw_props(self) -> list: """Get the raw properties of the instance. Returns: list: The list of raw properties. """ class_attrs = dir(self.__class__) return [p for p in class_attrs if isinstance(getattr(self.__class__, p, None), property)]
@property def state(self) -> dict: """Get the state of the instance. Returns: dict: The state of the instance. """ return self._state() @property def properties(self) -> dict: """Get the properties of the instance. Returns: dict: The properties of the instance. """ return self._properties()
[docs] @classmethod def dict_hash(cls, d: dict) -> str: """SHA224 encoded value of `d`. Args: d (dict): The dict instance to be SHA224 encoded. Returns: str: SHA224 encoded value of `d`. """ od = [(key, d[key]) for key in sorted(d)] try: s = pickle.dumps(od) except AttributeError: s = json.dumps(od, cls=SciUnitEncoder).encode('utf-8') return hashlib.sha224(s).hexdigest()
@property def hash(self) -> str: """A unique numeric identifier of the current model state. Returns: str: The unique numeric identifier of the current model state. """ return self.dict_hash(self.state)
[docs] def json(self, add_props: bool=False, keys: list=None, exclude: list=None, string: bool=True, indent: None=None) -> str: """Generate a Json format encoded sciunit instance. Args: add_props (bool, optional): Whether to add additional properties of the object to the serialization. Defaults to False. keys (list, optional): Only the keys in `keys` will be included in the json content. Defaults to None. exclude (list, optional): The keys in `exclude` will be excluded from the json content. Defaults to None. string (bool, optional): The json content will be `str` type if True, `dict` type otherwise. Defaults to True. indent (None, optional): If indent is a non-negative integer or string, then JSON array elements and object members will be pretty-printed with that indent level. An indent level of 0, negative, or "" will only insert newlines. None (the default) selects the most compact representation. Using a positive integer indent indents that many spaces per level. If indent is a string (such as "\t"), that string is used to indent each level (source: https://docs.python.org/3/library/json.html#json.dump). Defaults to None. Returns: str: The Json format encoded sciunit instance. """ result = json.dumps(self, cls=SciUnitEncoder, add_props=add_props, keys=keys, exclude=exclude, indent=indent) if not string: result = json.loads(result) return result
@property def _id(self) -> Any: return id(self) @property def _class(self) -> dict: url = '' if self.url is None else self.url import_path = '{}.{}'.format( self.__class__.__module__, self.__class__.__name__ ) return {'name': self.__class__.__name__, 'import_path': import_path, 'url': url} @property def id(self) -> str: return str(self.json) @property def url(self) -> str: return self._url if self._url else self.remote_url
[docs]class SciUnitEncoder(json.JSONEncoder): """Custom JSON encoder for SciUnit objects."""
[docs] def __init__(self, *args, **kwargs): for key in ['add_props', 'keys', 'exclude']: if key in kwargs: setattr(self.__class__, key, kwargs[key]) kwargs.pop(key) super(SciUnitEncoder, self).__init__(*args, **kwargs)
[docs] def default(self, obj: Any) -> Union[str, dict, list]: """Try to encode the object. Args: obj (Any): Any object to be encoded Raises: e: Could not JSON serialize the object. Returns: Union[str, dict, list]: Encoded object. """ try: if isinstance(obj, pd.DataFrame): o = obj.to_dict(orient='split') if isinstance(obj, SciUnit): for old, new in [('data', 'scores'), ('columns', 'tests'), ('index', 'models')]: o[new] = o.pop(old) elif isinstance(obj, np.ndarray) and len(obj.shape): o = obj.tolist() elif isinstance(obj, SciUnit): state = obj.state if self.add_props: state.update(obj.properties) o = obj._state(state=state, keys=self.keys, exclude=self.exclude) elif isinstance(obj, (dict, list, tuple, str, type(None), bool, float, int)): o = json.JSONEncoder.default(self, obj) else: # Something we don't know how to serialize; # just represent it as truncated string o = "%.20s..." % obj except Exception as e: print("Could not JSON encode object %s" % obj) raise e return o
[docs]class TestWeighted(object): """Base class for objects with test weights.""" @property def weights(self) -> List[float]: """Returns a normalized list of test weights. Returns: List[float]: The normalized list of test weights. """ n = len(self.tests) if self.weights_: assert all([x >= 0 for x in self.weights_]),\ "All test weights must be >=0" summ = sum(self.weights_) # Sum of test weights assert summ > 0, "Sum of test weights must be > 0" weights = [x/summ for x in self.weights_] # Normalize to sum else: weights = [1.0/n for i in range(n)] return weights
[docs]def deep_exclude(state: dict, exclude: list) -> dict: """[summary] Args: state (dict): A dict that represents the state of an instance. exclude (list): Attributes that will be marked as 'removed' Returns: dict: [description] """ tuples = [key for key in exclude if isinstance(key, tuple)] s = state for loc in tuples: for key in loc: try: s[key] except Exception: pass else: if key == loc[-1]: s[key] = '*removed*' else: s = s[key] return state