Source code for nutsml.booster

"""
.. module:: booster
   :synopsis: Boosting of wrongly predicted samples
"""

import random as rnd
import numpy as np

from nutsflow.common import StableRandom
from nutsflow import nut_processor, Tee, Collect, Flatten, Print


[docs]@nut_processor def Boost(iterable, batcher, network, rand=None): """ iterable >> Boost(batcher, network, rand=None) Boost samples with high softmax probability for incorrect class. Expects one-hot encoded targets and softmax predictions for output. NOTE: prefetching of batches must be disabled when using boosting! | network = Network() | build_batch = BuildBatch(BATCHSIZE, prefetch=0).input(...).output(...) | boost = Boost(build_batch, network) | samples >> ... ?>> boost >> build_batch >> network.train() >> Consume() :param iterable iterable: Iterable with samples. :param nutsml.BuildBatch batcher: Batcher used for network training. :param nutsml.Network network: Network used for prediction :param Random|None rand: Random number generator used for down-sampling. If None, random.Random() is used. :return: Generator over samples to boost :rtype: generator """ def do_boost(probs, target): assert len(target) > 1, 'Expect one-hot encoded target: ' + str(target) assert len(target) == len(probs), 'Expect softmax probs: ' + str(probs) return rand.random() > probs[np.argmax(target)] assert batcher.prefetch == 0, 'Disable prefetch when boosting' rand = rnd.Random() if rand is None else rand samples1, samples2 = iterable >> Tee(2) for batch in samples1 >> batcher: inputs, targets = batch tars = targets[0] preds = iter(inputs) >> network.predict() >> Collect() for p,t,s in zip(preds, tars, samples2): if do_boost(p, t): yield s