Source code for nutsml.checkpoint

"""
.. module:: checkpoint
   :synopsis: Conveniency class to create checkpoints for network training.
"""

import os
import time

from os.path import join, exists, isdir, getmtime
from nutsml.config import Config

"""
.. module:: checkpoint
   :synopsis: Creation of checkpoints with network weights and parameters.
"""


[docs]class Checkpoint(object): """ A factory for checkpoints to periodically save network weights and other hyper/configuration parameters. | Example usage: | | def create_network(lr=0.01, momentum=0.9): | model = Sequential() | ... | optimizer = opt.SGD(lr=lr, momentum=momentum) | model.compile(optimizer=optimizer, metrics=['accuracy']) | return KerasNetwork(model), optimizer | | def parameters(network, optimizer): | return dict(lr = optimizer.lr, momentum = optimizer.momentum) | | def train_network(): | checkpoint = Checkpoint(create_network, parameters) | network, optimizer = checkpoint.load() | | for epoch in xrange(EPOCHS): | train_err = train_network() | val_err = validate_network() | | if epoch % 10 == 0: # Reduce learning rate every 10 epochs | optimizer.lr /= 2 | | checkpoint.save_best(val_err) | Checkpoints can also be saved under different names, e.g. | checkpoint.save_best(val_err, 'checkpoint'+str(epoch)) And specific checkpoints can be loaded: | network, config = checkpoint.load('checkpoint103') If no checkpoint is specified the most recent one is loaded. """
[docs] def __init__(self, create_net, parameters, checkpointspath='checkpoints'): """ Create checkpoint factory. >>> def create_network(lr=0.1): ... return 'MyNetwork', lr >>> def parameters(network, lr): ... return dict(lr = lr) >>> checkpoint = Checkpoint(create_network, parameters) >>> network, lr = checkpoint.load() >>> network, lr ('MyNetwork', 0.1) :param function create_net: Function that takes keyword parameters and returns a nuts-ml Network and and any other values or objects needed to describe the state to be checkpointed. Note: parameters(*create_net()) must work! :param function parameters: Function that takes output of create_net() and returns dictionary with parameters (same as the one that are used in create_net(...)) :param string checkpointspath: Path to folder that will contain checkpoint folders. """ if not exists(checkpointspath): os.makedirs(checkpointspath) self.basepath = checkpointspath self.create_net = create_net self.parameters = parameters self.state = None # network and other objets self.network = None # only the network self.config = None # bestscore and other checkpoint params
[docs] def dirs(self): """ Return full paths to all checkpoint folders. :return: Paths to all folders under the basedir. :rtype: list """ paths = (join(self.basepath, d) for d in os.listdir(self.basepath)) return [p for p in paths if isdir(p)]
[docs] def latest(self): """ Find most recently modified/created checkpoint folder. :return: Full path to checkpoint folder if it exists otherwise None. :rtype: str | None """ dirs = sorted(self.dirs(), key=getmtime, reverse=True) return dirs[0] if dirs else None
[docs] def datapaths(self, checkpointname=None): """ Return paths to network weights, parameters and config files. If no checkpoints exist under basedir (None, None, None) is returned. :param str|None checkpointname: Name of checkpoint. If name is None the most recent checkpoint is used. :return: (weightspath, paramspath, configpath) or (None, None, None) :rtype: tuple """ name = checkpointname if name is None: path = self.latest() if path is None: return None, None, None else: path = join(self.basepath, name) if not exists(path): os.makedirs(path) return (join(path, 'weights'), join(path, 'params.json'), join(path, 'config.json'))
[docs] def save(self, checkpointname='checkpoint'): """ Save network weights and parameters under the given name. :param str checkpointname: Name of checkpoint folder. Path will be self.basepath/checkpointname :return: path to checkpoint folder :rtype: str """ weightspath, paramspath, configpath = self.datapaths(checkpointname) self.config.timestamp = time.time() self.network.save_weights(weightspath) state = self.state if hasattr(self.state, '__iter__') else [self.state] Config(self.parameters(*state)).save(paramspath) Config(self.config).save(configpath) return join(self.basepath, checkpointname)
[docs] def save_best(self, score, checkpointname='checkpoint', isloss=False): """ Save best network weights and parameters under the given name. :param float|int score: Some score indicating quality of network. :param str checkpointname: Name of checkpoint folder. :param bool isloss: True, score is a loss and lower is better otherwise higher is better. :return: path to checkpoint folder :rtype: str """ bestscore = self.config.bestscore if (bestscore is None or (isloss and score < bestscore) or (not isloss and score > bestscore)): self.config.bestscore = score self.config.isloss = isloss self.save(checkpointname) return join(self.basepath, checkpointname)
[docs] def load(self, checkpointname=None): """ Create network, load weights and parameters. :param str|none checkpointname: Name of checkpoint to load. If None the most recent checkpoint is used. If no checkpoint exists yet the network will be created but no weights loaded and the default configuration will be returned. :return: whatever self.create_net returns :rtype: object """ weightspath, paramspath, configpath = self.datapaths(checkpointname) params = Config().load(paramspath) if paramspath else None state = self.create_net(**params) if params else self.create_net() self.network = state[0] if hasattr(state, '__iter__') else state self.state = state if weightspath: self.network.load_weights(weightspath) defaultconfig = Config(bestscore=None, timestamp=None) self.config = Config().load(configpath) if configpath else defaultconfig return state