Source code for fret.workspace

import copy
import inspect as ins
import json
import logging
import pathlib
import pickle

import toml

from .common import configurables, plugins, NotConfiguredError
from .util import Configuration, stateful, Iterator, date_str


[docs]class Workspace: """Workspace utilities. One can save/load configurations, build models with specific configuration, save snapshots, open results, etc., using workspace objects.""" def __init__(self, path, config=None, config_dict=None): self._path = pathlib.Path(path) self._modules = dict() conf = None if self.config_path.exists(): conf = toml.load(self.config_path.open()) if config_dict is not None: if conf is None: conf = config_dict else: conf.update(config_dict) self._config_dict = None if conf is not None: self._config_dict = copy.deepcopy(conf) for name, cfg in conf.items(): cls_name = cfg['__module'] del cfg['__module'] self._modules[name] = (cls_name, cfg) if config: self._modules.update(config) for plugin in plugins: plugin.apply(self) def config_dict(self): return {name: dict({'__module': cls_name}, **cfg) for name, (cls_name, cfg) in self._modules.items()} @property def path(self): """Workspace root path.""" if not self._path.exists(): self._path.mkdir(parents=True) return self._path @property def config_path(self): """Workspace configuration path.""" cp = self.path.joinpath('config.toml') return cp
[docs] def log(self, *filename): """Get log file path within current workspace. Args: filename (str or list): relative path to file; if ommited, returns root path of logs. """ path = self.path.joinpath('log', *filename) _mkdir(path, not filename or filename[-1].endswith('/')) return path
[docs] def result(self, *filename): """Get result file path within current workspace. Args: filename (str or list): relative path to file; if ommited, returns root path of results. """ path = self.path.joinpath('result', *filename) _mkdir(path, not filename or filename[-1].endswith('/')) return path
[docs] def snapshot(self, *filename): """Get snapshot file path within current workspace. Args: filename (str or list): relative path to file; if ommited, returns root path of snapshots. """ path = self.path.joinpath('snapshot', *filename) _mkdir(path, not filename or filename[-1].endswith('/')) return path
[docs] def register(self, name, module, **kwargs): """Register and save module configuration.""" if not ins.isclass(module): cfg = module.config._dict() # pylint: disable=protected-access cfg.update(kwargs) self._modules[name] = (module.__class__.__name__, cfg) else: self._modules[name] = (module.__name__, kwargs)
[docs] def write(self): """Save module configuration of this workspace to file.""" toml.dump(self.config_dict(), self.config_path.open('w'))
def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.write() def _try_get_module(self, name='main'): if name in self._modules: return self._modules[name] else: raise NotConfiguredError('module %s not configured' % name)
[docs] def build(self, name='main', **kwargs): """Build module according to the configurations in current workspace.""" cls_name, cfg = self._try_get_module(name) cfg = cfg.copy() cfg.update(kwargs) try: cls = configurables[cls_name] except KeyError: raise KeyError('definition of module %s not found' % cls_name) for sub in cls.submodules: if sub not in cfg or isinstance(cfg[sub], str): if cls._build_subs: # pylint: disable=protected-access cfg[sub] = self.build(cfg.get(sub) or sub) else: cfg[sub] = Builder(self, cfg.get(sub) or sub) # noinspection PyCallingNonCallable obj = cls(**cfg) obj.ws = self obj.build_name = name if kwargs: obj.spec = Configuration(kwargs) return obj
[docs] def save(self, obj, tag): """Save module as a snapshot. Args: tag (str or pathlib.Path) : snapshot tag or path.""" # pylint: disable=protected-access env = self.config_dict() args = obj.spec._dict() if hasattr(obj, 'spec') else dict() state = obj.state_dict() if isinstance(tag, str) and not tag.endswith('.pt'): f = self.snapshot(obj.build_name + '.' + tag + '.pt') else: f = pathlib.Path(tag) self.save_to_file({'env': env, 'args': args, 'state': state}, str(f))
[docs] def load(self, name='main', tag=None, path=None): """Load module from a snapshot. Args: tag (str or pathlib.Path) : snapshot tag or path.""" if tag is None and path is None: f = self.snapshot(name + '.pt') elif path: f = pathlib.Path(path) else: f = self.snapshot(name + '.' + tag + '.pt') state = self.load_from_file(str(f)) last_ws = Workspace(self._path, config_dict=state['env']) obj = last_ws.build(name, **state['args']) obj.load_state_dict(state['state']) return obj
def save_to_file(self, obj, fn): with open(fn, 'wb') as f: pickle.dump(obj, f) def load_from_file(self, fn): return pickle.load(open(fn, 'rb'))
[docs] def logger(self, name: str): """Get a logger that logs to a file under workspace. Notice that same logger instance is returned for same names. Args: name(str): logger name """ logger = logging.getLogger('fret.' + name) if logger.handlers: # previously configured, remain unchanged return logger file_formatter = logging.Formatter('%(levelname)s [%(name)s] ' '%(asctime)s %(message)s', datefmt='%Y-%m-%d %H:%M:%S') file_handler = logging.FileHandler( str(self.log(name + '.log'))) file_handler.setFormatter(file_formatter) logger.addHandler(file_handler) return logger
[docs] def run(self, tag, resume=True): """Initiate a context manager that provides a persistent running environment. Mainly used to suspend and resume a time consuming process.""" return Run(self, tag, resume)
def record(self, value, metrics, descending=None, **kwargs): is_des = descending is True or \ (descending is None and metrics.endswith('-')) metrics = metrics.rstrip('+-') + ('-' if is_des else '+') data = {} for name, cfg in self.config_dict().items(): for k, v in cfg.items(): data[name + '.' + k] = v data.update({'metrics': metrics, 'value': value}) data.update(kwargs) with self.result(date_str + '.json-lines').open('a') as of: print(json.dumps(data), file=of) def __str__(self): return str(self.path) def __repr__(self): return 'Workspace(path=' + str(self.path) + ')'
[docs]class Run: """Class designed for running state persistency.""" __slots__ = ['_ws', '_id', '_states', '_index', '_seen'] def __init__(self, ws, tag, resume): self._ws = ws self._id = None self._states = dict() self._index = 0 self._seen = set() # only load once from file if resume: # TODO: accurate name search ids = [fn.name for fn in ws.snapshot().iterdir() if fn.is_dir() and fn.name.startswith(tag + '-')] if ids: self._id = max(ids) # most recent if self._id is None: self._id = tag + '-' + date_str if not resume: while ws.snapshot(self._id).exists(): self._id = self._id + '_' def __enter__(self): # load state if possible state_file = self._ws.snapshot(self._id, '.states.pt') if state_file.exists(): self._states = self._ws.load_from_file(str(state_file)) return self def __exit__(self, exc_type, exc_val, exc_tb): state_file = self._ws.snapshot(self._id, '.states.pt') for k in self._states: if hasattr(self._states[k], 'state_dict'): self._states[k] = self._states[k].state_dict() self._ws.save_to_file(self._states, str(state_file)) @property def id(self): return self._id def value(self, value, name=None): if name is None: name = str(self._index) self._index += 1 if name in self._states and name not in self._seen: self._seen.add(name) return self._states[name] else: self._states[name] = value return value def register(self, obj, name=None): if name is None: name = str(self._index) self._index += 1 if name in self._states and name not in self._seen: obj.load_state_dict(self._states[name]) self._seen.add(name) self._states[name] = obj return obj def iter(self, data, *label, name=None, **kwargs): return self.register(Iterator(data, *label, **kwargs), name) def acc(self, name=None): return self.register(Accumulator(), name)
[docs] def range(self, *args, name=None): """Works like normal range but with position recorded. Next time start from next loop, as current loop is finished.""" return self.register(Range(*args), name)
[docs] def brange(self, *args, name=None): """Breakable range. Works like normal range but with position recorded. Next time start from current position, as this loop isn't finished.""" return self.register(Range(*args, breakable=True), name)
def log(self, *filename): path = self._ws.path.joinpath('log', self._id, *filename) _mkdir(path, not filename or filename[-1].endswith('/')) return path def result(self, *filename): path = self._ws.path.joinpath('result', self._id, *filename) _mkdir(path, not filename or filename[-1].endswith('/')) return path def snapshot(self, *filename): path = self._ws.path.joinpath('snapshot', self._id, *filename) _mkdir(path, not filename or filename[-1].endswith('/')) return path
[docs]@stateful class Accumulator: """A stateful accumulator.""" __slots__ = ['_sum', '_cnt'] def __init__(self): self._sum = 0 self._cnt = 0 def __iadd__(self, other): self._sum += other self._cnt += 1 return self def __int__(self): return int(self._sum) def __float__(self): return float(self._sum) def clear(self): self._sum = 0 self._cnt = 0 def sum(self): return self._sum def mean(self): return self._sum / self._cnt if self._cnt > 0 else self._sum
[docs]@stateful('start', '_breakable') class Range: """A stateful range object that mimics built-in ``range``.""" __slots__ = ['start', 'step', 'stop', '_start', '_breakable'] def __init__(self, *args, breakable=False): r = range(*args) self.start = r.start self.step = r.step self.stop = r.stop self._start = r.start self._breakable = breakable def __iter__(self): for i in range(self.start, self.stop, self.step): self.start = i + (0 if self._breakable else self.step) yield i def clear(self): self.start = self._start
[docs]class Builder: """Class for building a specific module, with preset ws configuration.""" def __init__(self, ws, name): self.ws = ws self._name = name def __eq__(self, other): # pylint: disable=protected-access return self.ws._modules[self._name] == other.ws._modules[other._name] def __call__(self, **kwargs): return self.ws.build(self._name, **kwargs) def __str__(self): # pylint: disable=protected-access cls_name, cfg = self.ws._modules[self._name] return cls_name + '(' + str(Configuration(cfg)) + ')' def __repr__(self): return str(self) def __getattr__(self, item): # pylint: disable=protected-access cls_name, _ = self.ws._try_get_module(self._name) try: # pylint: disable=protected-access cls = configurables[cls_name] except KeyError: raise KeyError('definition of module %s not found' % cls_name) return getattr(cls, item)
def _mkdir(p, is_dir=False): if is_dir: if not p.exists(): p.mkdir(parents=True) else: if not p.parent.exists(): p.parent.mkdir(parents=True) __all__ = ['Workspace', 'Run', 'Accumulator', 'Range', 'Builder']