Source code for sigmaepsilon.plotting.mpl.parallel

# -*- coding: utf-8 -*-
from typing import Iterable, TypeVar, Hashable, Union

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.path import Path
from matplotlib.patches import PathPatch
import matplotlib.gridspec as gridspec
from matplotlib.widgets import Slider
from matplotlib.figure import Figure
import numpy as np
from numpy import ndarray

from linkeddeepdict import LinkedDeepDict
from sigmaepsilon.core.formatting import float_to_str_sig as str_sig


__all__ = ["parallel", "aligned_parallel"]


TScalar = TypeVar("TScalar", int, float, complex)
TReal = TypeVar("TReal", int, float)
TVector = Iterable[TScalar]
TRealVector = Iterable[TReal]
TColor = TypeVar("TColor", str, TRealVector)


[docs]def parallel( data: Union[dict, list, ndarray], *, labels:Iterable=None, padding:float=0.05, colors:Iterable=None, lw:float=0.2, bezier:bool=True, figsize:tuple=None, title:str=None, ranges:Iterable=None, return_figure:bool=True, **_, ) -> Union[Figure, None]: """ Parameters ---------- data: Union[list, dict, numpy.ndarray] A list of numpy.ndarray for each column. Each array is 1d with a length of N, where N is the number of data records (the number of lines). labels: Iterable, Optional Labels for the columns. If provided, it must have the same length as `data`. padding: float, Optional Controls the padding around the axes. colors: list of float, Optional A value for each record. Default is None. lw: float, Optional Linewidth. bezier: bool, Optional If True, bezier curves are created instead of a linear polinomials. Default is True. figsize: tuple, Optional A tuple to control the size of the figure. Default is None. title: str, Optional The title of the figure. ranges: list of list, Optional Ranges of the axes. If not provided, it is inferred from the input values, but this may result in an error. Default is False. return_figure: bool, Optional If True, the figure is returned. Default is False. Example ------- >>> from sigmaepsilon.mesh.mpl import parallel >>> colors = np.random.rand(150, 3) >>> labels = [str(i) for i in range(10)] >>> values = [np.random.rand(150) for i in range(10)] >>> parallel(values, labels=labels, padding=0.05, lw=0.2, ... colors=colors, title='Parallel Plot with Random Data') <matplotlib.figure.Figure object at ...> """ if isinstance(data, dict): if labels is None: labels = list(data.keys()) ys = np.dstack(list(data.values()))[0] elif isinstance(data, np.ndarray): assert labels is not None ys = data.T elif isinstance(data, Iterable): assert labels is not None ys = np.dstack(data)[0] else: raise TypeError("Invalid data type!") ynames = labels N, nY = ys.shape figsize = (7.5, 3) if figsize is None else figsize fig, host = plt.subplots(figsize=figsize) if ranges is None: ymins = ys.min(axis=0) ymaxs = ys.max(axis=0) ranges = np.stack((ymins, ymaxs), axis=1) else: ranges = np.array(ranges) # make sure that upper and lower ranges are not equal for i in range(nY): rmin, rmax = ranges[i] if abs(rmin - rmax) < 1e-12: rmin -= 1.0 rmax += 1.0 ranges[i] = [rmin, rmax] ymins = ranges[:, 0] ymaxs = ranges[:, 1] dys = ymaxs - ymins ymins -= dys * padding ymaxs += dys * padding dys = ymaxs - ymins # transform all data to be compatible with the main axis zs = np.zeros_like(ys) zs[:, 0] = ys[:, 0] zs[:, 1:] = (ys[:, 1:] - ymins[1:]) / dys[1:] * dys[0] + ymins[0] axes = [host] + [host.twinx() for i in range(nY - 1)] for i, ax in enumerate(axes): ax.set_ylim(ymins[i], ymaxs[i]) ax.spines["top"].set_visible(False) ax.spines["bottom"].set_visible(False) if ax != host: ax.spines["left"].set_visible(False) ax.yaxis.set_ticks_position("right") ax.spines["right"].set_position(("axes", i / (nY - 1))) host.set_xlim(0, nY - 1) host.set_xticks(range(nY)) host.set_xticklabels(ynames, fontsize=8) host.tick_params(axis="x", which="major", pad=7) host.spines["right"].set_visible(False) host.xaxis.tick_top() if title is not None: host.set_title(title, fontsize=12) for j in range(N): if not bezier: # to just draw straight lines between the axes: host.plot(range(nY), zs[j, :], c=colors[j], lw=lw) else: # create bezier curves # for each axis, there will a control vertex at the point itself, one at 1/3rd towards the previous and one # at one third towards the next axis; the first and last axis have one less control vertex # x-coordinate of the control vertices: at each integer (for the axes) and two inbetween # y-coordinate: repeat every point three times, except the first and last only twice verts = list( zip( [ x for x in np.linspace( 0, len(ys) - 1, len(ys) * 3 - 2, endpoint=True ) ], np.repeat(zs[j, :], 3)[1:-1], ) ) # for x,y in verts: host.plot(x, y, 'go') # to show the control points of the beziers codes = [Path.MOVETO] + [Path.CURVE4 for _ in range(len(verts) - 1)] path = Path(verts, codes) patch = PathPatch(path, facecolor="none", lw=lw, edgecolor=colors[j]) host.add_patch(patch) if return_figure: return fig
[docs]def aligned_parallel( data, datapos, *, yticks=None, labels=None, sharelimits=False, texlabels=None, xticksrotation=0, suptitle=None, slider=False, slider_label=None, hlines=None, vlines=None, y=None, xoffset=0.0, yoffset=0.0, **kwargs, ) -> Union[Figure, None]: """ Parameters ---------- data: numpy.ndarray or dict The values to plot. If it is a numpy array, labels must be provided with the argument `labels`, if it is a sictionary, the keys of the dictionary are used as labels. datapos: Iterable Positions of the provided data values. yticks: Iterable, Optional Positions of ticks on the vertical axes. Default is None. labels: Iterable, Optional An iterable of strings specifying labels for the datasets. Default is None. sharelimits: bool, Optional If True, the axes share limits of the vertical axes. Default is False. texlabels: Itrable, Optional TeX-formatted labels. If provided, it must have the same length as `labels`. Default is None. xticksrotation: int, Optional Rotation of the ticks along the vertical axes. Expected in degrees. Default is 0. suptitle: str, Optional See Matplotlib's docs for the details. Default is None. slider: bool, Optional If True, a slider is added to the figure for interactive plots. Default is False. slider_label: str, Optional A label for the slider. Only if `slider` is true. Default is None. hlines: Iterable, Optional A list of data values where horizontal lines are to be added to the axes. Default is None. vlines: Iterable, Optional A list of data values where vertical lines are to be added to the axes. Default is None. y: float or int, Optional Value for the vertical axis. Default is the average of the limits of the vertical axis (0.5*(datapos[0] + datapos[-1])). xoffset: float, Optional Margin of the plot in the vertical direction. Default is 0. yoffset: float, Optional Margin of the plot in the horizontal direction. Default is 0. **kwargs: dict, Optional Extra keyword arguments are forwarded to the creator of the matplotlib figure. Default is None. Example ------- >>> from sigmaepsilon.plotting.mpl import aligned_parallel >>> labels = ['a', 'b', 'c'] >>> values = np.array([np.random.rand(150) for _ in labels]).T >>> datapos = np.linspace(-1, 1, 150) >>> aligned_parallel(values, datapos, labels=labels, yticks=[-1, 1]) <matplotlib.figure.Figure object at ...> """ # init fig = plt.figure(**kwargs) suptitle = "" if suptitle is None else suptitle fig.suptitle(suptitle) plotdata = LinkedDeepDict() axcolor = "lightgoldenrodyellow" ymin, ymax = np.min(datapos), np.max(datapos) if y is None: y0 = 0.5 * (ymin + ymax) else: y0 = y hlines = [] if hlines is None else hlines vlines = [] if vlines is None else vlines # init slider if slider: if slider_label is None: slider_label = "" # read data if isinstance(data, dict): if labels is None: labels = list(data.keys()) elif isinstance(data, np.ndarray): nData = data.shape[1] if labels is None: labels = list(map(str, range(nData))) data = {labels[i]: data[:, i] for i in range(nData)} for lbl in labels: plotdata[lbl]["values"] = data[lbl] # set min and max values _min, _max = [], [] for lbl in labels: plotdata[lbl]["min"] = np.min(data[lbl]) plotdata[lbl]["max"] = np.max(data[lbl]) _min.append(plotdata[lbl]["min"]) _max.append(plotdata[lbl]["max"]) # set global min and max vmin = np.min(_min) vmax = np.max(_max) del _min del _max # setting up figure layout nData = len(labels) if slider: nAxes = nData + 1 # +1 for the Slider else: nAxes = nData width_ratios = [1 for i in range(nData)] if slider: width_ratios.append(0.15) spec = gridspec.GridSpec( ncols=nAxes, nrows=1, width_ratios=width_ratios, figure=fig, wspace=0.2, left=0.1, ) # create axes for i in range(nData): plotid = int("{}{}{}".format(1, nAxes, i + 1)) plotid = spec[0, i] ax = fig.add_subplot(plotid, facecolor=axcolor) ax.grid(False) ax.patch.set_edgecolor("black") ax.patch.set_linewidth("0.5") if i == 0: if yticks is not None: ax.set_yticks(yticks) ax.set_yticklabels([str_sig(val, sig=3) for val in yticks]) else: ax.set_yticks([]) ax.set_yticklabels([]) hline = ax.axhline(y=y0, color="#d62728", linewidth=1) bbox = dict(boxstyle="round", ec="black", fc="yellow", alpha=0.8) txt = ax.text( 0.0, 0.0, "NaN", size=10, ha="center", va="center", visible=False, bbox=bbox ) # horizontal lines ax.axhline(y=yticks[0], color="black", linewidth=0.5, linestyle="-") ax.axhline(y=yticks[-1], color="black", linewidth=0.5, linestyle="-") for hl in hlines: ax.axhline(y=hl, color="black", linewidth=0.5, linestyle="-") # a vertical lines for vl in vlines: ax.axvline(x=vl, color="black", linewidth=0.5, linestyle="-") # store objects plotdata[labels[i]]["ax"] = ax plotdata[labels[i]]["text"] = txt plotdata[labels[i]]["hline"] = hline # create slider if slider: sliderax = fig.add_subplot(spec[0, nAxes - 1], fc=axcolor) """sliderax.patch.set_edgecolor('black') sliderax.patch.set_linewidth('1.0') sliderax.patch.zorder = 10""" slider_ = Slider( sliderax, slider_label, valmin=ymin, valmax=ymax, valinit=0.0, orientation="vertical", valfmt="%.3f", closedmin=True, closedmax=True, ) # track_color=axcolor, color=axcolor """slider_rect = slider_.track s_xy = slider_rect.get_xy() s_w = slider_rect.get_width() s_h = slider_rect.get_height() sliderax.add_patch(Rectangle(s_xy, s_w, s_h, fill=False, ec='k'))""" def _approx_at_y(y: float, plotkey: Hashable): lines2D = plotdata[plotkey]["lines"] values = lines2D.get_xdata() locations = lines2D.get_ydata() return np.interp(y, locations, values) def _set_yval(y): for axkey in plotdata.keys(): if "hline" in plotdata[axkey]: plotdata[axkey]["hline"].set_ydata(y) v_at_y = _approx_at_y(y, axkey) txtparams = { "visible": True, "x": v_at_y, "y": y, "text": str_sig(v_at_y, sig=4), } plotdata[axkey]["text"].update(txtparams) fig.canvas.draw_idle() def _update_slider(y=None): if y is None: y = slider.val _set_yval(y) def _set_xlim(axs: mpl.axes, vmin: float, vmax: float): voffset = (vmax - vmin) * xoffset if abs(vmin - vmax) > 1e-7: axs.set_xlim(vmin - voffset, vmax + voffset) xticks = [vmin, vmax] axs.set_xticks(xticks) rotation = kwargs.get("rotation", xticksrotation) axs.set_xticklabels([str_sig(val, sig=3) for val in xticks], rotation=rotation) def _set_ylim(axs: mpl.axes, vmin: float, vmax: float): voffset = (vmax - vmin) * yoffset axs.set_ylim(vmin - voffset, vmax + voffset) # plot axes for i, axkey in enumerate(plotdata.keys()): axis = plotdata[axkey]["ax"] # set limits if sharelimits == True: _set_xlim(axis, vmin, vmax) else: _set_xlim(axis, plotdata[axkey]["min"], plotdata[axkey]["max"]) _set_ylim(axis, ymin, ymax) # set labels if texlabels is not None: axis.set_title(texlabels[i]) else: axis.set_title(str(axkey)) # plot lines = axis.plot(plotdata[axkey]["values"], datapos, picker=5)[0] plotdata[axkey]["lines"] = lines # connect events if slider: slider_.on_changed(_update_slider) fig._slider = slider_ # to keep reference, otherwise slider is not responsive _set_yval(y0) return fig