Source code for glitchgan.tf.train

"""Entry point for training cDVGAN models with TensorFlow/Keras.

Usage::

    python -m glitchgan.tf.train --variant cDVGAN --data-dir data/ --epochs 500

"""

import argparse
import os

# Cap TF's CPU thread pool before EagerContext is created.
# On shared HPC nodes the per-user thread limit (RLIMIT_NPROC) is typically
# ~1024, but TF 2.x defaults to one Eigen thread per physical core which can
# easily exceed that on many-core nodes.  Override with env vars so users can
# still increase them if needed (e.g. export TF_NUM_INTRAOP_THREADS=32).
os.environ.setdefault("TF_NUM_INTEROP_THREADS", "4")
os.environ.setdefault("TF_NUM_INTRAOP_THREADS", "8")
# Disable XLA auto-JIT clustering.
os.environ.setdefault("TF_XLA_FLAGS", "--tf_xla_auto_jit=0")
# XLA's HloEvaluator (constant folding at compile time) creates a global Eigen
# thread pool sized to hardware_concurrency() — easily 128 threads on A100 nodes.
# This hits RLIMIT_NPROC on shared clusters.  Disable its multi-threading; only
# compile-time constant evaluation is affected, not GPU kernel execution speed.
os.environ.setdefault("XLA_FLAGS", "--xla_cpu_multi_thread_eigen=false")

import numpy as np
import tensorflow as tf

from glitchgan.tf.gan_models import build_gan
from glitchgan.tf.utils import train_gan


def _pick_best_gpu(gpus):
    """Return index of the GPU with the most free memory, via nvidia-smi."""
    import subprocess
    try:
        result = subprocess.run(
            ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"],
            capture_output=True, text=True, check=True,
        )
        free = [int(x.strip()) for x in result.stdout.strip().splitlines()]
        # Only consider GPUs visible to TF
        free = free[:len(gpus)]
        best = free.index(max(free))
        print(f"GPU free memory (MiB): {free} — selecting GPU {best}")
        return best
    except Exception as exc:
        print(f"nvidia-smi query failed ({exc}), defaulting to GPU 0")
        return 0


def _setup_gpu():
    gpus = tf.config.list_physical_devices("GPU")
    if gpus:
        _set_mg = (tf.config.set_memory_growth
                   if hasattr(tf.config, "set_memory_growth")
                   else tf.config.experimental.set_memory_growth)
        for gpu in gpus:
            _set_mg(gpu, True)
        best = _pick_best_gpu(gpus)
        tf.config.set_visible_devices(gpus[best], "GPU")
        print(f"Using GPU: {gpus[best]}")
    else:
        print("No GPU found, running on CPU.")
    print(f"TensorFlow version: {tf.__version__}")


[docs] def parse_args(): parser = argparse.ArgumentParser( description="Train a cDVGAN model (TensorFlow backend)", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--variant", type=str, default="cDVGAN", choices=["cWGAN", "cDVGAN", "cDVGAN2"], help="GAN variant to train", ) parser.add_argument("--data-dir", type=str, default="data/", help="Directory containing .npy data files") parser.add_argument("--output-dir", type=str, default="GAN_outputs_tf/", help="Directory to save checkpoints and plots") parser.add_argument("--epochs", type=int, default=500) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--noise-dim", type=int, default=100) parser.add_argument("--num-classes", type=int, default=7) parser.add_argument("--d-steps", type=int, default=5, help="Discriminator updates per generator update") parser.add_argument("--gp-weight", type=float, default=10.0, help="Gradient penalty weight (lambda)") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate for RMSprop optimisers") parser.add_argument("--save-every", type=int, default=10, help="Save model checkpoint every N epochs") parser.add_argument("--monitor-every", type=int, default=1, help="Save a monitor plot every N epochs (0 to disable)") parser.add_argument("--resume-epoch", type=int, default=None, help="Resume training from this checkpoint epoch") return parser.parse_args()
[docs] def main(): _setup_gpu() args = parse_args() # Load data signals_path = os.path.join(args.data_dir, "glitch_GAN_samples_scaled_balanced.npy") classes_path = os.path.join(args.data_dir, "glitch_GAN_labels_balanced.npy") derivs_path = os.path.join(args.data_dir, "glitch_GAN_deriv_samples_balanced.npy") print("Loading data...") signals = np.load(signals_path) classes = np.load(classes_path) derivs = None derivs2 = None if args.variant in ("cDVGAN", "cDVGAN2"): derivs = np.load(derivs_path) if args.variant == "cDVGAN2": derivs2 = np.diff(derivs, axis=-1) signal_length = signals.shape[-1] print(f"Signal length: {signal_length}, N samples: {len(signals)}") # Build GAN output_dir = os.path.join(args.output_dir, args.variant) gan = build_gan( variant=args.variant, signal_length=signal_length, num_classes=args.num_classes, noise_dim=args.noise_dim, d_steps=args.d_steps, gp_weight=args.gp_weight, lr=args.lr, ) # Train start = args.resume_epoch if args.resume_epoch is not None else 1 print(f"Training {args.variant} from epoch {start} to {args.epochs}...") train_gan( gan=gan, signals=signals, classes=classes, derivs=derivs, derivs2=derivs2, epochs=args.epochs, batch_size=args.batch_size, variant=args.variant, save_every=args.save_every, monitor_every=args.monitor_every, output_dir=output_dir, noise_dim=args.noise_dim, num_classes=args.num_classes, resume_epoch=args.resume_epoch, ) print(f"Training complete. Outputs saved to {output_dir}")
if __name__ == "__main__": main()