{ "cells": [ { "cell_type": "markdown", "id": "aa000001", "metadata": {}, "source": [ "# GlitchGAN — GravitySpy Classification\n", "\n", "Standalone classification notebook. Injection pipeline follows `evaluation2.ipynb` exactly." ] }, { "cell_type": "markdown", "id": "3df13eb8", "metadata": {}, "source": [ "## Dependencies\n", "\n", "This notebook requires GravitySpy. Install it and its runtime deps, then apply three small patches for Python 3.11 / keras 3.x compatibility:\n", "\n", "```bash\n", "pip install \"glitchgan[eval]\" # core deps + gravityspy runtime deps\n", "pip install gravityspy==1.0.0 --no-deps # gravityspy (skip broken scipy pin)\n", "```\n", "\n", "Then patch three incompatibilities:\n", "\n", "```python\n", "import site, pathlib\n", "sp = pathlib.Path(site.getsitepackages()[0]) / 'gravityspy'\n", "\n", "# Fix 1: scipy.misc.imresize removed in scipy 1.3\n", "f = sp / 'ml/labelling_test_glitches.py'\n", "txt = f.read_text()\n", "txt = txt.replace('from scipy.misc import imresize',\n", " 'from skimage.transform import resize as imresize')\n", "# Fix 2: keras 3.x loads old .h5 models incorrectly — use tf_keras (legacy keras 2.x)\n", "txt = txt.replace('from keras.applications.vgg16 import preprocess_input',\n", " 'from tf_keras.applications.vgg16 import preprocess_input')\n", "txt = txt.replace('from keras.models import load_model',\n", " 'from tf_keras.models import load_model')\n", "txt = txt.replace('from keras import backend as K',\n", " 'from tf_keras import backend as K')\n", "f.write_text(txt)\n", "\n", "# Fix 3: gwpy 4.0 rejects bare truthiness checks on TimeSeries\n", "f = sp / 'utils/utils.py'\n", "f.write_text(f.read_text().replace(\n", " ' if timeseries:\\n',\n", " ' if timeseries is not None:\\n'))\n", "\n", "print('All patches applied')\n", "```" ] }, { "cell_type": "code", "execution_count": 1, "id": "aa000002", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Project root : /Users/tomdooney/Documents/Work/Projects/glitchgan\n", "Model exists : True\n" ] } ], "source": [ "import os, sys, io, shutil, warnings, logging\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from pathlib import Path\n", "from tqdm.notebook import tqdm\n", "from IPython.display import Image as IPyImage, display as ipy_display\n", "\n", "PROJECT_ROOT = Path('..').resolve()\n", "sys.path.insert(0, str(PROJECT_ROOT / 'src'))\n", "\n", "from glitchgan.tf import GlitchGAN\n", "from glitchgan.tf.model_components import ArgmaxLayer, ReduceSumDotLayer\n", "from glitchgan.utils import whitened_snr_scaling\n", "\n", "# ── paths ─────────────────────────────────────────────────────────────────────\n", "DATA_DIR = PROJECT_ROOT / 'data'\n", "PLOTS_DIR = PROJECT_ROOT / 'evaluation_plots'\n", "GENERATOR_PATH = PROJECT_ROOT / 'weights' / 'tensorflow' / 'generator_210_keras3.keras'\n", "PATH_TO_MODEL = PROJECT_ROOT / 'models' / 'sidd-cqg-paper-O3-model.h5'\n", "#FIXME: set your local GravitySpy clone path\n", "PATH_TO_REPO = '/path/to/GravitySpy/'\n", "\n", "os.makedirs(PLOTS_DIR, exist_ok=True)\n", "\n", "# ── glitch classes ────────────────────────────────────────────────────────────\n", "LABEL_ORDER = [\n", " 'Blip', 'Fast_Scattering', 'Koi_Fish',\n", " 'Low_Frequency_Burst', 'Scattered_Light', 'Tomte', 'Whistle',\n", "]\n", "NUM_CLASSES = len(LABEL_ORDER)\n", "NOISE_DIM = 100\n", "\n", "# ── GravitySpy noise / classification config ─────────────────────────────────\n", "IFO = 'H1'\n", "SRATE = 4096\n", "GW_START, GW_END = 1262540000, 1262540040\n", "CHANNEL = f'{IFO}:GDS-CALIB_STRAIN'\n", "INIT_TIME = -20\n", "EVENT_TIME = 0\n", "SNR_TARGET = 50\n", "NUM_CLASSIFY = 10\n", "\n", "print('Project root :', PROJECT_ROOT)\n", "print('Model exists :', PATH_TO_MODEL.exists())" ] }, { "cell_type": "code", "execution_count": 2, "id": "aa000003", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded: generator_210_keras3.keras\n" ] } ], "source": [ "import keras\n", "\n", "gan = GlitchGAN()\n", "gan.generator = keras.models.load_model(\n", " str(GENERATOR_PATH), compile=False,\n", " custom_objects={'ArgmaxLayer': ArgmaxLayer, 'ReduceSumDotLayer': ReduceSumDotLayer}\n", ")\n", "print(f'Loaded: {GENERATOR_PATH.name}')" ] }, { "cell_type": "code", "execution_count": 3, "id": "aa000004", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "generated_signals: (70, 8192)\n", "labels : (70,)\n" ] } ], "source": [ "# Generate signals — same approach as evaluation2: explicit np.eye class vectors,\n", "# np.random.randn noise, no fixed seed\n", "class_vecs = np.repeat(np.eye(NUM_CLASSES, dtype='float32'), NUM_CLASSIFY, axis=0)\n", "noise_vecs = np.random.randn(NUM_CLASSES * NUM_CLASSIFY, NOISE_DIM).astype('float32')\n", "\n", "generated_signals = gan.generator([noise_vecs, class_vecs], training=False).numpy()\n", "labels = np.repeat(np.array(LABEL_ORDER), NUM_CLASSIFY)\n", "\n", "print(f'generated_signals: {generated_signals.shape}')\n", "print(f'labels : {labels.shape}')" ] }, { "cell_type": "markdown", "id": "aa000005", "metadata": {}, "source": [ "## GravitySpy Classification\n", "\n", "Injection pipeline follows `evaluation2.ipynb` exactly: fresh noise fetch → `to_pycbc()` → whiten → pycbc copy → pycbc `+=` → gwpy TimeSeries." ] }, { "cell_type": "code", "execution_count": 4, "id": "aa000006", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/homebrew/Caskroom/miniforge/base/envs/glitchgan_test/lib/python3.11/site-packages/gwpy/time/_ligotimegps.py:42: UserWarning: Wswiglal-redir-stdio:\n", "\n", "SWIGLAL standard output/error redirection is enabled in IPython.\n", "This may lead to performance penalties. To disable locally, use:\n", "\n", "with lal.no_swig_redirect_standard_output_error():\n", " ...\n", "\n", "To disable globally, use:\n", "\n", "lal.swig_redirect_standard_output_error(False)\n", "\n", "Note however that this will likely lead to error messages from\n", "LAL functions being either misdirected or lost when called from\n", "Jupyter notebooks.\n", "\n", "To suppress this warning, use:\n", "\n", "import warnings\n", "warnings.filterwarnings(\"ignore\", \"Wswiglal-redir-stdio\")\n", "import lal\n", "\n", " from lal import LIGOTimeGPS\n", "PyCBC.libutils: pkg-config call failed, setting NO_PKGCONFIG=1\n", "INFO:panoptes_client:libmagic not operational, likely due to lack of shared libraries. Media MIME type determination will be based on file extensions.\n" ] } ], "source": [ "sys.path.insert(0, PATH_TO_REPO)\n", "\n", "from gwpy.timeseries import TimeSeries\n", "from gravityspy.classify import classify\n", "import gravityspy.ml.labelling_test_glitches as _lgt\n", "\n", "warnings.filterwarnings('ignore')\n", "for _log in ['gravityspy', 'gwpy', 'astropy', 'tensorflow']:\n", " logging.getLogger(_log).setLevel(logging.ERROR)\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n", "\n", "GSPY_PLOT_DIR = str(PLOTS_DIR / 'gspy_tmp')\n", "\n", "\n", "def classify_signals(generated_signals, labels, white_noise, noise, label_order, tag):\n", " \"\"\"Classify GlitchGAN signals using GravitySpy.\n", " Injection follows evaluation2 exactly: pycbc copy → pycbc += → gwpy TimeSeries.\n", " \"\"\"\n", " ifo = IFO\n", " srate = SRATE\n", " init_time = INIT_TIME\n", " channel_name = CHANNEL\n", " path_to_model = str(PATH_TO_MODEL)\n", " snr_target = SNR_TARGET\n", "\n", " shutil.rmtree(GSPY_PLOT_DIR, ignore_errors=True)\n", " os.makedirs(GSPY_PLOT_DIR, exist_ok=True)\n", "\n", " rows = []\n", " total = NUM_CLASSIFY * len(label_order)\n", " with tqdm(total=total, desc=f'Classifying [{tag}]', unit='glitch') as pbar:\n", " for class_label in label_order:\n", " class_indices = np.where(labels == class_label)[0]\n", " chosen_indices = np.random.choice(class_indices, NUM_CLASSIFY, replace=False)\n", "\n", " for idx in chosen_indices:\n", " glitch = generated_signals[idx].copy()\n", " glitch = whitened_snr_scaling(glitch, snr_target)\n", "\n", " len_glitch = len(glitch)\n", " length = noise.shape[-1]\n", " t_inj = 0.5 * length / srate\n", " id_start = int((t_inj * srate / length) * len(white_noise)) - len_glitch // 2\n", "\n", " injected_noise = white_noise.copy()\n", " injected_noise[id_start:id_start + len_glitch] += glitch\n", "\n", " glitch_series = TimeSeries(\n", " injected_noise, t0=init_time, sample_rate=srate, name=ifo\n", " )\n", "\n", " try:\n", " result = classify(\n", " event_time=EVENT_TIME,\n", " channel_name=channel_name,\n", " path_to_cnn=path_to_model,\n", " timeseries=glitch_series,\n", " plot_directory=GSPY_PLOT_DIR,\n", " )\n", " rows.append({\n", " 'true_label': class_label,\n", " 'pred_label': result['ml_label'].value[0],\n", " 'confidence': result['ml_confidence'].value[0],\n", " })\n", " except Exception as e:\n", " print(f' \\u26a0 {class_label}[{idx}]: {type(e).__name__}: {e}')\n", " rows.append({'true_label': class_label, 'pred_label': 'Error', 'confidence': 0.0})\n", " pbar.update(1)\n", "\n", " return pd.DataFrame(rows)\n", "\n", "\n", "def plot_confusion(df, tag, save_name):\n", " if df is None or len(df) == 0:\n", " print(f'No results for {tag}'); return None\n", " df = df[df['pred_label'] != 'Error']\n", " if len(df) == 0:\n", " print(f'All errors for {tag}'); return None\n", "\n", " pred_all = sorted(df['pred_label'].unique())\n", " for lbl in LABEL_ORDER:\n", " if lbl not in pred_all:\n", " pred_all.append(lbl)\n", " pred_cols = ([l for l in LABEL_ORDER if l in pred_all]\n", " + [l for l in pred_all if l not in LABEL_ORDER])\n", "\n", " count_matrix = pd.DataFrame(0, index=LABEL_ORDER, columns=pred_cols)\n", " conf_accum = {(t, p): [] for t in LABEL_ORDER for p in pred_cols}\n", " for t, p, c in zip(df['true_label'], df['pred_label'], df['confidence']):\n", " if t in LABEL_ORDER and p in pred_cols:\n", " count_matrix.loc[t, p] += 1\n", " conf_accum[(t, p)].append(c)\n", "\n", " annot = pd.DataFrame('', index=LABEL_ORDER, columns=pred_cols)\n", " for t in LABEL_ORDER:\n", " for p in pred_cols:\n", " n = count_matrix.loc[t, p]\n", " annot.loc[t, p] = '0' if n == 0 else f\"{n}\\n({np.mean(conf_accum[(t,p)]):.2f})\"\n", "\n", " total = count_matrix.values.sum()\n", " acc = np.trace(count_matrix.values) / total if total > 0 else 0.0\n", "\n", " fig_w = max(10, len(pred_cols) * 1.1)\n", " sns.set(style='whitegrid', font_scale=1.0)\n", " fig, ax = plt.subplots(figsize=(fig_w, 6))\n", " sns.heatmap(count_matrix, annot=annot, fmt='', cmap='Blues', cbar=True,\n", " linewidths=0.5, linecolor='gray',\n", " annot_kws={'size': 8, 'color': 'black'}, ax=ax)\n", " ax.set_xlabel('Predicted Label', fontsize=12)\n", " ax.set_ylabel('True Label', fontsize=12)\n", " ax.set_title(f'Gravity Spy \\u2014 {tag} (accuracy = {acc:.1%})', fontsize=13)\n", " plt.xticks(rotation=45, ha='right', fontsize=8)\n", " plt.yticks(rotation=0, fontsize=9)\n", " plt.tight_layout()\n", " fig.savefig(PLOTS_DIR / f'{save_name}.pdf', bbox_inches='tight')\n", " buf = io.BytesIO()\n", " fig.savefig(buf, format='png', dpi=150, bbox_inches='tight')\n", " plt.close(fig)\n", " buf.seek(0)\n", " ipy_display(IPyImage(buf.read()))\n", " print(f'{tag} accuracy: {acc:.3f}')\n", " return acc" ] }, { "cell_type": "code", "execution_count": 5, "id": "aa000007", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fetching open data and whitening...\n", "white_noise: 163840 samples dtype: float64\n" ] } ], "source": [ "# Fetch and whiten noise — identical to evaluation2\n", "print('Fetching open data and whitening...')\n", "noise = TimeSeries.fetch_open_data(IFO, GW_START, GW_END, sample_rate=SRATE)\n", "noise = noise.to_pycbc()\n", "white_noise, psd = noise.whiten(\n", " len(noise) / (2 * SRATE),\n", " len(noise) / (4 * SRATE),\n", " remove_corrupted=False,\n", " return_psd=True,\n", ")\n", "print(f'white_noise: {len(white_noise)} samples dtype: {white_noise.dtype}')" ] }, { "cell_type": "code", "execution_count": 6, "id": "aa000008", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1af89a9c8e3b47ee932bb6c0e8e796b2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Classifying [GlitchGAN (epoch 210, SNR=50)]: 0%| | 0/70 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "GlitchGAN (epoch 210, SNR=50) accuracy: 0.700\n" ] } ], "source": [ "acc = plot_confusion(df, f'GlitchGAN (epoch 210, SNR={SNR_TARGET})', 'gspy_cm')" ] }, { "cell_type": "code", "execution_count": null, "id": "49ff59cb", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "cdvgan", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 5 }