GlitchGAN Evaluation¶
Evaluates the GlitchGAN generator (TensorFlow/Keras, epoch 210) on LIGO gravitational-wave glitch data.
Sections:
Waveform visualisation — real vs generated signals
UMAP 3D embedding (correlation metric, 2 000 samples per class)
GravitySpy classification — inject generated glitches and classify with the O3 CNN
Prerequisites:
Data files in
data/(see README for download instructions)GravitySpy installed; set
PATH_TO_REPObelow to your local clone
Dependencies¶
Install the eval extras:
pip install "glitchgan[eval]"
[9]:
import os
import sys
from pathlib import Path
PROJECT_ROOT = Path(".").resolve().parent
DATA_DIR = PROJECT_ROOT / "data"
PLOTS_DIR = PROJECT_ROOT / "evaluation_plots"
os.makedirs(PLOTS_DIR, exist_ok=True)
SEED = 56
os.environ["PYTHONHASHSEED"] = str(SEED)
sys.path.insert(0, str(PROJECT_ROOT / "src"))
from glitchgan.tf.model_components import ArgmaxLayer, ReduceSumDotLayer
_co = {"ArgmaxLayer": ArgmaxLayer, "ReduceSumDotLayer": ReduceSumDotLayer}
# ── Generator ─────────────────────────────────────────────────────────────────
GENERATOR_PATH = PROJECT_ROOT / "weights" / "tensorflow" / "generator_210_keras3.keras"
GENERATOR_EPOCH = 210
NOISE_DIM = 100
NUM_CLASSES = 7
SAMPLES_PER_CLASS = 100
LABEL_ORDER = [
"Blip", "Fast_Scattering", "Koi_Fish",
"Low_Frequency_Burst", "Scattered_Light", "Tomte", "Whistle",
]
# ── GravitySpy ────────────────────────────────────────────────────────────────
IFO = "H1"
SRATE = 4096
GW_START = 1262540000
GW_END = GW_START + 40
CHANNEL = f"{IFO}:GDS-CALIB_STRAIN"
PATH_TO_MODEL = PROJECT_ROOT / "models" / "sidd-cqg-paper-O3-model.h5"
# Update to your local GravitySpy repository clone:
PATH_TO_REPO = "/path/to/GravitySpy"
NUM_CLASSIFY = 10
SNR_TARGET = 50
INIT_TIME = -20
EVENT_TIME = 0
print("Project root :", PROJECT_ROOT)
print("Generator :", GENERATOR_PATH.name)
print("Model exists :", GENERATOR_PATH.exists())
Project root : /Users/tomdooney/Documents/Work/Projects/glitchgan
Generator : generator_210_keras3.keras
Model exists : True
[10]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import umap.umap_ as umap
try:
import scienceplots
plt.style.use(["science", "no-latex"])
# plt.style.use(["science"])
except ImportError:
pass # scienceplots requires LaTeX; fall back to matplotlib defaults
import random
import tensorflow as tf
# random.seed(SEED)
# np.random.seed(SEED)
# tf.random.set_seed(SEED)
1. Load Data¶
[11]:
X_real = np.load(DATA_DIR / "glitch_GAN_samples_scaled_balanced.npy")
y_onehot = np.load(DATA_DIR / "glitch_GAN_labels_balanced.npy")
y_real = np.array(LABEL_ORDER)[np.argmax(y_onehot, axis=1)]
print(f"Real signals : {X_real.shape}")
print(f"Classes : {LABEL_ORDER}")
Real signals : (35000, 8192)
Classes : ['Blip', 'Fast_Scattering', 'Koi_Fish', 'Low_Frequency_Burst', 'Scattered_Light', 'Tomte', 'Whistle']
2. Load Generator¶
[12]:
import keras
from glitchgan.tf import GlitchGAN
gan = GlitchGAN()
gan.generator = keras.models.load_model(str(GENERATOR_PATH), compile=False)
print(f"Loaded GlitchGAN generator from: {GENERATOR_PATH.name}")
Loaded GlitchGAN generator from: generator_210_keras3.keras
3. Training Loss History¶
[13]:
import json
import numpy as np
import matplotlib.pyplot as plt
HISTORY_PATH = PROJECT_ROOT / "history.json"
with open(HISTORY_PATH) as f:
history = json.load(f)
epochs = np.arange(1, len(history["d_loss"]) + 1)
SMOOTH = 10
def smooth(x, w):
return np.convolve(x, np.ones(w) / w, mode="valid")
epochs_s = epochs[SMOOTH - 1:]
series = [
("d_loss", "Discriminator", "C0"),
("d2d_loss", "Derivative discriminator", "C1"),
("g_loss", "Generator", "C2"),
("g_loss2d", "Generator (derivative)", "C3"),
("g_loss_combined", "Generator (combined)", "C4"),
]
fig, ax = plt.subplots(figsize=(9, 5))
for key, label, color in series:
# ax.plot(epochs, history[key], color=color, alpha=0.15, lw=0.6)
ax.plot(epochs, history[key], color=color, alpha=1, lw=0.6, label=label)
# ax.plot(epochs_s, smooth(history[key], SMOOTH), color=color, lw=1.5, label=label)
ax.axhline(0, color="k", lw=0.6, ls="--", alpha=0.4)
ax.set_xlabel("Epoch", fontsize=20)
ax.set_ylabel("Wasserstein Loss", fontsize=20)
ax.set_xlim(1, 500)
ax.legend(fontsize=14, frameon=True, loc="lower left")
plt.tight_layout()
save_path = PLOTS_DIR / "training_loss_history.pdf"
plt.savefig(save_path, bbox_inches="tight")
print(f"Saved: {save_path}")
plt.show()
Saved: /Users/tomdooney/Documents/Work/Projects/glitchgan/evaluation_plots/training_loss_history.pdf
4. Generate Signals¶
[14]:
from glitchgan.tf.utils import generate_examples
X_fake, class_vecs = generate_examples(
gan,
noise_dim=NOISE_DIM,
num_classes=NUM_CLASSES,
num_signals=NUM_CLASSES * SAMPLES_PER_CLASS,
sampling="vertex",
)
y_fake = np.array(LABEL_ORDER)[np.argmax(class_vecs, axis=1)]
print(f"Generated : {X_fake.shape} ({SAMPLES_PER_CLASS} per class)")
Generated : (700, 8192) (100 per class)
4. Waveform Visualisation¶
[15]:
from gwpy.timeseries import TimeSeries
import numpy as np
import matplotlib.pyplot as plt
_bg_cache = PROJECT_ROOT / "evaluation_plots" / f"strain_bg_{GW_START}_{GW_END}.hdf5"
try:
strain_bg = TimeSeries.read(_bg_cache, format="hdf5")
print(f"Loaded from cache: {_bg_cache}")
except Exception:
strain_bg = TimeSeries.fetch_open_data("H1", GW_START, GW_END)
strain_bg = strain_bg.resample(SRATE)
strain_bg.write(_bg_cache, format="hdf5", overwrite=True)
print(f"Fetched and cached: {_bg_cache}")
# Match reference exactly: use gwpy's to_pycbc(), keep whitened result as pycbc TimeSeries
noise = strain_bg.to_pycbc()
white_noise_pycbc, psd = noise.whiten(
len(noise) / (2 * SRATE),
len(noise) / (4 * SRATE),
remove_corrupted=False,
return_psd=True,
)
# numpy copy for plotting / Q-scan visualisation only
white_noise = np.asarray(white_noise_pycbc)
print(f"white_noise: {white_noise.shape} type: {type(white_noise)}")
plt.figure(figsize=(10, 3))
plt.plot(white_noise[SRATE * 4 : -SRATE * 4])
plt.xlabel("Samples"); plt.ylabel("Amplitude")
plt.title("Whitened background noise")
plt.tight_layout(); plt.show()
Loaded from cache: /Users/tomdooney/Documents/Work/Projects/glitchgan/evaluation_plots/strain_bg_1262540000_1262540040.hdf5
white_noise: (163840,) type: <class 'numpy.ndarray'>
[16]:
from matplotlib.lines import Line2D
from glitchgan.utils import plot_q_transform, whitened_snr_scaling
try:
plt.style.use(["science", "no-latex"]) # remove "no-latex" if latex is on PATH
except Exception:
pass
def _inject(glitch, white_noise):
len_g = len(glitch)
id_start = len(white_noise) // 2 - len_g // 2
injected = white_noise.copy()
injected[id_start : id_start + len_g] += glitch
return injected
def plot_comparison_grid(sources, label_order, n_examples=1, save_name=None,
white_noise=None, snr_default=50, snr_overrides=None):
# Local RNG so sample selection is reproducible regardless of upstream state
rng = np.random.default_rng(SEED)
show_qscans = white_noise is not None
n_sources = len(sources)
n_classes = len(label_order)
rows_per_src = 2 if show_qscans else 1
n_rows = n_sources * rows_per_src
ts_height = 1.4
qs_height = 1.8
row_heights = [ts_height, qs_height] * n_sources if show_qscans else [ts_height] * n_sources
fig, axes = plt.subplots(
n_rows, n_classes * n_examples,
figsize=(2.4 * n_classes * n_examples, sum(row_heights)),
gridspec_kw={"hspace": 0.08, "wspace": 0.08,
"height_ratios": row_heights},
)
if n_rows == 1:
axes = axes[np.newaxis, :]
legend_handles = []
for src_idx, (X, y, src_label, color) in enumerate(sources):
legend_handles.append(Line2D([0], [0], color=color, lw=1.5, label=src_label))
ts_row = src_idx * rows_per_src
qs_row = ts_row + 1 if show_qscans else None
for ci, lbl in enumerate(label_order):
idx = np.where(y == lbl)[0]
chosen = rng.choice(idx, min(n_examples, len(idx)), replace=False)
snr = (snr_overrides or {}).get(lbl, snr_default)
for j, sample_idx in enumerate(chosen):
col = ci * n_examples + j
ax_ts = axes[ts_row, col]
ax_ts.plot(X[sample_idx], lw=0.6, color=color)
ax_ts.set_xticks([]); ax_ts.set_yticks([])
for spine in ax_ts.spines.values():
spine.set_linewidth(0.4)
if src_idx == 0 and j == 0:
ax_ts.set_title(lbl.replace("_", " "), fontsize=15, pad=3)
if show_qscans:
ax_qs = axes[qs_row, col]
glitch = whitened_snr_scaling(X[sample_idx], snr, srate=SRATE)
injected = _inject(glitch, white_noise)
plot_q_transform(injected, srate=SRATE, crop=(20, 2),
whiten=False, ax=ax_qs, colourbar=False)
ax_qs.set_xlabel(""); ax_qs.set_ylabel("")
ax_qs.set_xticks([]); ax_qs.set_yticks([])
for spine in ax_qs.spines.values():
spine.set_linewidth(0.4)
plt.tight_layout(rect=[0, 0, 1, 0.99])
# Extra gap between source pairs (does not affect within-pair spacing)
if show_qscans and n_sources > 1:
for src_idx in range(1, n_sources):
for row_offset in range(rows_per_src):
row = src_idx * rows_per_src + row_offset
for col in range(n_classes * n_examples):
pos = axes[row, col].get_position()
axes[row, col].set_position(
[pos.x0, pos.y0 - 0.08, pos.width, pos.height])
fig.legend(handles=legend_handles, loc="upper center", ncol=n_sources,
fontsize=15, frameon=False, bbox_to_anchor=(0.5, 0.99))
if save_name:
plt.savefig(PLOTS_DIR / f"{save_name}.pdf", bbox_inches="tight")
plt.show()
plot_comparison_grid(
sources=[
(X_real, y_real, "Real", "C0"),
(X_fake, y_fake, "GlitchGAN", "C1"),
],
label_order=LABEL_ORDER,
n_examples=1,
white_noise=white_noise,
snr_default=50,
snr_overrides={"Koi_Fish": 150, "Whistle": 150},
save_name=f"waveform_comparison_ep{GENERATOR_EPOCH}",
)
/var/folders/gv/z_2s63x116vbz6s2hmnl7mkc0000gn/T/ipykernel_10356/2254928735.py:73: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
plt.tight_layout(rect=[0, 0, 1, 0.99])
UMAP Embedding¶
3D UMAP embedding (correlation metric) of 2 000 real + 2 000 fake samples per class. Reproduces Figures 5 & 6 from the paper.
[17]:
import tensorflow as tf
N_FULL = 2000
# Subsample real data
rng_full = np.random.default_rng(42)
real_idx = []
for lbl in LABEL_ORDER:
cls_idx = np.where(y_real == lbl)[0]
real_idx.extend(rng_full.choice(cls_idx, size=min(N_FULL, len(cls_idx)), replace=False))
X_real_full = X_real[real_idx]
y_real_full = y_real[real_idx]
print(f"Real subsampled : {X_real_full.shape}")
# Generate fake data in chunks (avoids OOM on large batch)
n_fake_full = NUM_CLASSES * N_FULL
noise_full = np.random.randn(n_fake_full, NOISE_DIM).astype("float32")
class_full = np.repeat(np.eye(NUM_CLASSES, dtype="float32"), N_FULL, axis=0)
CHUNK = 2048
chunks_fake = []
for i in range(0, n_fake_full, CHUNK):
out = gan.generator([noise_full[i:i+CHUNK], class_full[i:i+CHUNK]], training=False)
chunks_fake.append(out.numpy())
X_fake_full = np.concatenate(chunks_fake, axis=0)
y_fake_full = np.repeat(LABEL_ORDER, N_FULL)
print(f"Fake generated : {X_fake_full.shape}")
X_all_full = np.concatenate([X_real_full, X_fake_full], axis=0)
y_all_full = np.concatenate([y_real_full, y_fake_full], axis=0)
domain_labels_full = np.array(["Real"] * len(X_real_full) + ["Fake"] * len(X_fake_full))
print(f"UMAP input : {X_all_full.shape} ({X_all_full.nbytes / 1e9:.2f} GB)")
Real subsampled : (14000, 8192)
Fake generated : (14000, 8192)
UMAP input : (28000, 8192) (1.84 GB)
[18]:
reducer_full = umap.UMAP(
n_components=3,
n_neighbors=15,
min_dist=0.6,
metric="correlation",
random_state=SEED,
low_memory=True,
)
embedding_full_corr = reducer_full.fit_transform(X_all_full)
print(f"Embedding shape : {embedding_full_corr.shape}")
/opt/homebrew/Caskroom/miniforge/base/envs/glitchgan_test/lib/python3.11/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
warn(
Embedding shape : (28000, 3)
[19]:
low, high = 1, 99
x1, x2, x3 = embedding_full_corr[:, 0], embedding_full_corr[:, 1], embedding_full_corr[:, 2]
mask_clean = (
(x1 >= np.percentile(x1, low)) & (x1 <= np.percentile(x1, high)) &
(x2 >= np.percentile(x2, low)) & (x2 <= np.percentile(x2, high)) &
(x3 >= np.percentile(x3, low)) & (x3 <= np.percentile(x3, high))
)
print(f"Removed {np.sum(~mask_clean)} outliers ({np.sum(~mask_clean)/len(mask_clean)*100:.2f}%)")
embedding_full_clean = embedding_full_corr[mask_clean]
y_all_clean = y_all_full[mask_clean]
domain_labels_clean = domain_labels_full[mask_clean]
Removed 1680 outliers (6.00%)
[20]:
# ── Figure 5: two-view 3D UMAP ────────────────────────────────────────────────
from matplotlib.lines import Line2D
n_per_class = 800
plot_indices = []
for lbl in LABEL_ORDER:
idx_real = np.where((y_all_clean == lbl) & (domain_labels_clean == "Real"))[0]
idx_fake = np.where((y_all_clean == lbl) & (domain_labels_clean == "Fake"))[0]
plot_indices.extend(np.random.choice(idx_real, min(len(idx_real), n_per_class), replace=False))
plot_indices.extend(np.random.choice(idx_fake, min(len(idx_fake), n_per_class), replace=False))
plot_indices = np.array(plot_indices)
embedding_plot = embedding_full_clean[plot_indices]
y_plot = y_all_clean[plot_indices]
domain_plot = domain_labels_clean[plot_indices]
unique_classes = np.array(LABEL_ORDER)
colors = plt.cm.tab10(np.linspace(0, 1, len(unique_classes)))
fig = plt.figure(figsize=(14, 6))
ax1 = fig.add_subplot(1, 2, 1, projection="3d")
ax2 = fig.add_subplot(1, 2, 2, projection="3d")
for lbl, color in zip(unique_classes, colors):
mask_real = (y_plot == lbl) & (domain_plot == "Real")
mask_fake = (y_plot == lbl) & (domain_plot == "Fake")
ax1.scatter(embedding_plot[mask_real, 0], embedding_plot[mask_real, 1], embedding_plot[mask_real, 2],
c=[color], s=1, alpha=0.4, label=lbl.replace("_", " "))
ax1.scatter(embedding_plot[mask_fake, 0], embedding_plot[mask_fake, 1], embedding_plot[mask_fake, 2],
c=[color], s=1, alpha=0.4)
ax2.scatter(embedding_plot[mask_real, 1], embedding_plot[mask_real, 0], embedding_plot[mask_real, 2],
c=[color], s=1, alpha=0.4)
ax2.scatter(embedding_plot[mask_fake, 1], embedding_plot[mask_fake, 0], embedding_plot[mask_fake, 2],
c=[color], s=1, alpha=0.4)
ax1.set_xlabel("UMAP-1", fontsize=18); ax1.set_ylabel("UMAP-2", fontsize=18); ax1.set_zlabel("UMAP-3", fontsize=18)
ax2.set_xlabel("UMAP-2", fontsize=18); ax2.set_ylabel("UMAP-1", fontsize=18); ax2.set_zlabel("UMAP-3", fontsize=18)
handles, labels_leg = ax1.get_legend_handles_labels()
fig.legend(handles, labels_leg, loc="lower center", ncol=len(unique_classes) // 2,
fontsize=16, title="Glitch Classes", title_fontsize=18,
frameon=False, bbox_to_anchor=(0.45, 0.02), markerscale=10)
plt.tight_layout(rect=[0, 0.12, 0.83, 1.3])
plt.subplots_adjust(wspace=0.25)
save_path = PLOTS_DIR / "umap_two_views_corr_2k.pdf"
plt.savefig(save_path, bbox_inches="tight")
print(f"Saved: {save_path}")
plt.show()
plt.close()
Saved: /Users/tomdooney/Documents/Work/Projects/glitchgan/evaluation_plots/umap_two_views_corr_2k.pdf
[21]:
# ── Figure 6: per-class 3D UMAP ──────────────────────────────────────────────
fig = plt.figure(figsize=(4 * len(unique_classes), 5))
for i, lbl in enumerate(unique_classes):
ax = fig.add_subplot(1, len(unique_classes), i + 1, projection="3d")
mask_real = (y_plot == lbl) & (domain_plot == "Real")
mask_fake = (y_plot == lbl) & (domain_plot == "Fake")
ax.scatter(embedding_plot[mask_real, 0], embedding_plot[mask_real, 1], embedding_plot[mask_real, 2],
c="C0", label="Real", s=3, alpha=0.6)
ax.scatter(embedding_plot[mask_fake, 0], embedding_plot[mask_fake, 1], embedding_plot[mask_fake, 2],
c="C1", label="Fake", s=3, alpha=0.6)
ax.view_init(elev=25, azim=45)
ax.set_xlabel("UMAP-1", fontsize=17, labelpad=2)
ax.set_ylabel("UMAP-2", fontsize=17, labelpad=2)
ax.set_zlabel("UMAP-3", fontsize=17, labelpad=2)
ax.set_title(lbl.replace("_", " "), fontsize=30, pad=10)
ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])
# if i == len(unique_classes) - 1:
# ax.legend(fontsize=25, loc="upper right", markerscale=1.5, frameon=False)
if i == 0:
ax.legend(fontsize=25, loc="lower left", markerscale=5, frameon=False)
plt.tight_layout(rect=[0.01, 0.01, 0.99, 1.1])
plt.subplots_adjust(wspace=0.25)
save_path = PLOTS_DIR / "umap_per_class_corr_2k.pdf"
plt.savefig(save_path, bbox_inches="tight")
print(f"Saved: {save_path}")
plt.show()
plt.close()
Saved: /Users/tomdooney/Documents/Work/Projects/glitchgan/evaluation_plots/umap_per_class_corr_2k.pdf
[22]:
# --- 5a. Grid visualisation — 9 examples of each sampling mode ---------------
SAMPLES_GRID = 9
sigs_sp, cvecs_sp = generate_examples(
gan, noise_dim=NOISE_DIM, num_classes=NUM_CLASSES,
num_signals=SAMPLES_GRID, sampling="simplex")
sigs_un, cvecs_un = generate_examples(
gan, noise_dim=NOISE_DIM, num_classes=NUM_CLASSES,
num_signals=SAMPLES_GRID, sampling="uniform")
def _mixed_grid(signals, class_vecs, title):
fig, axes = plt.subplots(3, 3, figsize=(9, 6), sharex=True)
for ax, sig, cvec in zip(axes.flatten(), signals, class_vecs):
ax.plot(sig, lw=0.8, color="C0")
ax.set_xticks([])
ax.set_yticks([])
dominant = LABEL_ORDER[np.argmax(cvec)]
vec_str = ", ".join(f"{v:.2f}" for v in cvec)
ax.text(0.5, 1.04, f"{dominant}\n[{vec_str}]",
transform=ax.transAxes, fontsize=6.5, ha="center", va="bottom")
fig.suptitle(title, fontsize=12)
plt.tight_layout(pad=1.0, h_pad=2.2)
safe = title.lower().replace(" ", "_").replace("—", "").replace("(", "").replace(")", "").replace("=", "eq").strip("_")
plt.savefig(PLOTS_DIR / f"{safe}.pdf", bbox_inches="tight")
plt.show()
_mixed_grid(sigs_sp, cvecs_sp, "Generated Glitches — Simplex Mixing (sum=1)")
_mixed_grid(sigs_un, cvecs_un, "Generated Glitches — Uniform Mixing ([0,1])")
[23]:
# --- 5b. Generate vertex / simplex / uniform for UMAP embedding --------------
SAMPLES_PER_TYPE = 1500
SAMPLES_VERTEX_TOTAL = (SAMPLES_PER_TYPE // NUM_CLASSES) * NUM_CLASSES # divisible
sigs_simplex, cvecs_simplex = generate_examples(
gan, noise_dim=NOISE_DIM, num_classes=NUM_CLASSES,
num_signals=SAMPLES_PER_TYPE, sampling="simplex")
sigs_uniform, cvecs_uniform = generate_examples(
gan, noise_dim=NOISE_DIM, num_classes=NUM_CLASSES,
num_signals=SAMPLES_PER_TYPE, sampling="uniform")
sigs_vertex, cvecs_vertex = generate_examples(
gan, noise_dim=NOISE_DIM, num_classes=NUM_CLASSES,
num_signals=SAMPLES_VERTEX_TOTAL, sampling="vertex")
y_vertex = np.array(LABEL_ORDER)[np.argmax(cvecs_vertex, axis=1)]
X_mix = np.concatenate([sigs_vertex, sigs_simplex, sigs_uniform], axis=0)
domain_labels = np.concatenate([
np.full(len(sigs_vertex), "Vertex"),
np.full(SAMPLES_PER_TYPE, "Simplex"),
np.full(SAMPLES_PER_TYPE, "Uniform"),
])
y_mix = np.concatenate([
y_vertex,
np.full(SAMPLES_PER_TYPE, "Simplex"),
np.full(SAMPLES_PER_TYPE, "Uniform"),
])
print(f"Total: {len(X_mix)} "
f"({len(sigs_vertex)} vertex, {SAMPLES_PER_TYPE} simplex, {SAMPLES_PER_TYPE} uniform)")
reducer_mix = umap.UMAP(n_components=3, random_state=SEED,
n_neighbors=30, min_dist=0.2, metric="euclidean")
embedding_mix = reducer_mix.fit_transform(X_mix)
print("UMAP done.")
Total: 4498 (1498 vertex, 1500 simplex, 1500 uniform)
/opt/homebrew/Caskroom/miniforge/base/envs/glitchgan_test/lib/python3.11/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
warn(
UMAP done.
[24]:
# --- 5c. 3D UMAP: vertex colored by class, simplex/uniform in black ----------
_class_colors_mix = dict(zip(LABEL_ORDER,
plt.cm.tab10(np.linspace(0, 1, len(LABEL_ORDER)))))
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection="3d")
m_sp = domain_labels == "Simplex"
ax.scatter(embedding_mix[m_sp, 0], embedding_mix[m_sp, 1], embedding_mix[m_sp, 2],
c="k", marker="^", s=10, alpha=0.35, label="Simplex mix")
m_un = domain_labels == "Uniform"
ax.scatter(embedding_mix[m_un, 0], embedding_mix[m_un, 1], embedding_mix[m_un, 2],
c="k", marker="x", s=10, alpha=0.35, label="Uniform mix")
for lbl in LABEL_ORDER:
m = (domain_labels == "Vertex") & (y_mix == lbl)
ax.scatter(embedding_mix[m, 0], embedding_mix[m, 1], embedding_mix[m, 2],
c=[_class_colors_mix[lbl]], s=15, alpha=0.7, label=lbl)
ax.set_xlabel("UMAP-1", fontsize=12)
ax.set_ylabel("UMAP-2", fontsize=12)
ax.set_zlabel("UMAP-3", fontsize=12)
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=9)
plt.tight_layout()
plt.savefig(PLOTS_DIR / "simplex_uniform_umap.pdf", bbox_inches="tight")
plt.show()
[25]:
# --- 5d. Save UMAP embedding for later reuse ---------------------------------
np.savez_compressed(
PLOTS_DIR / "simplex_uniform_umap_data.npz",
embedding=embedding_mix,
y_all=y_mix,
domain_labels=domain_labels,
label_order=np.array(LABEL_ORDER),
)
print(f"Saved → {PLOTS_DIR / 'simplex_uniform_umap_data.npz'}")
Saved → /Users/tomdooney/Documents/Work/Projects/glitchgan/evaluation_plots/simplex_uniform_umap_data.npz
[ ]: