Source code for nutsml.plotter

"""
.. module:: plotter
   :synopsis: Plotting of data, e.g. loss over epochs
"""
import time

import itertools as itt
import matplotlib.pyplot as plt

from six.moves import range
from nutsflow import NutFunction
from nutsflow.common import as_tuple, as_list


[docs]class PlotLines(NutFunction): # pragma no coverage """ Plot line graph for selected data columns. """
[docs] def __init__(self, ycols, xcols=None, layout=(1, None), titles=None, every_sec=0, every_n=0, filterfunc=lambda data: True, figsize=None, filepath=None): """ iterable >> PlotLines(ycols) >> Consume() >>> import os >>> import numpy as np >>> from nutsflow import Consume >>> fp = 'tests/data/temp_plotter.png' >>> xs = np.arange(0, 6.3, 1.2) >>> ysin, ycos = np.sin(xs), np.cos(xs) >>> data = zip(xs, ysin, ycos) >>> data >> PlotLines(1, 0, filepath=fp) >> Consume() >>> list(ycos) >> PlotLines(0, filepath=fp) >> Consume() >>> data >> PlotLines(ycols=(1,2), filepath=fp) >> Consume() >>> ysin.tolist() >> PlotLines(ycols=None, filepath=fp) >> Consume() >>> if os.path.exists(fp): os.remove(fp) :param int|tuple|None ycols: Index or tuple of indices of the data columns that contain the y-data for the plot. If None data is used directly. :param int|tuple|function|iterable|None xcols: Index or tuple of indices of the data columns that contain the x-data for the plot. Alternatively an iterator or a function can be provided that generates the x-data for the plot, e.g. xcols = itertools.count() or xcols = lambda: epoch For xcols==None, itertools.count() will be used. :param tuple layout: Rows and columns of the plotter layout., e.g. a layout of (2,3) means that 6 plots in the data are arranged in 2 rows and 3 columns. Number of cols can be None is then derived from ycols :param float every_sec: Plot every given second, e.g. to plot every 2.5 sec every_sec = 2.5 :param int every_n: Plot every n-th call. :param function filterfunc: Boolean function to filter plot data. :param tuple figsize: Figure size in inch. :param filepath: Path to a file to draw plot to. If provided the plot will not appear on the screen. :return: Returns input unaltered :rtype: any """ self.ycols = [-1] if ycols is None else as_list(ycols) self.xcols = itt.count() if xcols is None else xcols self.filepath = filepath self.figsize = figsize self.titles = titles self.cnt = 0 self.time = time.time() self.filterfunc = filterfunc self.every_sec = every_sec self.every_n = every_n r, c, n = layout[0], layout[1], len(self.ycols) if c is None: c = n self.figure = plt.figure(figsize=figsize) self.axes = [self.figure.add_subplot(r, c, i + 1) for i in range(n)] self.reset()
def __delta_sec(self): """Return time in seconds (float) consumed between plots so far""" return time.time() - self.time def __should_plot(self, data): """Return true if data should be plotted""" self.cnt += 1 return (self.filterfunc(data) and self.cnt >= self.every_n and self.__delta_sec() >= self.every_sec)
[docs] def reset(self): """Reset plot data""" self.xdata, self.ydata = [], [] for _ in self.ycols: self.xdata.append([]) self.ydata.append([])
def _add_data(self, data): """Add data point to data buffer""" if hasattr(data, 'ndim'): # is it a numpy array? data = data.tolist() if data.ndim else [data.item()] else: data = as_list(data) if hasattr(self.xcols, '__iter__'): x = next(self.xcols) for i, _ in enumerate(self.ycols): self.xdata[i].append(x) elif hasattr(self.xcols, '__call__'): x = self.xcols() for i, _ in enumerate(self.ycols): self.xdata[i].append(x) else: for i, xcol in enumerate(as_tuple(self.xcols)): self.xdata[i].append(data[xcol]) for i, ycol in enumerate(self.ycols): self.ydata[i].append(data if ycol < 0 else data[ycol])
[docs] def __call__(self, data): """Plot data""" if not self.__should_plot(data): return data self.cnt = 0 # reset counter self.time = time.time() # reset timer self._add_data(data) for i, ax in enumerate(self.axes): ax.clear() if self.titles: ax.set_title(self.titles[i]) ax.plot(self.xdata[i], self.ydata[i], '-') ax.figure.canvas.draw() if self.filepath: self.figure.savefig(self.filepath, bbox_inches='tight') else: plt.pause(0.0001) # Needed to draw return data