glitchgan.tf.gan_models

TensorFlow/Keras GAN model classes for cDVGAN.

Includes: - cWGAN — conditional Wasserstein GAN with gradient penalty - cDVGAN — adds a first-derivative discriminator - cDVGAN2 — adds first and second derivative discriminators - build_gan() — factory function

Module Contents

glitchgan.tf.gan_models.NUM_CLASSES = 7[source]
class glitchgan.tf.gan_models.cWGAN(signal_length=8192, num_classes=NUM_CLASSES, noise_dim=100, d_steps=5, gp_weight=10.0, lr=0.0001)[source]

Bases: keras.Model

Conditional Wasserstein GAN with gradient penalty.

noise_dim = 100[source]
num_classes = 7[source]
d_steps = 5[source]
gp_weight = 10.0[source]
generator[source]
discriminator[source]
g_optimizer[source]
d_optimizer[source]
train_step(data)[source]
class glitchgan.tf.gan_models.cDVGAN(signal_length=8192, num_classes=NUM_CLASSES, noise_dim=100, d_steps=5, gp_weight=10.0, lr=0.0001)[source]

Bases: keras.Model

Conditional Dual-discriminator Variational GAN (first derivative).

noise_dim = 100[source]
num_classes = 7[source]
d_steps = 5[source]
gp_weight = 10.0[source]
generator[source]
discriminator[source]
deriv_discriminator[source]
g_optimizer[source]
d_optimizer[source]
d2d_optimizer[source]
train_step(data)[source]
class glitchgan.tf.gan_models.cDVGAN2(signal_length=8192, num_classes=NUM_CLASSES, noise_dim=100, d_steps=5, gp_weight=10.0, lr=0.0001)[source]

Bases: keras.Model

cDVGAN with an additional second-derivative discriminator.

noise_dim = 100[source]
num_classes = 7[source]
d_steps = 5[source]
gp_weight = 10.0[source]
generator[source]
discriminator[source]
deriv_discriminator[source]
deriv2_discriminator[source]
g_optimizer[source]
d_optimizer[source]
d2d_optimizer[source]
d2d2_optimizer[source]
train_step(data)[source]
class glitchgan.tf.gan_models.GlitchGAN(noise_dim=100, d_steps=5, gp_weight=10.0, lr=0.0001)[source]

Bases: cDVGAN

cDVGAN trained on LIGO gravitational-wave glitch data.

Fixes the LIGO-specific defaults (signal length, number of glitch classes) so they don’t need to be passed at every call site. All architecture and training logic lives in cDVGAN.

SIGNAL_LENGTH = 8192[source]
NUM_CLASSES = 7[source]
glitchgan.tf.gan_models.build_gan(variant='cDVGAN', signal_length=8192, num_classes=NUM_CLASSES, noise_dim=100, d_steps=5, gp_weight=10.0, lr=0.0001)[source]

Instantiate a TF GAN variant by name.

Parameters:

variant (str) – One of "cWGAN", "cDVGAN", "cDVGAN2".

Return type:

keras.Model