Source code for quantumaudio.tools.plot

# Copyright 2024 Moth Quantum
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==========================================================================

from typing import Union

import matplotlib.pyplot as plt
import numpy as np

# ======================
# Plotting Functions
# ======================


[docs] def plot_1d( samples: np.ndarray, title: Union[str, None] = None, label: tuple[str, str] = ("original", "reconstructed"), ) -> None: """Plots the given samples. Args: samples: The samples to plot. title: Title for the plot. Defaults to None. label: Labels for the samples. Defaults to ("original", "reconstructed"). Returns: None """ if not isinstance(samples, list): samples = [samples] if label and not isinstance(label, tuple): label = (label,) num_samples = samples[0].shape[-1] x_axis = np.arange(0, num_samples) for i, y_axis in enumerate(samples): plt.plot( x_axis, y_axis.squeeze(), label=None if not label else label[i] ) plt.xlabel("Index") plt.ylabel("Values") if label: plt.legend() if title: plt.title(title) plt.show()
[docs] def plot( samples: Union[np.ndarray, list[np.ndarray]], title: Union[str, None] = None, label: tuple[str, str] = ("original", "reconstructed"), figsize: tuple[int, int] = (6, 4), ) -> None: """Plots the given samples. It accepts multi-dimensional array and also multiple plots for comparisons. Args: samples: The samples to plot. Can be a single `numpy` array or a list of `numpy` arrays. title: Title for the plot. Defaults to None. label: Labels for the samples. Defaults to ("original", "reconstructed"). figsize: Set the width and height for matplotlib plot Returns: None """ if not isinstance(samples, list): samples = [samples] if label and not isinstance(label, tuple): label = (label,) if len(samples) > len(label): label = [f"Signal {i+1}" for i in range(len(samples))] num_samples = samples[0].shape[-1] num_channels = 1 if samples[0].ndim == 1 else samples[0].shape[-2] x_axis = np.arange(0, num_samples) if num_channels > 1: fig, axs = plt.subplots(num_channels, 1, figsize=figsize) for i, y_axis in enumerate(samples): for c in range(num_channels): axs[c].plot( x_axis, y_axis[c][:num_samples], label=None if not label else label[i], ) axs[c].set_xlabel("Index") axs[c].set_ylabel("Values") axs[c].set_title(f"channel {c+1}") if label: axs[c].legend(loc="upper right") axs[c].grid(True) plt.tight_layout() else: plt.figure(figsize=figsize) for i, y_axis in enumerate(samples): if isinstance(y_axis, np.ndarray): y_axis = y_axis.squeeze() plt.plot(x_axis, y_axis, label=None if not label else label[i]) plt.xlabel("Index") plt.ylabel("Values") if label: plt.legend() plt.grid(True) if title: plt.title(title) plt.show()