Source code for pychopmarg.plot

"""
General plotting utilities for PyChOpMarg.

Original author: David Banas <capn.freako@gmail.com>

Original date:   January 17, 2025

Copyright (c) 2025 David Banas; all rights reserved World wide.
"""

from enum   import Enum
from random import sample
from typing import Any, Callable, Optional

from matplotlib         import pyplot as plt  # type: ignore
from matplotlib.axes    import Axes           # type: ignore
import numpy as np

from pychopmarg.com     import COM
from pychopmarg.common  import Cvec
from pychopmarg.utility import s2p_pulse_response


[docs] class ZoomMode(Enum): "Plot zoom extent." FULL = 1 # Use all available data. ISI = 2 # Show the full ISI sampling used in Rx FFE tap weight optimization. PULSE = 3 # Zoom in on the pulse, to inspect its shape in detail. MANUAL = 4 # x-axis min. & max. specified by caller. RELATIVE = 5 # x-axis min. & max. specified by caller, relative to pulse peak location.
[docs] def plot_group_samps( # noqa=501 pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-branches,too-many-statements plot_func: Callable[[COM, str, str, str, dict[str, str], Any, Any], None], x_lbl: str, y_lbls: tuple[str, str], coms: list[tuple[str, dict[str, str], dict[str, dict[str, COM]]]], maxRows: int = 3, dx=4, dy=3, chnls_by_grp: Optional[dict[str, list[str]]] = None, auto_yscale: bool = False ) -> dict[str, list[str]]: """ Call the given plotting function for several randomly chosen channel sets from each available group. Args: plot_func: The plotting function to use. Should take the following arguments: - com: The COM object to use for plotting. - grp: Channel set group name - lbl: Channel set name - name: COM set name - opts: Plotting options to use - ax1: First y-axis - ax2: Second y-axis x_lbl: Label for x-axis y_lbls: Pair of labels, one for each y-axis coms: List of tuples, each containing: - name: Identifying name, - opts: Plotting options to use, - coms: dictionary of COM objects to select from. Should be indexed first by group name then by channel set name. Keyword Args: maxRows: Maximum number of rows desired in resultant plot matrix. (Number of columns is equal to number of groups.) Default: 3 dx: Width of individual plots (in.) Default: 4 dy: Height of individual plots (in.) Default: 3 chnls_by_grp: Dictionary of key/value pairs of the form: <group name>: [<channel set name>]. (Used to enforce identical channel set choices across multiple calls.) Default: None auto_yscale: When ``True``, scale the y-axis to just accommodate the visible portion of the plotted waveforms. Default: False Returns: Dictionary containing lists of channel sets used by group name (for subsequent calls). Raises: KeyError: If there are any inconsistencies in dictionary key naming, either within the list of COMs given or between those COMs and the ``chnls_by_grp`` keyword argument, if provided. """ group_names = list(coms[0][2].keys()) nCols = len(group_names) nRows = min(maxRows, min(list(map(lambda grp: len(list(coms[0][2][grp].keys())), # pylint: disable=nested-min-max group_names)))) _, axs = plt.subplots(nRows, nCols, figsize=(dx * nCols, dy * nRows)) # Handle singleton case gracefully. if not isinstance(axs[0][0], Axes): axs = [[axs,],] n = 0 print(" ", end="") chnls_used: dict[str, list[str]] = {} for grp in group_names: # pylint: disable=too-many-nested-blocks print(f"{grp : ^45s}", end="") chnls_used.update({grp: []}) if chnls_by_grp: chnls = chnls_by_grp[grp] else: chnls = sample(sorted(coms[0][2][grp].keys()), nRows) # `sorted` is necessary. chnls = sorted(chnls) for lbl in chnls: chnls_used[grp].append(lbl) col, row = divmod(n, nRows) ax1 = axs[row][col] ax2 = ax1.twinx() plt.tight_layout() for nm, opts, d in coms: com = d[grp][lbl] plot_func(com, grp, lbl, nm, opts, ax1, ax2) if auto_yscale: # Set y-limits automatically. for ax in [ax1, ax2]: xmin, xmax = ax.get_xlim() ymin = 1e6 ymax = -1e6 for line in ax.lines: xdata, ydata = line.get_data() xmin_ixs = np.where(xdata >= max(xdata[0], xmin))[0] if len(xmin_ixs): xmin_ix = xmin_ixs[0] else: continue xmax_ix = np.where(xdata >= min(xdata[-1], xmax))[0][0] y_values = ydata[xmin_ix: xmax_ix] if len(y_values): # Ignore the `pylint` suggestion; it causes an exception: ymin = min(min(y_values), ymin) ymax = max(max(y_values), ymax) delta_y = ymax - ymin if delta_y > 0: ymin -= 0.1 * delta_y ymax += 0.1 * delta_y else: ymin = ymax = 0 ax.axis(ymin=ymin, ymax=ymax) if row == nRows - 1: ax1.set_xlabel(x_lbl) if col == 0: ax1.set_ylabel(y_lbls[0]) if col == nCols - 1: ax2.set_ylabel(y_lbls[1]) n += 1 plt.show() return chnls_used
[docs] def plot_pulse_resps_gen( # pylint: disable=too-many-statements zoom: ZoomMode, noeq: bool = False, nopkg: bool = False, plot_ntwk: bool = True, xlims: Optional[tuple[float, float]] = None, ) -> Callable[[COM, str, str, str, dict[str, str], Any, Any], None]: """ Generate a pulse response plotting function for use with ``plot_group_samps()``. Args: zoom: Zoom mode. Keyword Args: noeq: Plot unequalized pulse response when ``True``. Default: ``False`` nopkg: Plot raw channel pulse response when ``True``. (Takes priority over ``noeq``.) Default: ``False`` plot_ntwk: Add SciKit-RF pulse response estimate to plot when ``True``. (Only valid when either ``nopkg`` or ``noeq`` is ``True``.) Default: ``True`` xlims: X-axis min. & max., for use w/ `zoom` = MANUAL. Default: None Returns: Pulse response plotting function suitable for sending to ``plot_group_samps()``. ToDo: 1. Add a fourth pulse response option: pre-FFE. """ def plot_pulse_resps( # noqa=501 pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-statements,unused-argument com: COM, grp: str, lbl: str, nm: str, opts: dict[str, str], ax1: Any, ax2: Any ) -> None: """ Plot pulse response. Args: com: The COM instance to use for plotting. (Should already have been called, to optimize EQ.) grp: The channel group to use for plotting. (Should match the name of an immediate subfolder of your top-level channel data folder.) lbl: The channel set name within the channel group. (Should match the stem of the "<lbl>_{THRU,NEXT,FEXT}.s4p" file names.) nm: Extra identification information available to caller (e.g. - "MMSE" vs. "PRZF"). opts: Plotting options. (See the ``matplotlib.pyplot`` documentation.) ax1: Axis to use for plotting against the left y-axis. ax2: Axis to use for plotting against the right y-axis. """ ui = com.ui nspui = com.nspui t = com.times Av = com.com_params.A_v nRxTaps = len(com.com_params.rx_taps_min) nRxPreTaps = com.com_params.dw # Find cursor location. if nopkg: curs_ix = np.argmax(com.pulse_resps_nopkg[0]) elif noeq: curs_ix = np.argmax(com.pulse_resps_noeq[0]) else: curs_ix = com.fom_rslts["cursor_ix"] # Plot the data. clr = opts["color"] if nopkg: ax1.plot(t * 1e9, com.pulse_resps_nopkg[0] * 1e3, label=nm, color=clr) if plot_ntwk and nm in ["MMSE", "PyChOpMarg"]: _, y = s2p_pulse_response(com.chnls_noPkg[0][0][0], ui, t) ax1.plot(t * 1e9, y * Av * 1e3, label="SciKit-RF", color=clr, linestyle="dashed") elif noeq: ax1.plot(t * 1e9, com.pulse_resps_noeq[0] * 1e3, label=nm, color=clr) if plot_ntwk and nm == "MMSE": _, y = s2p_pulse_response(com.chnls[0][0][0], ui, t) ax1.plot(t * 1e9, y * Av * 1e3, label="SciKit-RF", color=clr, linestyle="dashed") else: ax1.plot(t * 1e9, com.com_rslts["pulse_resps"][0] * 1e3, label=nm, color=clr) # Set x-limits appropriately, as per user requested zoom option. match zoom: case ZoomMode.FULL: xmin = t[0] * 1e9 xmax = t[-1] * 1e9 case ZoomMode.ISI: first_ix = curs_ix - 2 * nRxPreTaps * nspui last_ix = first_ix + 4 * nRxTaps * nspui xmin = t[first_ix] * 1e9 xmax = t[last_ix] * 1e9 case ZoomMode.PULSE: first_ix = curs_ix - nRxPreTaps * nspui last_ix = first_ix + nRxTaps * nspui xmin = t[first_ix] * 1e9 xmax = t[last_ix] * 1e9 plt.axvline(t[curs_ix] * 1e9, color=clr, linestyle="-") if (not (nopkg or noeq)) and nm == "MMSE": ax1.plot(com.com_rslts["tISI"] * 1e9, com.com_rslts["hISI"] * 1e3, "xk", label="ISI") case ZoomMode.MANUAL: assert xlims, ValueError( "X-axis limits must be provided in manual zoom mode!" ) xmin = xlims[0] xmax = xlims[1] case ZoomMode.RELATIVE: assert xlims, ValueError( "X-axis limits must be provided in relative zoom mode!" ) curs_t = t[curs_ix] xmin = xlims[0] + curs_t * 1e9 xmax = xlims[1] + curs_t * 1e9 case _: raise RuntimeError( f"Unrecognized zoom mode value ({zoom}) received!" ) # Finalize plot configuration. ax1.axis(xmin=xmin, xmax=xmax) ax1.legend(loc="upper right") ax1.grid() plt.title(f"{lbl[-25: -5]}") return plot_pulse_resps
[docs] def plot_H(H: Cvec) -> None: "Plot magnitude and phase of given frequency response." plt.figure(figsize=(10, 3)) plt.subplot(121) plt.plot(abs(H)) plt.subplot(122) plt.plot(np.angle(H))