"""
.. module:: iterfunction
:synopsis: Functions that work with iterables.
See https://docs.python.org/2/library/itertools.html
"""
import six
import itertools as itt
import threading as t
import collections as cl
from six.moves import queue as q
from six import advance_iterator
from six.moves import map, filter, filterfalse
[docs]def length(iterable):
"""
Return number of elements in iterable. Consumes iterable!
>>> length(range(10))
10
:param iterable iterable: Any iterable, e.g. list, range, ...
:return: Length of iterable.
:rtype: int
"""
return sum(1 for _ in iterable)
[docs]def interleave(*iterables):
"""
Return generator that interleaves the elements of the iterables.
>>> list(interleave(range(5), 'abc'))
[0, 'a', 1, 'b', 2, 'c', 3, 4]
>>> list(interleave('12', 'abc', '+-'))
['1', 'a', '+', '2', 'b', '-', 'c']
:param iterable iterables: Collection of iterables, e.g. lists, range, ...
:return: Interleaved iterables.
:rtype: iterator
"""
pending = len(iterables)
fnext = lambda it: lambda: advance_iterator(it)
nexts = itt.cycle(fnext(iter(it)) for it in iterables)
while pending:
try:
for nxt in nexts:
yield nxt()
except StopIteration:
pending -= 1
nexts = itt.cycle(itt.islice(nexts, pending))
[docs]def take(iterable, n):
"""
Return iterator over last n elements of given iterable.
>>> list(take(range(10), 3))
[0, 1, 2]
See: https://docs.python.org/2/library/itertools.html#itertools.islice
:param iterable iterable: Any iterable, e.g. list, range, ...
:param int n: Number of elements to take
:return: Iterator over last n elements
:rtype: iterator
"""
return itt.islice(iterable, n)
[docs]def nth(iterable, n, default=None):
"""
Return n-th element of iterable. Consumes iterable!
>>> nth(range(10), 2)
2
>>> nth(range(10), 100, default=-1)
-1
https://docs.python.org/2/library/itertools.html#itertools.islice
:param iterable iterable: Any iterable, e.g. list, range, ...
:param n: Index of element to retrieve.
:param default: Value to return when iterator is depleted
:return: nth element
:rtype: Any or default value.
"""
return next(itt.islice(iterable, n, None), default)
[docs]def unique(iterable, key=None):
"""
Return only unique elements in iterable. Potentially high mem. consumption!
>>> list(unique([2,3,1,1,2,4]))
[2, 3, 1, 4]
>>> ''.join(unique('this is a test'))
'this ae'
>>> data = [(1,'a'), (2,'a'), (3,'b')]
>>> list(unique(data, key=lambda t: t[1]))
[(1, 'a'), (3, 'b')]
:param iterable iterable: Any iterable, e.g. list, range, ...
:param key: Function used to compare for equality.
:return: Iterator over unique elements.
:rtype: Iterator
"""
seen = set()
for e in iterable:
k = key(e) if key else e
if k not in seen:
seen.add(k)
yield e
[docs]def chunked(iterable, n):
"""
Split iterable in chunks of size n, where each chunk is also an iterator.
for chunk in chunked(range(10), 3):
for element in chunk:
print element
>>> it = chunked(range(7), 2)
>>> list(map(tuple, it))
[(0, 1), (2, 3), (4, 5), (6,)]
:param iterable iterable: Any iterable, e.g. list, range, ...
:param n: Chunk size
:return: Chunked iterable
:rtype: Iterator over iterators
"""
it = iter(iterable)
while True:
chunk_it = itt.islice(it, n)
try:
first_el = next(chunk_it)
except StopIteration:
return
yield itt.chain((first_el,), chunk_it)
[docs]def consume(iterable, n=None):
"""
Consume n elements of the iterable.
>>> it = iter([1,2,3,4])
>>> consume(it, 2)
>>> next(it)
3
See https://docs.python.org/2/library/itertools.html
:param iterable iterable: Any iterable, e.g. list, range, ...
:param n: Number of elements to consume. For n=None all are consumed.
"""
if n is None:
cl.deque(iterable, maxlen=0)
else:
next(itt.islice(iterable, n, n), None)
[docs]def flatten(iterable):
"""
Return flattened iterable.
>>> list(flatten([(1,2), (3,4,5)]))
[1, 2, 3, 4, 5]
:param iterable iterable:
:return: Iterator over flattened elements of iterable
:rtype: Iterator
"""
return itt.chain(*iterable)
[docs]def flatmap(func, iterable):
"""
Map function to iterable and flatten.
>>> f = lambda n: str(n) * n
>>> list( flatmap(f, [1, 2, 3]) )
['1', '2', '2', '3', '3', '3']
>>> list( map(f, [1, 2, 3]) ) # map instead of flatmap
['1', '22', '333']
:param function func: Function to map on iterable.
:param iterable iterable: Any iterable, e.g. list, range, ...
:return: Iterator of iterable elements transformed via func and flattened.
:rtype: Iterator
"""
return itt.chain.from_iterable(map(func, iterable))
[docs]def partition(iterable, pred):
"""
Split iterable into two partitions based on predicate function
>>> pred = lambda x: x < 6
>>> smaller, larger = partition(range(10), pred)
>>> list(smaller)
[0, 1, 2, 3, 4, 5]
>>> list(larger)
[6, 7, 8, 9]
:param iterable: Any iterable, e.g. list, range, ...
:param pred: Predicate function.
:return: Partition iterators
:rtype: Two iterators
"""
t1, t2 = itt.tee(iterable)
return filter(pred, t1), filterfalse(pred, t2)
[docs]class PrefetchIterator(t.Thread, six.Iterator):
"""
Wrap an iterable in an iterator that prefetches elements.
Typically used to fetch samples or batches while the the GPU processes
the batch. Keeps the CPU busy pre-processing data and not waiting for the
GPU to finish the batch.
>>> from __future__ import print_function
>>> for i in PrefetchIterator(range(4)):
... print(i)
0
1
2
3
"""
[docs] def __init__(self, iterable, num_prefetch=1):
"""
Constructor.
:param iterable iterable: Iterable elements are fetched from.
:param int num_prefetch: Number of elements to pre-fetch.
"""
t.Thread.__init__(self)
self.queue = q.Queue(num_prefetch)
self.iterable = iterable
self.daemon = True
self.lock = t.Lock()
self.start()
[docs] def run(self):
"""
Put elements in input iterable into queue.
"""
for item in self.iterable:
self.queue.put(item)
self.queue.put(None)
def __next__(self):
"""
Return next element from pre-fetch iterator.
:return: element from iterator
:rtype: same as element type of input iterable.
"""
with self.lock:
next_item = self.queue.get()
if next_item is None:
raise StopIteration
return next_item
def __iter__(self):
"""
Return pre-fetch iterator
:return: pre-fetch iterator
:rtype: PrefetchIterator
"""
return self