GlitchGAN Evaluation

Evaluates the GlitchGAN generator (TensorFlow/Keras, epoch 210) on LIGO gravitational-wave glitch data.

Sections:

  1. Waveform visualisation — real vs generated signals

  2. UMAP 3D embedding (correlation metric, 2 000 samples per class)

  3. GravitySpy classification — inject generated glitches and classify with the O3 CNN

Prerequisites:

  • Data files in data/ (see README for download instructions)

  • GravitySpy installed; set PATH_TO_REPO below to your local clone

Dependencies

Install the eval extras:

pip install "glitchgan[eval]"
[9]:
import os
import sys
from pathlib import Path

PROJECT_ROOT = Path(".").resolve().parent
DATA_DIR     = PROJECT_ROOT / "data"
PLOTS_DIR    = PROJECT_ROOT / "evaluation_plots"
os.makedirs(PLOTS_DIR, exist_ok=True)

SEED = 56
os.environ["PYTHONHASHSEED"] = str(SEED)

sys.path.insert(0, str(PROJECT_ROOT / "src"))
from glitchgan.tf.model_components import ArgmaxLayer, ReduceSumDotLayer
_co = {"ArgmaxLayer": ArgmaxLayer, "ReduceSumDotLayer": ReduceSumDotLayer}

# ── Generator ─────────────────────────────────────────────────────────────────
GENERATOR_PATH  = PROJECT_ROOT / "weights" / "tensorflow" / "generator_210_keras3.keras"
GENERATOR_EPOCH = 210
NOISE_DIM       = 100
NUM_CLASSES     = 7
SAMPLES_PER_CLASS = 100

LABEL_ORDER = [
    "Blip", "Fast_Scattering", "Koi_Fish",
    "Low_Frequency_Burst", "Scattered_Light", "Tomte", "Whistle",
]

# ── GravitySpy ────────────────────────────────────────────────────────────────
IFO           = "H1"
SRATE         = 4096
GW_START      = 1262540000
GW_END        = GW_START + 40
CHANNEL       = f"{IFO}:GDS-CALIB_STRAIN"
PATH_TO_MODEL = PROJECT_ROOT / "models" / "sidd-cqg-paper-O3-model.h5"
# Update to your local GravitySpy repository clone:
PATH_TO_REPO  = "/path/to/GravitySpy"
NUM_CLASSIFY  = 10
SNR_TARGET    = 50
INIT_TIME     = -20
EVENT_TIME    = 0

print("Project root :", PROJECT_ROOT)
print("Generator    :", GENERATOR_PATH.name)
print("Model exists :", GENERATOR_PATH.exists())

Project root : /Users/tomdooney/Documents/Work/Projects/glitchgan
Generator    : generator_210_keras3.keras
Model exists : True
[10]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import umap.umap_ as umap

try:
    import scienceplots
    plt.style.use(["science", "no-latex"])
    # plt.style.use(["science"])
except ImportError:
    pass  # scienceplots requires LaTeX; fall back to matplotlib defaults

import random
import tensorflow as tf

# random.seed(SEED)
# np.random.seed(SEED)
# tf.random.set_seed(SEED)

1. Load Data

[11]:
X_real   = np.load(DATA_DIR / "glitch_GAN_samples_scaled_balanced.npy")
y_onehot = np.load(DATA_DIR / "glitch_GAN_labels_balanced.npy")
y_real   = np.array(LABEL_ORDER)[np.argmax(y_onehot, axis=1)]

print(f"Real signals : {X_real.shape}")
print(f"Classes      : {LABEL_ORDER}")

Real signals : (35000, 8192)
Classes      : ['Blip', 'Fast_Scattering', 'Koi_Fish', 'Low_Frequency_Burst', 'Scattered_Light', 'Tomte', 'Whistle']

2. Load Generator

[12]:
import keras
from glitchgan.tf import GlitchGAN

gan = GlitchGAN()
gan.generator = keras.models.load_model(str(GENERATOR_PATH), compile=False)

print(f"Loaded GlitchGAN generator from: {GENERATOR_PATH.name}")

Loaded GlitchGAN generator from: generator_210_keras3.keras

3. Training Loss History

[13]:
import json
import numpy as np
import matplotlib.pyplot as plt

HISTORY_PATH = PROJECT_ROOT / "history.json"
with open(HISTORY_PATH) as f:
    history = json.load(f)

epochs = np.arange(1, len(history["d_loss"]) + 1)
SMOOTH = 10

def smooth(x, w):
    return np.convolve(x, np.ones(w) / w, mode="valid")

epochs_s = epochs[SMOOTH - 1:]

series = [
    ("d_loss",          "Discriminator",             "C0"),
    ("d2d_loss",        "Derivative discriminator",  "C1"),
    ("g_loss",          "Generator",                 "C2"),
    ("g_loss2d",        "Generator (derivative)",    "C3"),
    ("g_loss_combined", "Generator (combined)",      "C4"),
]

fig, ax = plt.subplots(figsize=(9, 5))
for key, label, color in series:
    # ax.plot(epochs,   history[key],              color=color, alpha=0.15, lw=0.6)
    ax.plot(epochs,   history[key],              color=color, alpha=1, lw=0.6, label=label)
    # ax.plot(epochs_s, smooth(history[key], SMOOTH), color=color, lw=1.5, label=label)

ax.axhline(0, color="k", lw=0.6, ls="--", alpha=0.4)
ax.set_xlabel("Epoch", fontsize=20)
ax.set_ylabel("Wasserstein Loss", fontsize=20)
ax.set_xlim(1, 500)
ax.legend(fontsize=14, frameon=True, loc="lower left")
plt.tight_layout()
save_path = PLOTS_DIR / "training_loss_history.pdf"
plt.savefig(save_path, bbox_inches="tight")
print(f"Saved: {save_path}")
plt.show()

Saved: /Users/tomdooney/Documents/Work/Projects/glitchgan/evaluation_plots/training_loss_history.pdf
../_images/notebooks_evaluation_9_1.png

4. Generate Signals

[14]:
from glitchgan.tf.utils import generate_examples

X_fake, class_vecs = generate_examples(
    gan,
    noise_dim=NOISE_DIM,
    num_classes=NUM_CLASSES,
    num_signals=NUM_CLASSES * SAMPLES_PER_CLASS,
    sampling="vertex",
)
y_fake = np.array(LABEL_ORDER)[np.argmax(class_vecs, axis=1)]

print(f"Generated : {X_fake.shape}  ({SAMPLES_PER_CLASS} per class)")

Generated : (700, 8192)  (100 per class)

4. Waveform Visualisation

[15]:
from gwpy.timeseries import TimeSeries
import numpy as np
import matplotlib.pyplot as plt

_bg_cache = PROJECT_ROOT / "evaluation_plots" / f"strain_bg_{GW_START}_{GW_END}.hdf5"

try:
    strain_bg = TimeSeries.read(_bg_cache, format="hdf5")
    print(f"Loaded from cache: {_bg_cache}")
except Exception:
    strain_bg = TimeSeries.fetch_open_data("H1", GW_START, GW_END)
    strain_bg = strain_bg.resample(SRATE)
    strain_bg.write(_bg_cache, format="hdf5", overwrite=True)
    print(f"Fetched and cached: {_bg_cache}")

# Match reference exactly: use gwpy's to_pycbc(), keep whitened result as pycbc TimeSeries
noise = strain_bg.to_pycbc()
white_noise_pycbc, psd = noise.whiten(
    len(noise) / (2 * SRATE),
    len(noise) / (4 * SRATE),
    remove_corrupted=False,
    return_psd=True,
)
# numpy copy for plotting / Q-scan visualisation only
white_noise = np.asarray(white_noise_pycbc)
print(f"white_noise: {white_noise.shape}  type: {type(white_noise)}")

plt.figure(figsize=(10, 3))
plt.plot(white_noise[SRATE * 4 : -SRATE * 4])
plt.xlabel("Samples"); plt.ylabel("Amplitude")
plt.title("Whitened background noise")
plt.tight_layout(); plt.show()

Loaded from cache: /Users/tomdooney/Documents/Work/Projects/glitchgan/evaluation_plots/strain_bg_1262540000_1262540040.hdf5
white_noise: (163840,)  type: <class 'numpy.ndarray'>
../_images/notebooks_evaluation_13_1.png
[16]:
from matplotlib.lines import Line2D
from glitchgan.utils import plot_q_transform, whitened_snr_scaling

try:
    plt.style.use(["science", "no-latex"])  # remove "no-latex" if latex is on PATH
except Exception:
    pass

def _inject(glitch, white_noise):
    len_g    = len(glitch)
    id_start = len(white_noise) // 2 - len_g // 2
    injected = white_noise.copy()
    injected[id_start : id_start + len_g] += glitch
    return injected


def plot_comparison_grid(sources, label_order, n_examples=1, save_name=None,
                          white_noise=None, snr_default=50, snr_overrides=None):
    # Local RNG so sample selection is reproducible regardless of upstream state
    rng = np.random.default_rng(SEED)

    show_qscans    = white_noise is not None
    n_sources      = len(sources)
    n_classes      = len(label_order)
    rows_per_src   = 2 if show_qscans else 1
    n_rows         = n_sources * rows_per_src
    ts_height      = 1.4
    qs_height      = 1.8
    row_heights    = [ts_height, qs_height] * n_sources if show_qscans else [ts_height] * n_sources

    fig, axes = plt.subplots(
        n_rows, n_classes * n_examples,
        figsize=(2.4 * n_classes * n_examples, sum(row_heights)),
        gridspec_kw={"hspace": 0.08, "wspace": 0.08,
                     "height_ratios": row_heights},
    )
    if n_rows == 1:
        axes = axes[np.newaxis, :]

    legend_handles = []
    for src_idx, (X, y, src_label, color) in enumerate(sources):
        legend_handles.append(Line2D([0], [0], color=color, lw=1.5, label=src_label))
        ts_row = src_idx * rows_per_src
        qs_row = ts_row + 1 if show_qscans else None

        for ci, lbl in enumerate(label_order):
            idx    = np.where(y == lbl)[0]
            chosen = rng.choice(idx, min(n_examples, len(idx)), replace=False)
            snr    = (snr_overrides or {}).get(lbl, snr_default)

            for j, sample_idx in enumerate(chosen):
                col   = ci * n_examples + j
                ax_ts = axes[ts_row, col]

                ax_ts.plot(X[sample_idx], lw=0.6, color=color)
                ax_ts.set_xticks([]); ax_ts.set_yticks([])
                for spine in ax_ts.spines.values():
                    spine.set_linewidth(0.4)
                if src_idx == 0 and j == 0:
                    ax_ts.set_title(lbl.replace("_", " "), fontsize=15, pad=3)

                if show_qscans:
                    ax_qs    = axes[qs_row, col]
                    glitch   = whitened_snr_scaling(X[sample_idx], snr, srate=SRATE)
                    injected = _inject(glitch, white_noise)
                    plot_q_transform(injected, srate=SRATE, crop=(20, 2),
                                     whiten=False, ax=ax_qs, colourbar=False)
                    ax_qs.set_xlabel(""); ax_qs.set_ylabel("")
                    ax_qs.set_xticks([]); ax_qs.set_yticks([])
                    for spine in ax_qs.spines.values():
                        spine.set_linewidth(0.4)

    plt.tight_layout(rect=[0, 0, 1, 0.99])

    # Extra gap between source pairs (does not affect within-pair spacing)
    if show_qscans and n_sources > 1:
        for src_idx in range(1, n_sources):
            for row_offset in range(rows_per_src):
                row = src_idx * rows_per_src + row_offset
                for col in range(n_classes * n_examples):
                    pos = axes[row, col].get_position()
                    axes[row, col].set_position(
                        [pos.x0, pos.y0 - 0.08, pos.width, pos.height])

    fig.legend(handles=legend_handles, loc="upper center", ncol=n_sources,
               fontsize=15, frameon=False, bbox_to_anchor=(0.5, 0.99))
    if save_name:
        plt.savefig(PLOTS_DIR / f"{save_name}.pdf", bbox_inches="tight")
    plt.show()


plot_comparison_grid(
    sources=[
        (X_real, y_real, "Real",      "C0"),
        (X_fake, y_fake, "GlitchGAN", "C1"),
    ],
    label_order=LABEL_ORDER,
    n_examples=1,
    white_noise=white_noise,
    snr_default=50,
    snr_overrides={"Koi_Fish": 150, "Whistle": 150},
    save_name=f"waveform_comparison_ep{GENERATOR_EPOCH}",
)

/var/folders/gv/z_2s63x116vbz6s2hmnl7mkc0000gn/T/ipykernel_10356/2254928735.py:73: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
  plt.tight_layout(rect=[0, 0, 1, 0.99])
../_images/notebooks_evaluation_14_1.png

UMAP Embedding

3D UMAP embedding (correlation metric) of 2 000 real + 2 000 fake samples per class. Reproduces Figures 5 & 6 from the paper.

[17]:
import tensorflow as tf

N_FULL = 2000

# Subsample real data
rng_full = np.random.default_rng(42)
real_idx = []
for lbl in LABEL_ORDER:
    cls_idx = np.where(y_real == lbl)[0]
    real_idx.extend(rng_full.choice(cls_idx, size=min(N_FULL, len(cls_idx)), replace=False))
X_real_full = X_real[real_idx]
y_real_full  = y_real[real_idx]
print(f"Real subsampled : {X_real_full.shape}")

# Generate fake data in chunks (avoids OOM on large batch)
n_fake_full = NUM_CLASSES * N_FULL
noise_full  = np.random.randn(n_fake_full, NOISE_DIM).astype("float32")
class_full  = np.repeat(np.eye(NUM_CLASSES, dtype="float32"), N_FULL, axis=0)

CHUNK = 2048
chunks_fake = []
for i in range(0, n_fake_full, CHUNK):
    out = gan.generator([noise_full[i:i+CHUNK], class_full[i:i+CHUNK]], training=False)
    chunks_fake.append(out.numpy())
X_fake_full = np.concatenate(chunks_fake, axis=0)
y_fake_full  = np.repeat(LABEL_ORDER, N_FULL)
print(f"Fake generated  : {X_fake_full.shape}")

X_all_full         = np.concatenate([X_real_full, X_fake_full], axis=0)
y_all_full         = np.concatenate([y_real_full, y_fake_full], axis=0)
domain_labels_full = np.array(["Real"] * len(X_real_full) + ["Fake"] * len(X_fake_full))
print(f"UMAP input      : {X_all_full.shape}  ({X_all_full.nbytes / 1e9:.2f} GB)")

Real subsampled : (14000, 8192)
Fake generated  : (14000, 8192)
UMAP input      : (28000, 8192)  (1.84 GB)
[18]:
reducer_full = umap.UMAP(
    n_components=3,
    n_neighbors=15,
    min_dist=0.6,
    metric="correlation",
    random_state=SEED,
    low_memory=True,
)
embedding_full_corr = reducer_full.fit_transform(X_all_full)
print(f"Embedding shape : {embedding_full_corr.shape}")

/opt/homebrew/Caskroom/miniforge/base/envs/glitchgan_test/lib/python3.11/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
Embedding shape : (28000, 3)
[19]:
low, high = 1, 99
x1, x2, x3 = embedding_full_corr[:, 0], embedding_full_corr[:, 1], embedding_full_corr[:, 2]
mask_clean = (
    (x1 >= np.percentile(x1, low))  & (x1 <= np.percentile(x1, high)) &
    (x2 >= np.percentile(x2, low))  & (x2 <= np.percentile(x2, high)) &
    (x3 >= np.percentile(x3, low))  & (x3 <= np.percentile(x3, high))
)
print(f"Removed {np.sum(~mask_clean)} outliers ({np.sum(~mask_clean)/len(mask_clean)*100:.2f}%)")

embedding_full_clean = embedding_full_corr[mask_clean]
y_all_clean          = y_all_full[mask_clean]
domain_labels_clean  = domain_labels_full[mask_clean]

Removed 1680 outliers (6.00%)
[20]:
# ── Figure 5: two-view 3D UMAP ────────────────────────────────────────────────
from matplotlib.lines import Line2D

n_per_class = 800
plot_indices = []
for lbl in LABEL_ORDER:
    idx_real = np.where((y_all_clean == lbl) & (domain_labels_clean == "Real"))[0]
    idx_fake = np.where((y_all_clean == lbl) & (domain_labels_clean == "Fake"))[0]
    plot_indices.extend(np.random.choice(idx_real, min(len(idx_real), n_per_class), replace=False))
    plot_indices.extend(np.random.choice(idx_fake, min(len(idx_fake), n_per_class), replace=False))
plot_indices   = np.array(plot_indices)
embedding_plot = embedding_full_clean[plot_indices]
y_plot         = y_all_clean[plot_indices]
domain_plot    = domain_labels_clean[plot_indices]

unique_classes = np.array(LABEL_ORDER)
colors = plt.cm.tab10(np.linspace(0, 1, len(unique_classes)))

fig = plt.figure(figsize=(14, 6))
ax1 = fig.add_subplot(1, 2, 1, projection="3d")
ax2 = fig.add_subplot(1, 2, 2, projection="3d")

for lbl, color in zip(unique_classes, colors):
    mask_real = (y_plot == lbl) & (domain_plot == "Real")
    mask_fake = (y_plot == lbl) & (domain_plot == "Fake")
    ax1.scatter(embedding_plot[mask_real, 0], embedding_plot[mask_real, 1], embedding_plot[mask_real, 2],
                c=[color], s=1, alpha=0.4, label=lbl.replace("_", " "))
    ax1.scatter(embedding_plot[mask_fake, 0], embedding_plot[mask_fake, 1], embedding_plot[mask_fake, 2],
                c=[color], s=1, alpha=0.4)
    ax2.scatter(embedding_plot[mask_real, 1], embedding_plot[mask_real, 0], embedding_plot[mask_real, 2],
                c=[color], s=1, alpha=0.4)
    ax2.scatter(embedding_plot[mask_fake, 1], embedding_plot[mask_fake, 0], embedding_plot[mask_fake, 2],
                c=[color], s=1, alpha=0.4)

ax1.set_xlabel("UMAP-1", fontsize=18); ax1.set_ylabel("UMAP-2", fontsize=18); ax1.set_zlabel("UMAP-3", fontsize=18)
ax2.set_xlabel("UMAP-2", fontsize=18); ax2.set_ylabel("UMAP-1", fontsize=18); ax2.set_zlabel("UMAP-3", fontsize=18)

handles, labels_leg = ax1.get_legend_handles_labels()
fig.legend(handles, labels_leg, loc="lower center", ncol=len(unique_classes) // 2,
           fontsize=16, title="Glitch Classes", title_fontsize=18,
           frameon=False, bbox_to_anchor=(0.45, 0.02), markerscale=10)

plt.tight_layout(rect=[0, 0.12, 0.83, 1.3])
plt.subplots_adjust(wspace=0.25)
save_path = PLOTS_DIR / "umap_two_views_corr_2k.pdf"
plt.savefig(save_path, bbox_inches="tight")
print(f"Saved: {save_path}")
plt.show()
plt.close()

Saved: /Users/tomdooney/Documents/Work/Projects/glitchgan/evaluation_plots/umap_two_views_corr_2k.pdf
../_images/notebooks_evaluation_19_1.png
[21]:
# ── Figure 6: per-class 3D UMAP ──────────────────────────────────────────────
fig = plt.figure(figsize=(4 * len(unique_classes), 5))

for i, lbl in enumerate(unique_classes):
    ax = fig.add_subplot(1, len(unique_classes), i + 1, projection="3d")
    mask_real = (y_plot == lbl) & (domain_plot == "Real")
    mask_fake = (y_plot == lbl) & (domain_plot == "Fake")
    ax.scatter(embedding_plot[mask_real, 0], embedding_plot[mask_real, 1], embedding_plot[mask_real, 2],
               c="C0", label="Real", s=3, alpha=0.6)
    ax.scatter(embedding_plot[mask_fake, 0], embedding_plot[mask_fake, 1], embedding_plot[mask_fake, 2],
               c="C1", label="Fake", s=3, alpha=0.6)
    ax.view_init(elev=25, azim=45)
    ax.set_xlabel("UMAP-1", fontsize=17, labelpad=2)
    ax.set_ylabel("UMAP-2", fontsize=17, labelpad=2)
    ax.set_zlabel("UMAP-3", fontsize=17, labelpad=2)
    ax.set_title(lbl.replace("_", " "), fontsize=30, pad=10)
    ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])
    # if i == len(unique_classes) - 1:
    #     ax.legend(fontsize=25, loc="upper right", markerscale=1.5, frameon=False)
    if i == 0:
        ax.legend(fontsize=25, loc="lower left", markerscale=5, frameon=False)

plt.tight_layout(rect=[0.01, 0.01, 0.99, 1.1])
plt.subplots_adjust(wspace=0.25)
save_path = PLOTS_DIR / "umap_per_class_corr_2k.pdf"
plt.savefig(save_path, bbox_inches="tight")
print(f"Saved: {save_path}")
plt.show()
plt.close()

Saved: /Users/tomdooney/Documents/Work/Projects/glitchgan/evaluation_plots/umap_per_class_corr_2k.pdf
../_images/notebooks_evaluation_20_1.png
[22]:
# --- 5a. Grid visualisation — 9 examples of each sampling mode ---------------
SAMPLES_GRID = 9

sigs_sp, cvecs_sp = generate_examples(
    gan, noise_dim=NOISE_DIM, num_classes=NUM_CLASSES,
    num_signals=SAMPLES_GRID, sampling="simplex")
sigs_un, cvecs_un = generate_examples(
    gan, noise_dim=NOISE_DIM, num_classes=NUM_CLASSES,
    num_signals=SAMPLES_GRID, sampling="uniform")


def _mixed_grid(signals, class_vecs, title):
    fig, axes = plt.subplots(3, 3, figsize=(9, 6), sharex=True)
    for ax, sig, cvec in zip(axes.flatten(), signals, class_vecs):
        ax.plot(sig, lw=0.8, color="C0")
        ax.set_xticks([])
        ax.set_yticks([])
        dominant = LABEL_ORDER[np.argmax(cvec)]
        vec_str  = ", ".join(f"{v:.2f}" for v in cvec)
        ax.text(0.5, 1.04, f"{dominant}\n[{vec_str}]",
                transform=ax.transAxes, fontsize=6.5, ha="center", va="bottom")
    fig.suptitle(title, fontsize=12)
    plt.tight_layout(pad=1.0, h_pad=2.2)
    safe = title.lower().replace(" ", "_").replace("—", "").replace("(", "").replace(")", "").replace("=", "eq").strip("_")
    plt.savefig(PLOTS_DIR / f"{safe}.pdf", bbox_inches="tight")
    plt.show()


_mixed_grid(sigs_sp, cvecs_sp, "Generated Glitches — Simplex Mixing (sum=1)")
_mixed_grid(sigs_un, cvecs_un, "Generated Glitches — Uniform Mixing ([0,1])")
../_images/notebooks_evaluation_21_0.png
../_images/notebooks_evaluation_21_1.png
[23]:
# --- 5b. Generate vertex / simplex / uniform for UMAP embedding --------------
SAMPLES_PER_TYPE      = 1500
SAMPLES_VERTEX_TOTAL  = (SAMPLES_PER_TYPE // NUM_CLASSES) * NUM_CLASSES  # divisible

sigs_simplex, cvecs_simplex = generate_examples(
    gan, noise_dim=NOISE_DIM, num_classes=NUM_CLASSES,
    num_signals=SAMPLES_PER_TYPE, sampling="simplex")
sigs_uniform, cvecs_uniform = generate_examples(
    gan, noise_dim=NOISE_DIM, num_classes=NUM_CLASSES,
    num_signals=SAMPLES_PER_TYPE, sampling="uniform")
sigs_vertex,  cvecs_vertex  = generate_examples(
    gan, noise_dim=NOISE_DIM, num_classes=NUM_CLASSES,
    num_signals=SAMPLES_VERTEX_TOTAL, sampling="vertex")

y_vertex = np.array(LABEL_ORDER)[np.argmax(cvecs_vertex, axis=1)]

X_mix        = np.concatenate([sigs_vertex,  sigs_simplex, sigs_uniform], axis=0)
domain_labels = np.concatenate([
    np.full(len(sigs_vertex),  "Vertex"),
    np.full(SAMPLES_PER_TYPE,  "Simplex"),
    np.full(SAMPLES_PER_TYPE,  "Uniform"),
])
y_mix = np.concatenate([
    y_vertex,
    np.full(SAMPLES_PER_TYPE, "Simplex"),
    np.full(SAMPLES_PER_TYPE, "Uniform"),
])

print(f"Total: {len(X_mix)}  "
      f"({len(sigs_vertex)} vertex, {SAMPLES_PER_TYPE} simplex, {SAMPLES_PER_TYPE} uniform)")

reducer_mix  = umap.UMAP(n_components=3, random_state=SEED,
                          n_neighbors=30, min_dist=0.2, metric="euclidean")
embedding_mix = reducer_mix.fit_transform(X_mix)
print("UMAP done.")
Total: 4498  (1498 vertex, 1500 simplex, 1500 uniform)
/opt/homebrew/Caskroom/miniforge/base/envs/glitchgan_test/lib/python3.11/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
UMAP done.
[24]:
# --- 5c. 3D UMAP: vertex colored by class, simplex/uniform in black ----------
_class_colors_mix = dict(zip(LABEL_ORDER,
                             plt.cm.tab10(np.linspace(0, 1, len(LABEL_ORDER)))))

fig = plt.figure(figsize=(10, 7))
ax  = fig.add_subplot(111, projection="3d")

m_sp = domain_labels == "Simplex"
ax.scatter(embedding_mix[m_sp, 0], embedding_mix[m_sp, 1], embedding_mix[m_sp, 2],
           c="k", marker="^", s=10, alpha=0.35, label="Simplex mix")

m_un = domain_labels == "Uniform"
ax.scatter(embedding_mix[m_un, 0], embedding_mix[m_un, 1], embedding_mix[m_un, 2],
           c="k", marker="x", s=10, alpha=0.35, label="Uniform mix")

for lbl in LABEL_ORDER:
    m = (domain_labels == "Vertex") & (y_mix == lbl)
    ax.scatter(embedding_mix[m, 0], embedding_mix[m, 1], embedding_mix[m, 2],
               c=[_class_colors_mix[lbl]], s=15, alpha=0.7, label=lbl)

ax.set_xlabel("UMAP-1", fontsize=12)
ax.set_ylabel("UMAP-2", fontsize=12)
ax.set_zlabel("UMAP-3", fontsize=12)
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=9)
plt.tight_layout()
plt.savefig(PLOTS_DIR / "simplex_uniform_umap.pdf", bbox_inches="tight")
plt.show()
../_images/notebooks_evaluation_23_0.png
[25]:
# --- 5d. Save UMAP embedding for later reuse ---------------------------------
np.savez_compressed(
    PLOTS_DIR / "simplex_uniform_umap_data.npz",
    embedding=embedding_mix,
    y_all=y_mix,
    domain_labels=domain_labels,
    label_order=np.array(LABEL_ORDER),
)
print(f"Saved → {PLOTS_DIR / 'simplex_uniform_umap_data.npz'}")
Saved → /Users/tomdooney/Documents/Work/Projects/glitchgan/evaluation_plots/simplex_uniform_umap_data.npz
[ ]: