Source code for glitchgan.tf.gan_models

"""TensorFlow/Keras GAN model classes for cDVGAN.

Includes:
- cWGAN   — conditional Wasserstein GAN with gradient penalty
- cDVGAN  — adds a first-derivative discriminator
- cDVGAN2 — adds first and second derivative discriminators
- build_gan() — factory function
"""

import tensorflow as tf
import keras
from keras.optimizers import RMSprop

from glitchgan.tf.model_components import (
    get_discriminator_model,
    get_derivative_discriminator_model,
    get_second_derivative_discriminator_model,
    get_generator_model,
)

[docs] NUM_CLASSES = 7
def _diff(x): """First-order finite difference along last axis. Replaces calculate_derivative.""" return x[..., 1:] - x[..., :-1] def _gradient_penalty(discriminator, real, fake, classes, batch_size): """WGAN-GP gradient penalty for a given discriminator.""" alpha = tf.random.normal([batch_size, 1], 0.0, 1.0, dtype=tf.float32) interpolated = real + alpha * (fake - real) with tf.GradientTape() as gp_tape: gp_tape.watch(interpolated) pred = discriminator([interpolated, classes], training=True) grads = gp_tape.gradient(pred, [interpolated])[0] norm = tf.sqrt(tf.reduce_sum(tf.square(grads))) return tf.reduce_mean((norm - 1.0) ** 2) # --------------------------------------------------------------------------- # cWGAN # ---------------------------------------------------------------------------
[docs] class cWGAN(keras.Model): """Conditional Wasserstein GAN with gradient penalty.""" def __init__(self, signal_length=8192, num_classes=NUM_CLASSES, noise_dim=100, d_steps=5, gp_weight=10.0, lr=1e-4): super().__init__()
[docs] self.noise_dim = noise_dim
[docs] self.num_classes = num_classes
[docs] self.d_steps = d_steps
[docs] self.gp_weight = gp_weight
[docs] self.generator = get_generator_model(noise_dim, num_classes)
[docs] self.discriminator = get_discriminator_model(signal_length, num_classes)
[docs] self.g_optimizer = RMSprop(learning_rate=lr, rho=0.9, epsilon=1e-7)
[docs] self.d_optimizer = RMSprop(learning_rate=lr, rho=0.9, epsilon=1e-7)
[docs] def train_step(self, data): # data is (signals_batch, classes_batch) from tf.data.Dataset real_signals = tf.cast(data[0], tf.float32) class_array = tf.cast(data[1], tf.float32) batch_size = tf.shape(real_signals)[0] for _ in range(self.d_steps): noise = tf.random.normal((batch_size, self.noise_dim)) with tf.GradientTape() as tape: fake = self.generator([noise, class_array], training=True) real_logits = self.discriminator([real_signals, class_array], training=True) fake_logits = self.discriminator([fake, class_array], training=True) d_cost = tf.reduce_mean(fake_logits) - tf.reduce_mean(real_logits) gp = _gradient_penalty(self.discriminator, real_signals, fake, class_array, batch_size) d_loss = d_cost + self.gp_weight * gp grads = tape.gradient(d_loss, self.discriminator.trainable_variables) self.d_optimizer.apply_gradients( zip(grads, self.discriminator.trainable_variables)) noise = tf.random.normal((batch_size, self.noise_dim)) with tf.GradientTape() as tape: fake = self.generator([noise, class_array], training=True) fake_logits = self.discriminator([fake, class_array], training=True) g_loss = -tf.reduce_mean(fake_logits) grads = tape.gradient(g_loss, self.generator.trainable_variables) self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_variables)) return {"d_loss": d_loss, "g_loss": g_loss}
# --------------------------------------------------------------------------- # cDVGAN # ---------------------------------------------------------------------------
[docs] class cDVGAN(keras.Model): """Conditional Dual-discriminator Variational GAN (first derivative).""" def __init__(self, signal_length=8192, num_classes=NUM_CLASSES, noise_dim=100, d_steps=5, gp_weight=10.0, lr=1e-4): super().__init__()
[docs] self.noise_dim = noise_dim
[docs] self.num_classes = num_classes
[docs] self.d_steps = d_steps
[docs] self.gp_weight = gp_weight
[docs] self.generator = get_generator_model(noise_dim, num_classes)
[docs] self.discriminator = get_discriminator_model(signal_length, num_classes)
[docs] self.deriv_discriminator = get_derivative_discriminator_model( signal_length - 1, num_classes)
[docs] self.g_optimizer = RMSprop(learning_rate=lr, rho=0.9, epsilon=1e-7)
[docs] self.d_optimizer = RMSprop(learning_rate=lr, rho=0.9, epsilon=1e-7)
[docs] self.d2d_optimizer = RMSprop(learning_rate=lr, rho=0.9, epsilon=1e-7)
[docs] def train_step(self, data): # data is (signals_batch, derivs_batch, classes_batch) real_signals = tf.cast(data[0], tf.float32) real_derivs = tf.cast(data[1], tf.float32) class_array = tf.cast(data[2], tf.float32) batch_size = tf.shape(real_signals)[0] for _ in range(self.d_steps): noise = tf.random.normal((batch_size, self.noise_dim)) with tf.GradientTape(persistent=True) as tape: fake = self.generator([noise, class_array], training=True) fake_derivs = _diff(fake) real_logits = self.discriminator([real_signals, class_array], training=True) fake_logits = self.discriminator([fake, class_array], training=True) real_logits2d = self.deriv_discriminator([real_derivs, class_array], training=True) fake_logits2d = self.deriv_discriminator([fake_derivs, class_array], training=True) d_cost = tf.reduce_mean(fake_logits) - tf.reduce_mean(real_logits) d2d_cost = tf.reduce_mean(fake_logits2d) - tf.reduce_mean(real_logits2d) gp = _gradient_penalty(self.discriminator, real_signals, fake, class_array, batch_size) gp2d = _gradient_penalty(self.deriv_discriminator, real_derivs, fake_derivs, class_array, batch_size) d_loss = d_cost + self.gp_weight * gp d2d_loss = d2d_cost + self.gp_weight * gp2d d_grads = tape.gradient(d_loss, self.discriminator.trainable_variables) d2d_grads = tape.gradient(d2d_loss, self.deriv_discriminator.trainable_variables) del tape self.d_optimizer.apply_gradients( zip(d_grads, self.discriminator.trainable_variables)) self.d2d_optimizer.apply_gradients( zip(d2d_grads, self.deriv_discriminator.trainable_variables)) noise = tf.random.normal((batch_size, self.noise_dim)) with tf.GradientTape() as tape: fake = self.generator([noise, class_array], training=True) fake_derivs = _diff(fake) g_loss = -tf.reduce_mean(self.discriminator([fake, class_array], training=True)) g_loss2d = -tf.reduce_mean(self.deriv_discriminator([fake_derivs, class_array], training=True)) g_loss_combined = 0.5 * (g_loss + g_loss2d) grads = tape.gradient(g_loss_combined, self.generator.trainable_variables) self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_variables)) return { "d_loss": d_loss, "d2d_loss": d2d_loss, "g_loss": g_loss, "g_loss2d": g_loss2d, "g_loss_combined": g_loss_combined, }
# --------------------------------------------------------------------------- # cDVGAN2 # ---------------------------------------------------------------------------
[docs] class cDVGAN2(keras.Model): """cDVGAN with an additional second-derivative discriminator.""" def __init__(self, signal_length=8192, num_classes=NUM_CLASSES, noise_dim=100, d_steps=5, gp_weight=10.0, lr=1e-4): super().__init__()
[docs] self.noise_dim = noise_dim
[docs] self.num_classes = num_classes
[docs] self.d_steps = d_steps
[docs] self.gp_weight = gp_weight
[docs] self.generator = get_generator_model(noise_dim, num_classes)
[docs] self.discriminator = get_discriminator_model(signal_length, num_classes)
[docs] self.deriv_discriminator = get_derivative_discriminator_model( signal_length - 1, num_classes)
[docs] self.deriv2_discriminator = get_second_derivative_discriminator_model( signal_length - 2, num_classes)
[docs] self.g_optimizer = RMSprop(learning_rate=lr, rho=0.9, epsilon=1e-7)
[docs] self.d_optimizer = RMSprop(learning_rate=lr, rho=0.9, epsilon=1e-7)
[docs] self.d2d_optimizer = RMSprop(learning_rate=lr, rho=0.9, epsilon=1e-7)
[docs] self.d2d2_optimizer = RMSprop(learning_rate=lr, rho=0.9, epsilon=1e-7)
[docs] def train_step(self, data): # data is (signals_batch, derivs_batch, derivs2_batch, classes_batch) real_signals = tf.cast(data[0], tf.float32) real_derivs = tf.cast(data[1], tf.float32) real_derivs2 = tf.cast(data[2], tf.float32) class_array = tf.cast(data[3], tf.float32) batch_size = tf.shape(real_signals)[0] for _ in range(self.d_steps): noise = tf.random.normal((batch_size, self.noise_dim)) with tf.GradientTape(persistent=True) as tape: fake = self.generator([noise, class_array], training=True) fake_derivs = _diff(fake) fake_derivs2 = _diff(fake_derivs) real_logits = self.discriminator([real_signals, class_array], training=True) fake_logits = self.discriminator([fake, class_array], training=True) real_logits2d = self.deriv_discriminator([real_derivs, class_array], training=True) fake_logits2d = self.deriv_discriminator([fake_derivs, class_array], training=True) real_logits2d2 = self.deriv2_discriminator([real_derivs2, class_array], training=True) fake_logits2d2 = self.deriv2_discriminator([fake_derivs2, class_array], training=True) d_cost = tf.reduce_mean(fake_logits) - tf.reduce_mean(real_logits) d2d_cost = tf.reduce_mean(fake_logits2d) - tf.reduce_mean(real_logits2d) d2d2_cost = tf.reduce_mean(fake_logits2d2) - tf.reduce_mean(real_logits2d2) gp = _gradient_penalty(self.discriminator, real_signals, fake, class_array, batch_size) gp2d = _gradient_penalty(self.deriv_discriminator, real_derivs, fake_derivs, class_array, batch_size) gp2d2 = _gradient_penalty(self.deriv2_discriminator, real_derivs2, fake_derivs2, class_array, batch_size) d_loss = d_cost + self.gp_weight * gp d2d_loss = d2d_cost + self.gp_weight * gp2d d2d2_loss = d2d2_cost + self.gp_weight * gp2d2 d_grads = tape.gradient(d_loss, self.discriminator.trainable_variables) d2d_grads = tape.gradient(d2d_loss, self.deriv_discriminator.trainable_variables) d2d2_grads = tape.gradient(d2d2_loss, self.deriv2_discriminator.trainable_variables) del tape self.d_optimizer.apply_gradients( zip(d_grads, self.discriminator.trainable_variables)) self.d2d_optimizer.apply_gradients( zip(d2d_grads, self.deriv_discriminator.trainable_variables)) self.d2d2_optimizer.apply_gradients( zip(d2d2_grads, self.deriv2_discriminator.trainable_variables)) noise = tf.random.normal((batch_size, self.noise_dim)) with tf.GradientTape() as tape: fake = self.generator([noise, class_array], training=True) fake_derivs = _diff(fake) fake_derivs2 = _diff(fake_derivs) g_loss = -tf.reduce_mean(self.discriminator([fake, class_array], training=True)) g_loss2d = -tf.reduce_mean(self.deriv_discriminator([fake_derivs, class_array], training=True)) g_loss2d2 = -tf.reduce_mean(self.deriv2_discriminator([fake_derivs2, class_array], training=True)) g_loss_combined = (g_loss + g_loss2d + g_loss2d2) / 3.0 grads = tape.gradient(g_loss_combined, self.generator.trainable_variables) self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_variables)) return { "d_loss": d_loss, "d2d_loss": d2d_loss, "d2d2_loss": d2d2_loss, "g_loss": g_loss, "g_loss2d": g_loss2d, "g_loss2d2": g_loss2d2, "g_loss_combined": g_loss_combined, }
# --------------------------------------------------------------------------- # GlitchGAN — cDVGAN configured for LIGO glitch data # ---------------------------------------------------------------------------
[docs] class GlitchGAN(cDVGAN): """cDVGAN trained on LIGO gravitational-wave glitch data. Fixes the LIGO-specific defaults (signal length, number of glitch classes) so they don't need to be passed at every call site. All architecture and training logic lives in :class:`cDVGAN`. """
[docs] SIGNAL_LENGTH = 8192
[docs] NUM_CLASSES = 7
def __init__(self, noise_dim=100, d_steps=5, gp_weight=10.0, lr=1e-4): super().__init__( signal_length=self.SIGNAL_LENGTH, num_classes=self.NUM_CLASSES, noise_dim=noise_dim, d_steps=d_steps, gp_weight=gp_weight, lr=lr, )
# --------------------------------------------------------------------------- # Factory # ---------------------------------------------------------------------------
[docs] def build_gan(variant="cDVGAN", signal_length=8192, num_classes=NUM_CLASSES, noise_dim=100, d_steps=5, gp_weight=10.0, lr=1e-4): """Instantiate a TF GAN variant by name. Parameters ---------- variant : str One of ``"cWGAN"``, ``"cDVGAN"``, ``"cDVGAN2"``. Returns ------- keras.Model """ kwargs = dict(signal_length=signal_length, num_classes=num_classes, noise_dim=noise_dim, d_steps=d_steps, gp_weight=gp_weight, lr=lr) registry = {"cWGAN": cWGAN, "cDVGAN": cDVGAN, "cDVGAN2": cDVGAN2} if variant not in registry: raise ValueError(f"Unknown variant '{variant}'. Choose from {list(registry)}") return registry[variant](**kwargs)