glitchgan.utils¶
Training utilities for cDVGAN.
Includes: - GlitchDataset — PyTorch Dataset for GAN training data - train_gan() — main training loop - generate_examples() — vertex / simplex / uniform class sampling - plot_losses() — loss curve plotting - plot_q_transform() — Q-transform plot via gwpy - save_checkpoint() — save model state dicts - load_checkpoint() — restore model state dicts - whitened_snr_scaling() — scale a signal to a target SNR in the whitened frame
Module Contents¶
- class glitchgan.utils.GlitchDataset(signals, class_array, derivs=None, derivs2=None)[source]¶
Bases:
torch.utils.data.DatasetDataset for cWGAN / cDVGAN training.
- Parameters:
signals (np.ndarray (N, L)) – Raw glitch time series.
class_array (np.ndarray (N, num_classes)) – One-hot class labels.
derivs (np.ndarray (N, L-1) or None) – First derivatives (required for cDVGAN / cDVGAN2).
derivs2 (np.ndarray (N, L-2) or None) – Second derivatives (required for cDVGAN2 only).
- glitchgan.utils.train_gan(gan, dataset, epochs=500, batch_size=64, save_every=25, monitor_every=1, output_dir='GAN_outputs', variant='cDVGAN', noise_dim=100, num_classes=7, start_epoch=1)[source]¶
Train a GAN model.
- Parameters:
gan (cWGAN | cDVGAN | cDVGAN2) – A GAN instance from glitchgan.gan_models.
dataset (GlitchDataset)
epochs (int) – Total number of epochs to reach (not additional epochs to run).
batch_size (int)
save_every (int) – Save full checkpoint and multi-sample example plots every N epochs.
monitor_every (int) – Save a single vertex-sample monitor plot every N epochs (default 1).
output_dir (str) – Directory to save checkpoints, loss history and example plots.
variant (str) – One of
"cWGAN","cDVGAN","cDVGAN2".noise_dim (int)
num_classes (int)
start_epoch (int) – Epoch to start from (1 for fresh training, or resumed checkpoint epoch + 1).
- Returns:
Loss history — keys are loss names, values are lists over epochs.
- Return type:
- glitchgan.utils.save_checkpoint(gan, variant, output_dir, epoch='last')[source]¶
Save model and optimizer state dicts to output_dir.
- glitchgan.utils.load_checkpoint(gan, output_dir, epoch='last', device='cpu')[source]¶
Load model and optimizer state dicts from output_dir into a GAN instance.
- glitchgan.utils.generate_examples(gan, noise_dim=100, num_classes=7, num_signals=10, sampling='vertex', device='cpu')[source]¶
Generate signals using a trained generator.
- Parameters:
gan (GAN instance)
noise_dim (int)
num_classes (int)
num_signals (int)
sampling (str) –
One of
"vertex","simplex","uniform"."vertex"— pure one-hot class vectors (hard class assignment)"simplex"— random convex combinations (sum to 1)"uniform"— independent uniform draws per class dimension
device (str or torch.device)
- Returns:
signals (np.ndarray (num_signals, signal_length))
class_vectors (np.ndarray (num_signals, num_classes))
- glitchgan.utils.plot_losses(history, variant, output_dir)[source]¶
Plot and save training loss curves.
- glitchgan.utils.plot_q_transform(data, srate=4096.0, crop=None, whiten=True, ax=None, colourbar=True, qrange=(4, 64), frange=(10, 1200), clim=(0, 25.5))[source]¶
Plot the Q-transform of a time series using gwpy.
- Parameters:
data (array-like) – Input time-domain data.
srate (float) – Sample rate in Hz.
crop (tuple or None) –
(center_time, duration)in seconds.center_timeis measured from the start ofdata(i.e. t0=0). If None the full segment is used.whiten (bool) – If True, apply gwpy’s internal whitening before the Q-transform. Pass False when the signal is already whitened or noise-free.
ax (matplotlib.axes.Axes or None) – Axes to plot on. A new figure is created if None.
colourbar (bool) – Add a colorbar to the plot.
qrange (tuple) – (q_min, q_max) range for the Q-transform.
frange (tuple) – (f_min, f_max) frequency range in Hz.
clim (tuple) – (vmin, vmax) colour axis limits for normalised energy.
- Return type:
matplotlib.axes.Axes
- glitchgan.utils.whitened_snr_scaling(glitch, snr, srate=4096)[source]¶
Scale a glitch signal to a target SNR in the whitened frame.
Computes the true optimal SNR of the signal via its one-sided power spectral density, then rescales so that the injected signal has the requested SNR.
- Parameters:
- Returns:
Rescaled glitch signal(s), same shape as input.
- Return type: