glitchgan.utils

Training utilities for cDVGAN.

Includes: - GlitchDataset — PyTorch Dataset for GAN training data - train_gan() — main training loop - generate_examples() — vertex / simplex / uniform class sampling - plot_losses() — loss curve plotting - plot_q_transform() — Q-transform plot via gwpy - save_checkpoint() — save model state dicts - load_checkpoint() — restore model state dicts - whitened_snr_scaling() — scale a signal to a target SNR in the whitened frame

Module Contents

class glitchgan.utils.GlitchDataset(signals, class_array, derivs=None, derivs2=None)[source]

Bases: torch.utils.data.Dataset

Dataset for cWGAN / cDVGAN training.

Parameters:
  • signals (np.ndarray (N, L)) – Raw glitch time series.

  • class_array (np.ndarray (N, num_classes)) – One-hot class labels.

  • derivs (np.ndarray (N, L-1) or None) – First derivatives (required for cDVGAN / cDVGAN2).

  • derivs2 (np.ndarray (N, L-2) or None) – Second derivatives (required for cDVGAN2 only).

signals[source]
classes[source]
derivs[source]
derivs2[source]
glitchgan.utils.train_gan(gan, dataset, epochs=500, batch_size=64, save_every=25, monitor_every=1, output_dir='GAN_outputs', variant='cDVGAN', noise_dim=100, num_classes=7, start_epoch=1)[source]

Train a GAN model.

Parameters:
  • gan (cWGAN | cDVGAN | cDVGAN2) – A GAN instance from glitchgan.gan_models.

  • dataset (GlitchDataset)

  • epochs (int) – Total number of epochs to reach (not additional epochs to run).

  • batch_size (int)

  • save_every (int) – Save full checkpoint and multi-sample example plots every N epochs.

  • monitor_every (int) – Save a single vertex-sample monitor plot every N epochs (default 1).

  • output_dir (str) – Directory to save checkpoints, loss history and example plots.

  • variant (str) – One of "cWGAN", "cDVGAN", "cDVGAN2".

  • noise_dim (int)

  • num_classes (int)

  • start_epoch (int) – Epoch to start from (1 for fresh training, or resumed checkpoint epoch + 1).

Returns:

Loss history — keys are loss names, values are lists over epochs.

Return type:

dict

glitchgan.utils.save_checkpoint(gan, variant, output_dir, epoch='last')[source]

Save model and optimizer state dicts to output_dir.

glitchgan.utils.load_checkpoint(gan, output_dir, epoch='last', device='cpu')[source]

Load model and optimizer state dicts from output_dir into a GAN instance.

glitchgan.utils.generate_examples(gan, noise_dim=100, num_classes=7, num_signals=10, sampling='vertex', device='cpu')[source]

Generate signals using a trained generator.

Parameters:
  • gan (GAN instance)

  • noise_dim (int)

  • num_classes (int)

  • num_signals (int)

  • sampling (str) –

    One of "vertex", "simplex", "uniform".

    • "vertex" — pure one-hot class vectors (hard class assignment)

    • "simplex" — random convex combinations (sum to 1)

    • "uniform" — independent uniform draws per class dimension

  • device (str or torch.device)

Returns:

  • signals (np.ndarray (num_signals, signal_length))

  • class_vectors (np.ndarray (num_signals, num_classes))

glitchgan.utils.plot_losses(history, variant, output_dir)[source]

Plot and save training loss curves.

Parameters:
  • history (dict) – Keys are loss names, values are lists of per-epoch values.

  • variant (str)

  • output_dir (str)

glitchgan.utils.plot_q_transform(data, srate=4096.0, crop=None, whiten=True, ax=None, colourbar=True, qrange=(4, 64), frange=(10, 1200), clim=(0, 25.5))[source]

Plot the Q-transform of a time series using gwpy.

Parameters:
  • data (array-like) – Input time-domain data.

  • srate (float) – Sample rate in Hz.

  • crop (tuple or None) – (center_time, duration) in seconds. center_time is measured from the start of data (i.e. t0=0). If None the full segment is used.

  • whiten (bool) – If True, apply gwpy’s internal whitening before the Q-transform. Pass False when the signal is already whitened or noise-free.

  • ax (matplotlib.axes.Axes or None) – Axes to plot on. A new figure is created if None.

  • colourbar (bool) – Add a colorbar to the plot.

  • qrange (tuple) – (q_min, q_max) range for the Q-transform.

  • frange (tuple) – (f_min, f_max) frequency range in Hz.

  • clim (tuple) – (vmin, vmax) colour axis limits for normalised energy.

Return type:

matplotlib.axes.Axes

glitchgan.utils.whitened_snr_scaling(glitch, snr, srate=4096)[source]

Scale a glitch signal to a target SNR in the whitened frame.

Computes the true optimal SNR of the signal via its one-sided power spectral density, then rescales so that the injected signal has the requested SNR.

Parameters:
  • glitch (array-like, shape (..., N)) – Time-domain glitch signal(s).

  • snr (float or None) – Target SNR. If None the signal is returned unchanged.

  • srate (int) – Sample rate in Hz (default 4096).

Returns:

Rescaled glitch signal(s), same shape as input.

Return type:

numpy.ndarray