glitchgan.utils =============== .. py:module:: glitchgan.utils .. autoapi-nested-parse:: Training utilities for cDVGAN. Includes: - GlitchDataset — PyTorch Dataset for GAN training data - train_gan() — main training loop - generate_examples() — vertex / simplex / uniform class sampling - plot_losses() — loss curve plotting - plot_q_transform() — Q-transform plot via gwpy - save_checkpoint() — save model state dicts - load_checkpoint() — restore model state dicts - whitened_snr_scaling() — scale a signal to a target SNR in the whitened frame Module Contents --------------- .. py:class:: GlitchDataset(signals, class_array, derivs=None, derivs2=None) Bases: :py:obj:`torch.utils.data.Dataset` Dataset for cWGAN / cDVGAN training. :param signals: Raw glitch time series. :type signals: np.ndarray (N, L) :param class_array: One-hot class labels. :type class_array: np.ndarray (N, num_classes) :param derivs: First derivatives (required for cDVGAN / cDVGAN2). :type derivs: np.ndarray (N, L-1) or None :param derivs2: Second derivatives (required for cDVGAN2 only). :type derivs2: np.ndarray (N, L-2) or None .. py:attribute:: signals .. py:attribute:: classes .. py:attribute:: derivs .. py:attribute:: derivs2 .. py:function:: train_gan(gan, dataset, epochs=500, batch_size=64, save_every=25, monitor_every=1, output_dir='GAN_outputs', variant='cDVGAN', noise_dim=100, num_classes=7, start_epoch=1) Train a GAN model. :param gan: A GAN instance from glitchgan.gan_models. :type gan: cWGAN | cDVGAN | cDVGAN2 :param dataset: :type dataset: GlitchDataset :param epochs: Total number of epochs to reach (not additional epochs to run). :type epochs: int :param batch_size: :type batch_size: int :param save_every: Save full checkpoint and multi-sample example plots every N epochs. :type save_every: int :param monitor_every: Save a single vertex-sample monitor plot every N epochs (default 1). :type monitor_every: int :param output_dir: Directory to save checkpoints, loss history and example plots. :type output_dir: str :param variant: One of ``"cWGAN"``, ``"cDVGAN"``, ``"cDVGAN2"``. :type variant: str :param noise_dim: :type noise_dim: int :param num_classes: :type num_classes: int :param start_epoch: Epoch to start from (1 for fresh training, or resumed checkpoint epoch + 1). :type start_epoch: int :returns: Loss history — keys are loss names, values are lists over epochs. :rtype: dict .. py:function:: save_checkpoint(gan, variant, output_dir, epoch='last') Save model and optimizer state dicts to output_dir. .. py:function:: load_checkpoint(gan, output_dir, epoch='last', device='cpu') Load model and optimizer state dicts from output_dir into a GAN instance. .. py:function:: generate_examples(gan, noise_dim=100, num_classes=7, num_signals=10, sampling='vertex', device='cpu') Generate signals using a trained generator. :param gan: :type gan: GAN instance :param noise_dim: :type noise_dim: int :param num_classes: :type num_classes: int :param num_signals: :type num_signals: int :param sampling: One of ``"vertex"``, ``"simplex"``, ``"uniform"``. - ``"vertex"`` — pure one-hot class vectors (hard class assignment) - ``"simplex"`` — random convex combinations (sum to 1) - ``"uniform"`` — independent uniform draws per class dimension :type sampling: str :param device: :type device: str or torch.device :returns: * **signals** (*np.ndarray (num_signals, signal_length)*) * **class_vectors** (*np.ndarray (num_signals, num_classes)*) .. py:function:: plot_losses(history, variant, output_dir) Plot and save training loss curves. :param history: Keys are loss names, values are lists of per-epoch values. :type history: dict :param variant: :type variant: str :param output_dir: :type output_dir: str .. py:function:: plot_q_transform(data, srate=4096.0, crop=None, whiten=True, ax=None, colourbar=True, qrange=(4, 64), frange=(10, 1200), clim=(0, 25.5)) Plot the Q-transform of a time series using gwpy. :param data: Input time-domain data. :type data: array-like :param srate: Sample rate in Hz. :type srate: float :param crop: ``(center_time, duration)`` in seconds. ``center_time`` is measured from the start of ``data`` (i.e. t0=0). If None the full segment is used. :type crop: tuple or None :param whiten: If True, apply gwpy's internal whitening before the Q-transform. Pass False when the signal is already whitened or noise-free. :type whiten: bool :param ax: Axes to plot on. A new figure is created if None. :type ax: matplotlib.axes.Axes or None :param colourbar: Add a colorbar to the plot. :type colourbar: bool :param qrange: (q_min, q_max) range for the Q-transform. :type qrange: tuple :param frange: (f_min, f_max) frequency range in Hz. :type frange: tuple :param clim: (vmin, vmax) colour axis limits for normalised energy. :type clim: tuple :rtype: matplotlib.axes.Axes .. py:function:: whitened_snr_scaling(glitch, snr, srate=4096) Scale a glitch signal to a target SNR in the whitened frame. Computes the true optimal SNR of the signal via its one-sided power spectral density, then rescales so that the injected signal has the requested SNR. :param glitch: Time-domain glitch signal(s). :type glitch: array-like, shape (..., N) :param snr: Target SNR. If None the signal is returned unchanged. :type snr: float or None :param srate: Sample rate in Hz (default 4096). :type srate: int :returns: Rescaled glitch signal(s), same shape as input. :rtype: numpy.ndarray