Training

Data preparation

Download the GravitySpy balanced dataset and place it in data/:

data/
├── glitch_GAN_samples_scaled_balanced.npy   # (35000, 8192) whitened waveforms
└── glitch_GAN_labels_balanced.npy            # (35000, 7)   one-hot class labels

See the README for download instructions.

Training a model

glitchgan-train \
    --variant cDVGAN \
    --data-dir data/ \
    --epochs 500 \
    --output-dir outputs/

Available model variants

Variant

Description

cWGAN

Conditional Wasserstein GAN with gradient penalty (single discriminator)

cDVGAN

Dual-discriminator cWGAN with derivative discriminator (recommended)

cDVGAN2

Extended cDVGAN with additional second-derivative discriminator

Python API

from glitchgan.tf import build_gan, train_gan, GlitchDataset
import numpy as np

X = np.load("data/glitch_GAN_samples_scaled_balanced.npy")
y = np.load("data/glitch_GAN_labels_balanced.npy")

dataset = GlitchDataset(X, y, batch_size=64)
gan     = build_gan("cDVGAN", noise_dim=100, num_classes=7, signal_length=8192)

train_gan(gan, dataset, epochs=500, checkpoint_dir="checkpoints/")

Checkpointing

Weights are saved every 10 epochs to checkpoint_dir/. Training can be resumed by pointing --output-dir at an existing checkpoint directory.

Run glitchgan-train --help for the full list of arguments.