glitchgan.tf.utils¶
Training utilities for the TensorFlow cDVGAN implementation.
Includes: - build_dataset() — build a tf.data.Dataset from numpy arrays - GANMonitor — Keras callback for per-epoch signal plots - train_gan() — wraps model.fit() with checkpointing and history - generate_examples() — vertex / simplex / uniform class sampling - save_models() — save all model components in .keras format - load_models() — restore model components from .keras files - plot_losses() — loss curve plotting - plot_examples() — grid plot of generated signals
Module Contents¶
- glitchgan.tf.utils.build_dataset(signals, classes, derivs=None, derivs2=None, batch_size=64, shuffle_buffer=1000)[source]¶
Build a tf.data.Dataset for GAN training.
- Parameters:
signals (np.ndarray (N, L))
classes (np.ndarray (N, num_classes))
derivs (np.ndarray (N, L-1) or None)
derivs2 (np.ndarray (N, L-2) or None)
batch_size (int)
shuffle_buffer (int)
- Returns:
Each element is a tuple matching what the GAN’s train_step expects: - cWGAN : (signals, classes) - cDVGAN : (signals, derivs, classes) - cDVGAN2 : (signals, derivs, derivs2, classes)
- Return type:
tf.data.Dataset
- class glitchgan.tf.utils.GANMonitor(noise_dim=100, num_classes=7, output_dir='monitor', save_model_every=10)[source]¶
Bases:
keras.callbacks.CallbackSave a generated signal plot and model checkpoints during training.
- on_epoch_end(epoch, logs=None)[source]¶
Called at the end of an epoch.
Subclasses should override for any actions to run. This function should only be called during TRAIN mode.
- Parameters:
epoch – Integer, index of epoch.
logs – Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the Model’s metrics are returned. Example: {‘loss’: 0.2, ‘accuracy’: 0.7}.
- glitchgan.tf.utils.train_gan(gan, signals, classes, derivs=None, derivs2=None, epochs=500, batch_size=64, variant='cDVGAN', save_every=10, monitor_every=1, output_dir='GAN_outputs', noise_dim=100, num_classes=7, resume_epoch=None)[source]¶
Train a TF GAN model.
- Parameters:
gan (keras.Model (cWGAN / cDVGAN / cDVGAN2))
signals (np.ndarray (N, L))
classes (np.ndarray (N, num_classes))
derivs (np.ndarray or None)
derivs2 (np.ndarray or None)
epochs (int)
batch_size (int)
variant (str)
save_every (int)
monitor_every (int (0 to disable per-epoch plots))
output_dir (str)
noise_dim (int)
num_classes (int)
- Return type:
dict — loss history
- glitchgan.tf.utils.save_models(gan, output_dir, epoch='last')[source]¶
Save all model components in .keras format plus optimizer states.
- glitchgan.tf.utils.load_models(gan, output_dir, epoch='last')[source]¶
Restore model weights and optimizer states from a checkpoint.
- glitchgan.tf.utils.generate_examples(gan, noise_dim=100, num_classes=7, num_signals=10, sampling='vertex')[source]¶
Generate signals from the trained generator.
- Parameters:
sampling (str) – One of
"vertex","simplex","uniform".- Returns:
signals (np.ndarray (num_signals, signal_length))
class_vectors (np.ndarray (num_signals, num_classes))