"""
.. module:: stratify
:synopsis: Stratification of sample sets
"""
from __future__ import absolute_import
import random as rnd
from nutsflow import nut_processor, nut_sink, Sort
from nutsml.datautil import upsample, random_downsample
[docs]@nut_processor
def Stratify(iterable, labelcol, labeldist, rand=None):
"""
iterable >> Stratify(labelcol, labeldist, rand=None)
Stratifies samples by randomly down-sampling according to the given
label distribution. In detail: samples belonging to the class with the
smallest number of samples are returned with probability one. Samples
from other classes are randomly down-sampled to match the number of
samples in the smallest class.
Note that in contrast to SplitRandom, which generates the same random
split per default, Stratify generates different stratifications.
Furthermore, while the downsampling is random the order of samples
remains the same!
While labeldist needs to be provided or computed upfront the actual
stratification occurs online and only one sample per time is stored
in memory.
>>> from nutsflow import Collect, CountValues
>>> from nutsflow.common import StableRandom
>>> fix = StableRandom(1) # Stable random numbers for doctest
>>> samples = [('pos', 1), ('pos', 1), ('neg', 0)]
>>> labeldist = samples >> CountValues(1)
>>> samples >> Stratify(1, labeldist, rand=fix) >> Sort()
[('neg', 0), ('pos', 1)]
:param iterable over tuples iterable: Iterable of tuples where column
labelcol contains a sample label that is used for stratification
:param int labelcol: Column of tuple/samples that contains the label,
:param dict labeldist: Dictionary with numbers of different labels,
e.g. {'good':12, 'bad':27, 'ugly':3}
:param Random|None rand: Random number generator used for down-sampling.
If None, random.Random() is used.
:return: Stratified samples
:rtype: Generator over tuples
"""
rand = rnd.Random() if rand is None else rand
min_n = float(min(labeldist.values()))
probs = {l: min_n / n for l, n in labeldist.items()}
for sample in iterable:
label = sample[labelcol]
if rand.random() < probs[label]:
yield sample
[docs]@nut_sink
def CollectStratified(iterable, labelcol, mode='downrnd', container=list,
rand=None):
"""
iterable >> CollectStratified(labelcol, mode='downrnd', container=list,
rand=rnd.Random())
Collects samples in a container and stratifies them by either randomly
down-sampling classes or up-sampling classes by duplicating samples.
>>> from nutsflow import Collect
>>> samples = [('pos', 1), ('pos', 1), ('neg', 0)]
>>> samples >> CollectStratified(1) >> Sort()
[('neg', 0), ('pos', 1)]
:param iterable over tuples iterable: Iterable of tuples where column
labelcol contains a sample label that is used for stratification
:param int labelcol: Column of tuple/samples that contains the label
:param string mode:
'downrnd' : randomly down-sample
'up' : up-sample
:param container container: Some container, e.g. list, set, dict
that can be filled from an iterable
:param Random|None rand: Random number generator used for sampling.
If None, random.Random() is used.
:return: Stratified samples
:rtype: List of tuples
"""
rand = rnd.Random() if rand is None else rand
samples = list(iterable)
if mode == 'up':
stratified = upsample(samples, labelcol, rand)
elif mode == 'downrnd':
stratified = random_downsample(samples, labelcol, rand)
else:
raise ValueError('Unknown mode: ' + mode)
return container(stratified)