glitchgan.tf.gan_models ======================= .. py:module:: glitchgan.tf.gan_models .. autoapi-nested-parse:: TensorFlow/Keras GAN model classes for cDVGAN. Includes: - cWGAN — conditional Wasserstein GAN with gradient penalty - cDVGAN — adds a first-derivative discriminator - cDVGAN2 — adds first and second derivative discriminators - build_gan() — factory function Module Contents --------------- .. py:data:: NUM_CLASSES :value: 7 .. py:class:: cWGAN(signal_length=8192, num_classes=NUM_CLASSES, noise_dim=100, d_steps=5, gp_weight=10.0, lr=0.0001) Bases: :py:obj:`keras.Model` Conditional Wasserstein GAN with gradient penalty. .. py:attribute:: noise_dim :value: 100 .. py:attribute:: num_classes :value: 7 .. py:attribute:: d_steps :value: 5 .. py:attribute:: gp_weight :value: 10.0 .. py:attribute:: generator .. py:attribute:: discriminator .. py:attribute:: g_optimizer .. py:attribute:: d_optimizer .. py:method:: train_step(data) .. py:class:: cDVGAN(signal_length=8192, num_classes=NUM_CLASSES, noise_dim=100, d_steps=5, gp_weight=10.0, lr=0.0001) Bases: :py:obj:`keras.Model` Conditional Dual-discriminator Variational GAN (first derivative). .. py:attribute:: noise_dim :value: 100 .. py:attribute:: num_classes :value: 7 .. py:attribute:: d_steps :value: 5 .. py:attribute:: gp_weight :value: 10.0 .. py:attribute:: generator .. py:attribute:: discriminator .. py:attribute:: deriv_discriminator .. py:attribute:: g_optimizer .. py:attribute:: d_optimizer .. py:attribute:: d2d_optimizer .. py:method:: train_step(data) .. py:class:: cDVGAN2(signal_length=8192, num_classes=NUM_CLASSES, noise_dim=100, d_steps=5, gp_weight=10.0, lr=0.0001) Bases: :py:obj:`keras.Model` cDVGAN with an additional second-derivative discriminator. .. py:attribute:: noise_dim :value: 100 .. py:attribute:: num_classes :value: 7 .. py:attribute:: d_steps :value: 5 .. py:attribute:: gp_weight :value: 10.0 .. py:attribute:: generator .. py:attribute:: discriminator .. py:attribute:: deriv_discriminator .. py:attribute:: deriv2_discriminator .. py:attribute:: g_optimizer .. py:attribute:: d_optimizer .. py:attribute:: d2d_optimizer .. py:attribute:: d2d2_optimizer .. py:method:: train_step(data) .. py:class:: GlitchGAN(noise_dim=100, d_steps=5, gp_weight=10.0, lr=0.0001) Bases: :py:obj:`cDVGAN` cDVGAN trained on LIGO gravitational-wave glitch data. Fixes the LIGO-specific defaults (signal length, number of glitch classes) so they don't need to be passed at every call site. All architecture and training logic lives in :class:`cDVGAN`. .. py:attribute:: SIGNAL_LENGTH :value: 8192 .. py:attribute:: NUM_CLASSES :value: 7 .. py:function:: build_gan(variant='cDVGAN', signal_length=8192, num_classes=NUM_CLASSES, noise_dim=100, d_steps=5, gp_weight=10.0, lr=0.0001) Instantiate a TF GAN variant by name. :param variant: One of ``"cWGAN"``, ``"cDVGAN"``, ``"cDVGAN2"``. :type variant: str :rtype: keras.Model