glitchgan.tf.utils

Training utilities for the TensorFlow cDVGAN implementation.

Includes: - build_dataset() — build a tf.data.Dataset from numpy arrays - GANMonitor — Keras callback for per-epoch signal plots - train_gan() — wraps model.fit() with checkpointing and history - generate_examples() — vertex / simplex / uniform class sampling - save_models() — save all model components in .keras format - load_models() — restore model components from .keras files - plot_losses() — loss curve plotting - plot_examples() — grid plot of generated signals

Module Contents

glitchgan.tf.utils.build_dataset(signals, classes, derivs=None, derivs2=None, batch_size=64, shuffle_buffer=1000)[source]

Build a tf.data.Dataset for GAN training.

Parameters:
  • signals (np.ndarray (N, L))

  • classes (np.ndarray (N, num_classes))

  • derivs (np.ndarray (N, L-1) or None)

  • derivs2 (np.ndarray (N, L-2) or None)

  • batch_size (int)

  • shuffle_buffer (int)

Returns:

Each element is a tuple matching what the GAN’s train_step expects: - cWGAN : (signals, classes) - cDVGAN : (signals, derivs, classes) - cDVGAN2 : (signals, derivs, derivs2, classes)

Return type:

tf.data.Dataset

class glitchgan.tf.utils.GANMonitor(noise_dim=100, num_classes=7, output_dir='monitor', save_model_every=10)[source]

Bases: keras.callbacks.Callback

Save a generated signal plot and model checkpoints during training.

noise_dim = 100[source]
num_classes = 7[source]
output_dir = 'monitor'[source]
save_model_every = 10[source]
on_epoch_end(epoch, logs=None)[source]

Called at the end of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters:
  • epoch – Integer, index of epoch.

  • logs – Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the Model’s metrics are returned. Example: {‘loss’: 0.2, ‘accuracy’: 0.7}.

glitchgan.tf.utils.train_gan(gan, signals, classes, derivs=None, derivs2=None, epochs=500, batch_size=64, variant='cDVGAN', save_every=10, monitor_every=1, output_dir='GAN_outputs', noise_dim=100, num_classes=7, resume_epoch=None)[source]

Train a TF GAN model.

Parameters:
  • gan (keras.Model (cWGAN / cDVGAN / cDVGAN2))

  • signals (np.ndarray (N, L))

  • classes (np.ndarray (N, num_classes))

  • derivs (np.ndarray or None)

  • derivs2 (np.ndarray or None)

  • epochs (int)

  • batch_size (int)

  • variant (str)

  • save_every (int)

  • monitor_every (int (0 to disable per-epoch plots))

  • output_dir (str)

  • noise_dim (int)

  • num_classes (int)

Return type:

dict — loss history

glitchgan.tf.utils.save_models(gan, output_dir, epoch='last')[source]

Save all model components in .keras format plus optimizer states.

glitchgan.tf.utils.load_models(gan, output_dir, epoch='last')[source]

Restore model weights and optimizer states from a checkpoint.

glitchgan.tf.utils.generate_examples(gan, noise_dim=100, num_classes=7, num_signals=10, sampling='vertex')[source]

Generate signals from the trained generator.

Parameters:

sampling (str) – One of "vertex", "simplex", "uniform".

Returns:

  • signals (np.ndarray (num_signals, signal_length))

  • class_vectors (np.ndarray (num_signals, num_classes))

glitchgan.tf.utils.plot_losses(history, variant, output_dir)[source]

Plot and save training loss curves.

glitchgan.tf.utils.plot_examples(signals, classes, path, n=9)[source]

Plot up to n generated signals and save to path.