glitchgan.tf.utils ================== .. py:module:: glitchgan.tf.utils .. autoapi-nested-parse:: Training utilities for the TensorFlow cDVGAN implementation. Includes: - build_dataset() — build a tf.data.Dataset from numpy arrays - GANMonitor — Keras callback for per-epoch signal plots - train_gan() — wraps model.fit() with checkpointing and history - generate_examples() — vertex / simplex / uniform class sampling - save_models() — save all model components in .keras format - load_models() — restore model components from .keras files - plot_losses() — loss curve plotting - plot_examples() — grid plot of generated signals Module Contents --------------- .. py:function:: build_dataset(signals, classes, derivs=None, derivs2=None, batch_size=64, shuffle_buffer=1000) Build a tf.data.Dataset for GAN training. :param signals: :type signals: np.ndarray (N, L) :param classes: :type classes: np.ndarray (N, num_classes) :param derivs: :type derivs: np.ndarray (N, L-1) or None :param derivs2: :type derivs2: np.ndarray (N, L-2) or None :param batch_size: :type batch_size: int :param shuffle_buffer: :type shuffle_buffer: int :returns: Each element is a tuple matching what the GAN's train_step expects: - cWGAN : (signals, classes) - cDVGAN : (signals, derivs, classes) - cDVGAN2 : (signals, derivs, derivs2, classes) :rtype: tf.data.Dataset .. py:class:: GANMonitor(noise_dim=100, num_classes=7, output_dir='monitor', save_model_every=10) Bases: :py:obj:`keras.callbacks.Callback` Save a generated signal plot and model checkpoints during training. .. py:attribute:: noise_dim :value: 100 .. py:attribute:: num_classes :value: 7 .. py:attribute:: output_dir :value: 'monitor' .. py:attribute:: save_model_every :value: 10 .. py:method:: on_epoch_end(epoch, logs=None) Called at the end of an epoch. Subclasses should override for any actions to run. This function should only be called during TRAIN mode. :param epoch: Integer, index of epoch. :param logs: Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with `val_`. For training epoch, the values of the `Model`'s metrics are returned. Example: `{'loss': 0.2, 'accuracy': 0.7}`. .. py:function:: train_gan(gan, signals, classes, derivs=None, derivs2=None, epochs=500, batch_size=64, variant='cDVGAN', save_every=10, monitor_every=1, output_dir='GAN_outputs', noise_dim=100, num_classes=7, resume_epoch=None) Train a TF GAN model. :param gan: :type gan: keras.Model (cWGAN / cDVGAN / cDVGAN2) :param signals: :type signals: np.ndarray (N, L) :param classes: :type classes: np.ndarray (N, num_classes) :param derivs: :type derivs: np.ndarray or None :param derivs2: :type derivs2: np.ndarray or None :param epochs: :type epochs: int :param batch_size: :type batch_size: int :param variant: :type variant: str :param save_every: :type save_every: int :param monitor_every: :type monitor_every: int (0 to disable per-epoch plots) :param output_dir: :type output_dir: str :param noise_dim: :type noise_dim: int :param num_classes: :type num_classes: int :rtype: dict — loss history .. py:function:: save_models(gan, output_dir, epoch='last') Save all model components in .keras format plus optimizer states. .. py:function:: load_models(gan, output_dir, epoch='last') Restore model weights and optimizer states from a checkpoint. .. py:function:: generate_examples(gan, noise_dim=100, num_classes=7, num_signals=10, sampling='vertex') Generate signals from the trained generator. :param sampling: One of ``"vertex"``, ``"simplex"``, ``"uniform"``. :type sampling: str :returns: * **signals** (*np.ndarray (num_signals, signal_length)*) * **class_vectors** (*np.ndarray (num_signals, num_classes)*) .. py:function:: plot_losses(history, variant, output_dir) Plot and save training loss curves. .. py:function:: plot_examples(signals, classes, path, n=9) Plot up to n generated signals and save to path.