Source code for glitchgan.tf.utils

"""Training utilities for the TensorFlow cDVGAN implementation.

Includes:
- build_dataset()       — build a tf.data.Dataset from numpy arrays
- GANMonitor            — Keras callback for per-epoch signal plots
- train_gan()           — wraps model.fit() with checkpointing and history
- generate_examples()   — vertex / simplex / uniform class sampling
- save_models()         — save all model components in .keras format
- load_models()         — restore model components from .keras files
- plot_losses()         — loss curve plotting
- plot_examples()       — grid plot of generated signals
"""

import json
import os

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import keras

from glitchgan.tf.model_components import ArgmaxLayer, ReduceSumDotLayer

# Passed to keras.models.load_model so custom layers can be found regardless
# of whether the checkpoint was saved before the @register_keras_serializable
# decorator was present.
_CUSTOM_OBJECTS = {
    "ArgmaxLayer": ArgmaxLayer,
    "ReduceSumDotLayer": ReduceSumDotLayer,
}


# ---------------------------------------------------------------------------
# Dataset builder
# ---------------------------------------------------------------------------

[docs] def build_dataset(signals, classes, derivs=None, derivs2=None, batch_size=64, shuffle_buffer=1000): """Build a tf.data.Dataset for GAN training. Parameters ---------- signals : np.ndarray (N, L) classes : np.ndarray (N, num_classes) derivs : np.ndarray (N, L-1) or None derivs2 : np.ndarray (N, L-2) or None batch_size : int shuffle_buffer : int Returns ------- tf.data.Dataset Each element is a tuple matching what the GAN's train_step expects: - cWGAN : (signals, classes) - cDVGAN : (signals, derivs, classes) - cDVGAN2 : (signals, derivs, derivs2, classes) """ arrays = [signals.astype(np.float32), ] if derivs is not None: arrays.append(derivs.astype(np.float32)) if derivs2 is not None: arrays.append(derivs2.astype(np.float32)) arrays.append(classes.astype(np.float32)) dataset = tf.data.Dataset.from_tensor_slices(tuple(arrays)) dataset = (dataset .shuffle(shuffle_buffer) .batch(batch_size, drop_remainder=True) .prefetch(tf.data.AUTOTUNE)) # Limit tf.data's private thread pool on shared HPC nodes where # RLIMIT_NPROC is a hard per-user limit. private_threadpool_size=0 # falls back to the global pool; a small positive value caps new threads. options = tf.data.Options() options.threading.private_threadpool_size = 4 options.threading.max_intra_op_parallelism = 1 dataset = dataset.with_options(options) return dataset
# --------------------------------------------------------------------------- # Keras callback: per-epoch monitor # ---------------------------------------------------------------------------
[docs] class GANMonitor(keras.callbacks.Callback): """Save a generated signal plot and model checkpoints during training.""" def __init__(self, noise_dim=100, num_classes=7, output_dir="monitor", save_model_every=10): super().__init__()
[docs] self.noise_dim = noise_dim
[docs] self.num_classes = num_classes
[docs] self.output_dir = output_dir
[docs] self.save_model_every = save_model_every
os.makedirs(output_dir, exist_ok=True)
[docs] def on_epoch_end(self, epoch, logs=None): idx = np.random.randint(0, self.num_classes) class_vec = tf.one_hot([idx], self.num_classes, on_value=1.0, off_value=0.0) noise = tf.random.normal((1, self.noise_dim)) signal = self.model.generator([noise, class_vec], training=False).numpy()[0] fig, ax = plt.subplots(figsize=(8, 3)) ax.plot(signal) ax.set_title(f"Epoch {epoch + 1} — class {idx}") ax.set_xlabel("Sample") plt.tight_layout() plt.savefig(os.path.join(self.output_dir, f"epoch_{epoch + 1:04d}.png")) plt.close(fig) if (epoch + 1) % self.save_model_every == 0: save_models(self.model, self.output_dir, epoch=(epoch + 1))
# --------------------------------------------------------------------------- # Training loop # ---------------------------------------------------------------------------
[docs] def train_gan(gan, signals, classes, derivs=None, derivs2=None, epochs=500, batch_size=64, variant="cDVGAN", save_every=10, monitor_every=1, output_dir="GAN_outputs", noise_dim=100, num_classes=7, resume_epoch=None): """Train a TF GAN model. Parameters ---------- gan : keras.Model (cWGAN / cDVGAN / cDVGAN2) signals : np.ndarray (N, L) classes : np.ndarray (N, num_classes) derivs : np.ndarray or None derivs2 : np.ndarray or None epochs : int batch_size : int variant : str save_every : int monitor_every : int (0 to disable per-epoch plots) output_dir : str noise_dim : int num_classes : int Returns ------- dict — loss history """ os.makedirs(output_dir, exist_ok=True) monitor_dir = os.path.join(output_dir, "monitor") initial_epoch = 0 if resume_epoch is not None: print(f"Resuming from epoch {resume_epoch}...") load_models(gan, monitor_dir, epoch=resume_epoch) initial_epoch = resume_epoch dataset = build_dataset(signals, classes, derivs=derivs, derivs2=derivs2, batch_size=batch_size, shuffle_buffer=len(signals)) callbacks = [] if monitor_every > 0: callbacks.append(GANMonitor( noise_dim=noise_dim, num_classes=num_classes, output_dir=monitor_dir, save_model_every=save_every, )) # Keras model.fit() requires compile() to have been called, but our custom # train_step manages its own optimizers. Call compile with no args. # jit_compile=False: Keras 3.x defaults to XLA JIT on GPU, which spawns # a large Eigen thread pool during compilation that exceeds RLIMIT_NPROC # on shared HPC nodes. Native CUDA/cuDNN kernels are fast enough here. gan.compile(jit_compile=False) history = gan.fit(dataset, epochs=epochs, initial_epoch=initial_epoch, callbacks=callbacks) # Save final models and history save_models(gan, output_dir, epoch="final") history_dict = {k: [float(v) for v in vals] for k, vals in history.history.items()} with open(os.path.join(output_dir, "history.json"), "w") as f: json.dump(history_dict, f) plot_losses(history_dict, variant, output_dir) # Generate example plots for sampling in ("vertex", "simplex", "uniform"): sigs, classes_out = generate_examples( gan, noise_dim=noise_dim, num_classes=num_classes, sampling=sampling) path = os.path.join(output_dir, f"{sampling}_examples_final.png") plot_examples(sigs, classes_out, path) return history_dict
# --------------------------------------------------------------------------- # Checkpointing # ---------------------------------------------------------------------------
[docs] def save_models(gan, output_dir, epoch="last"): """Save all model components in .keras format plus optimizer states.""" os.makedirs(output_dir, exist_ok=True) gan.generator.save(os.path.join(output_dir, f"generator_{epoch}.keras")) gan.discriminator.save(os.path.join(output_dir, f"discriminator_{epoch}.keras")) if hasattr(gan, "deriv_discriminator"): gan.deriv_discriminator.save( os.path.join(output_dir, f"deriv_discriminator_{epoch}.keras")) if hasattr(gan, "deriv2_discriminator"): gan.deriv2_discriminator.save( os.path.join(output_dir, f"deriv2_discriminator_{epoch}.keras")) ckpt_kwargs = {"g_optimizer": gan.g_optimizer, "d_optimizer": gan.d_optimizer} if hasattr(gan, "d2d_optimizer"): ckpt_kwargs["d2d_optimizer"] = gan.d2d_optimizer if hasattr(gan, "d2d2_optimizer"): ckpt_kwargs["d2d2_optimizer"] = gan.d2d2_optimizer tf.train.Checkpoint(**ckpt_kwargs).write( os.path.join(output_dir, f"optimizers_{epoch}"))
[docs] def load_models(gan, output_dir, epoch="last"): """Restore model weights and optimizer states from a checkpoint.""" gan.generator = keras.models.load_model( os.path.join(output_dir, f"generator_{epoch}.keras"), custom_objects=_CUSTOM_OBJECTS) gan.discriminator = keras.models.load_model( os.path.join(output_dir, f"discriminator_{epoch}.keras"), custom_objects=_CUSTOM_OBJECTS) if hasattr(gan, "deriv_discriminator"): gan.deriv_discriminator = keras.models.load_model( os.path.join(output_dir, f"deriv_discriminator_{epoch}.keras"), custom_objects=_CUSTOM_OBJECTS) if hasattr(gan, "deriv2_discriminator"): gan.deriv2_discriminator = keras.models.load_model( os.path.join(output_dir, f"deriv2_discriminator_{epoch}.keras"), custom_objects=_CUSTOM_OBJECTS) # Optimizer slot variables are created lazily on first apply_gradients; # deferred restore ensures they are populated as soon as they exist. ckpt_kwargs = {"g_optimizer": gan.g_optimizer, "d_optimizer": gan.d_optimizer} if hasattr(gan, "d2d_optimizer"): ckpt_kwargs["d2d_optimizer"] = gan.d2d_optimizer if hasattr(gan, "d2d2_optimizer"): ckpt_kwargs["d2d2_optimizer"] = gan.d2d2_optimizer tf.train.Checkpoint(**ckpt_kwargs).restore( os.path.join(output_dir, f"optimizers_{epoch}")).expect_partial()
# --------------------------------------------------------------------------- # Example generation # ---------------------------------------------------------------------------
[docs] def generate_examples(gan, noise_dim=100, num_classes=7, num_signals=10, sampling="vertex"): """Generate signals from the trained generator. Parameters ---------- sampling : str One of ``"vertex"``, ``"simplex"``, ``"uniform"``. Returns ------- signals : np.ndarray (num_signals, signal_length) class_vectors : np.ndarray (num_signals, num_classes) """ noise = tf.random.normal((num_signals, noise_dim)) if sampling == "vertex": indices = np.random.randint(0, num_classes, size=num_signals) class_vectors = np.eye(num_classes)[indices] elif sampling == "simplex": raw = np.random.randint(0, 100, size=(num_signals, num_classes)).astype(float) class_vectors = raw / raw.sum(axis=1, keepdims=True) elif sampling == "uniform": class_vectors = np.random.uniform(0.0, 1.0, size=(num_signals, num_classes)) else: raise ValueError(f"Unknown sampling '{sampling}'. " "Choose from 'vertex', 'simplex', 'uniform'.") class_tensor = tf.cast(class_vectors, tf.float32) signals = gan.generator([noise, class_tensor], training=False).numpy() return signals, class_vectors
# --------------------------------------------------------------------------- # Plotting # ---------------------------------------------------------------------------
[docs] def plot_losses(history, variant, output_dir): """Plot and save training loss curves.""" fig, ax = plt.subplots(figsize=(8, 4)) colors = {"d_loss": "C0", "d2d_loss": "C1", "d2d2_loss": "C3", "g_loss": "C2", "g_loss2d": "C4", "g_loss2d2": "C5", "g_loss_combined": "C6"} for key, values in history.items(): ax.plot(values, label=key, color=colors.get(key)) ax.set_xlabel("Epoch") ax.set_ylabel("Loss") ax.set_title(f"{variant} training losses (TF)") ax.legend() ax.grid(True) plt.tight_layout() plt.savefig(os.path.join(output_dir, f"{variant}_loss_plot.png")) plt.close(fig)
[docs] def plot_examples(signals, classes, path, n=9): """Plot up to n generated signals and save to path.""" n = min(n, len(signals)) fig, axes = plt.subplots(3, 3, figsize=(12, 7)) for i, ax in enumerate(axes.flat): if i >= n: ax.axis("off") continue ax.plot(signals[i]) ax.set_title(np.round(classes[i], 2), fontsize=7) plt.subplots_adjust(hspace=0.4) plt.savefig(path) plt.close(fig)