{ "cells": [ { "cell_type": "markdown", "id": "aa000001", "metadata": {}, "source": [ "# GlitchGAN \u2014 GravitySpy Classification\n", "\n", "Standalone classification notebook. Injection pipeline follows `evaluation2.ipynb` exactly." ] }, { "cell_type": "markdown", "id": "3df13eb8", "metadata": {}, "source": "## Dependencies\n\nThis notebook requires GravitySpy for classification. The PyPI release is broken with modern scipy, so install from GitHub:\n\n```bash\npip install \"glitchgan[eval]\" # core deps + gravityspy runtime deps\npip install git+https://github.com/Gravity-Spy/GravitySpy.git --no-deps # gravityspy (PyPI broken)\n```\n\n> **Note:** `--no-deps` skips GravitySpy's broken `scipy<=1.2.1` pin. The GitHub version replaces the removed `scipy.misc.imresize` with `skimage`, which is why it is required over the PyPI release." }, { "cell_type": "code", "execution_count": 1, "id": "aa000002", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Project root : /Users/tomdooney/Documents/Work/Projects/glitchgan\n", "Model exists : True\n" ] } ], "source": [ "import os, sys, io, shutil, warnings, logging\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from pathlib import Path\n", "from tqdm.notebook import tqdm\n", "from IPython.display import Image as IPyImage, display as ipy_display\n", "\n", "PROJECT_ROOT = Path('..').resolve()\n", "sys.path.insert(0, str(PROJECT_ROOT / 'src'))\n", "\n", "from glitchgan.tf import GlitchGAN\n", "from glitchgan.tf.model_components import ArgmaxLayer, ReduceSumDotLayer\n", "from glitchgan.utils import whitened_snr_scaling\n", "\n", "# \u2500\u2500 paths \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "DATA_DIR = PROJECT_ROOT / 'data'\n", "PLOTS_DIR = PROJECT_ROOT / 'evaluation_plots'\n", "GENERATOR_PATH = PROJECT_ROOT / 'weights' / 'tensorflow' / 'generator_210_keras3.keras'\n", "PATH_TO_MODEL = PROJECT_ROOT / 'models' / 'sidd-cqg-paper-O3-model.h5'\n", "#FIXME: set your local GravitySpy clone path\n", "PATH_TO_REPO = '/path/to/GravitySpy/'\n", "\n", "os.makedirs(PLOTS_DIR, exist_ok=True)\n", "\n", "# \u2500\u2500 glitch classes \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "LABEL_ORDER = [\n", " 'Blip', 'Fast_Scattering', 'Koi_Fish',\n", " 'Low_Frequency_Burst', 'Scattered_Light', 'Tomte', 'Whistle',\n", "]\n", "NUM_CLASSES = len(LABEL_ORDER)\n", "NOISE_DIM = 100\n", "\n", "# \u2500\u2500 GravitySpy noise / classification config \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "IFO = 'H1'\n", "SRATE = 4096\n", "GW_START, GW_END = 1262540000, 1262540040\n", "CHANNEL = f'{IFO}:GDS-CALIB_STRAIN'\n", "INIT_TIME = -20\n", "EVENT_TIME = 0\n", "SNR_TARGET = 50\n", "NUM_CLASSIFY = 10\n", "\n", "print('Project root :', PROJECT_ROOT)\n", "print('Model exists :', PATH_TO_MODEL.exists())" ] }, { "cell_type": "code", "execution_count": 2, "id": "aa000003", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded: generator_210_keras3.keras\n" ] } ], "source": [ "import keras\n", "\n", "gan = GlitchGAN()\n", "gan.generator = keras.models.load_model(\n", " str(GENERATOR_PATH), compile=False,\n", " custom_objects={'ArgmaxLayer': ArgmaxLayer, 'ReduceSumDotLayer': ReduceSumDotLayer}\n", ")\n", "print(f'Loaded: {GENERATOR_PATH.name}')" ] }, { "cell_type": "code", "execution_count": 3, "id": "aa000004", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "generated_signals: (70, 8192)\n", "labels : (70,)\n" ] } ], "source": [ "# Generate signals \u2014 same approach as evaluation2: explicit np.eye class vectors,\n", "# np.random.randn noise, no fixed seed\n", "class_vecs = np.repeat(np.eye(NUM_CLASSES, dtype='float32'), NUM_CLASSIFY, axis=0)\n", "noise_vecs = np.random.randn(NUM_CLASSES * NUM_CLASSIFY, NOISE_DIM).astype('float32')\n", "\n", "generated_signals = gan.generator([noise_vecs, class_vecs], training=False).numpy()\n", "labels = np.repeat(np.array(LABEL_ORDER), NUM_CLASSIFY)\n", "\n", "print(f'generated_signals: {generated_signals.shape}')\n", "print(f'labels : {labels.shape}')" ] }, { "cell_type": "markdown", "id": "aa000005", "metadata": {}, "source": [ "## GravitySpy Classification\n", "\n", "Injection pipeline follows `evaluation2.ipynb` exactly: fresh noise fetch \u2192 `to_pycbc()` \u2192 whiten \u2192 pycbc copy \u2192 pycbc `+=` \u2192 gwpy TimeSeries." ] }, { "cell_type": "code", "execution_count": 4, "id": "aa000006", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/homebrew/Caskroom/miniforge/base/envs/cdvgan/lib/python3.11/site-packages/gwpy/time/_ligotimegps.py:42: UserWarning: Wswiglal-redir-stdio:\n", "\n", "SWIGLAL standard output/error redirection is enabled in IPython.\n", "This may lead to performance penalties. To disable locally, use:\n", "\n", "with lal.no_swig_redirect_standard_output_error():\n", " ...\n", "\n", "To disable globally, use:\n", "\n", "lal.swig_redirect_standard_output_error(False)\n", "\n", "Note however that this will likely lead to error messages from\n", "LAL functions being either misdirected or lost when called from\n", "Jupyter notebooks.\n", "\n", "To suppress this warning, use:\n", "\n", "import warnings\n", "warnings.filterwarnings(\"ignore\", \"Wswiglal-redir-stdio\")\n", "import lal\n", "\n", " from lal import LIGOTimeGPS\n", "PyCBC.libutils: pkg-config call failed, setting NO_PKGCONFIG=1\n", "INFO:panoptes_client:libmagic not operational, likely due to lack of shared libraries. Media MIME type determination will be based on file extensions.\n" ] } ], "source": [ "sys.path.insert(0, PATH_TO_REPO)\n", "\n", "from gwpy.timeseries import TimeSeries\n", "from gravityspy.classify import classify\n", "import gravityspy.ml.labelling_test_glitches as _lgt\n", "\n", "warnings.filterwarnings('ignore')\n", "for _log in ['gravityspy', 'gwpy', 'astropy', 'tensorflow']:\n", " logging.getLogger(_log).setLevel(logging.ERROR)\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n", "\n", "GSPY_PLOT_DIR = str(PLOTS_DIR / 'gspy_tmp')\n", "\n", "\n", "def classify_signals(generated_signals, labels, white_noise, noise, label_order, tag):\n", " \"\"\"Classify GlitchGAN signals using GravitySpy.\n", " Injection follows evaluation2 exactly: pycbc copy \u2192 pycbc += \u2192 gwpy TimeSeries.\n", " \"\"\"\n", " ifo = IFO\n", " srate = SRATE\n", " init_time = INIT_TIME\n", " channel_name = CHANNEL\n", " path_to_model = str(PATH_TO_MODEL)\n", " snr_target = SNR_TARGET\n", "\n", " shutil.rmtree(GSPY_PLOT_DIR, ignore_errors=True)\n", " os.makedirs(GSPY_PLOT_DIR, exist_ok=True)\n", "\n", " rows = []\n", " total = NUM_CLASSIFY * len(label_order)\n", " with tqdm(total=total, desc=f'Classifying [{tag}]', unit='glitch') as pbar:\n", " for class_label in label_order:\n", " class_indices = np.where(labels == class_label)[0]\n", " chosen_indices = np.random.choice(class_indices, NUM_CLASSIFY, replace=False)\n", "\n", " for idx in chosen_indices:\n", " glitch = generated_signals[idx].copy()\n", " glitch = whitened_snr_scaling(glitch, snr_target)\n", "\n", " len_glitch = len(glitch)\n", " length = noise.shape[-1]\n", " t_inj = 0.5 * length / srate\n", " id_start = int((t_inj * srate / length) * len(white_noise)) - len_glitch // 2\n", "\n", " injected_noise = white_noise.copy()\n", " injected_noise[id_start:id_start + len_glitch] += glitch\n", "\n", " glitch_series = TimeSeries(\n", " injected_noise, t0=init_time, sample_rate=srate, name=ifo\n", " )\n", "\n", " try:\n", " result = classify(\n", " event_time=EVENT_TIME,\n", " channel_name=channel_name,\n", " path_to_cnn=path_to_model,\n", " timeseries=glitch_series,\n", " plot_directory=GSPY_PLOT_DIR,\n", " )\n", " rows.append({\n", " 'true_label': class_label,\n", " 'pred_label': result['ml_label'].value[0],\n", " 'confidence': result['ml_confidence'].value[0],\n", " })\n", " except Exception as e:\n", " print(f' \\u26a0 {class_label}[{idx}]: {type(e).__name__}: {e}')\n", " rows.append({'true_label': class_label, 'pred_label': 'Error', 'confidence': 0.0})\n", " pbar.update(1)\n", "\n", " return pd.DataFrame(rows)\n", "\n", "\n", "def plot_confusion(df, tag, save_name):\n", " if df is None or len(df) == 0:\n", " print(f'No results for {tag}'); return None\n", " df = df[df['pred_label'] != 'Error']\n", " if len(df) == 0:\n", " print(f'All errors for {tag}'); return None\n", "\n", " pred_all = sorted(df['pred_label'].unique())\n", " for lbl in LABEL_ORDER:\n", " if lbl not in pred_all:\n", " pred_all.append(lbl)\n", " pred_cols = ([l for l in LABEL_ORDER if l in pred_all]\n", " + [l for l in pred_all if l not in LABEL_ORDER])\n", "\n", " count_matrix = pd.DataFrame(0, index=LABEL_ORDER, columns=pred_cols)\n", " conf_accum = {(t, p): [] for t in LABEL_ORDER for p in pred_cols}\n", " for t, p, c in zip(df['true_label'], df['pred_label'], df['confidence']):\n", " if t in LABEL_ORDER and p in pred_cols:\n", " count_matrix.loc[t, p] += 1\n", " conf_accum[(t, p)].append(c)\n", "\n", " annot = pd.DataFrame('', index=LABEL_ORDER, columns=pred_cols)\n", " for t in LABEL_ORDER:\n", " for p in pred_cols:\n", " n = count_matrix.loc[t, p]\n", " annot.loc[t, p] = '0' if n == 0 else f\"{n}\\n({np.mean(conf_accum[(t,p)]):.2f})\"\n", "\n", " total = count_matrix.values.sum()\n", " acc = np.trace(count_matrix.values) / total if total > 0 else 0.0\n", "\n", " fig_w = max(10, len(pred_cols) * 1.1)\n", " sns.set(style='whitegrid', font_scale=1.0)\n", " fig, ax = plt.subplots(figsize=(fig_w, 6))\n", " sns.heatmap(count_matrix, annot=annot, fmt='', cmap='Blues', cbar=True,\n", " linewidths=0.5, linecolor='gray',\n", " annot_kws={'size': 8, 'color': 'black'}, ax=ax)\n", " ax.set_xlabel('Predicted Label', fontsize=12)\n", " ax.set_ylabel('True Label', fontsize=12)\n", " ax.set_title(f'Gravity Spy \\u2014 {tag} (accuracy = {acc:.1%})', fontsize=13)\n", " plt.xticks(rotation=45, ha='right', fontsize=8)\n", " plt.yticks(rotation=0, fontsize=9)\n", " plt.tight_layout()\n", " fig.savefig(PLOTS_DIR / f'{save_name}.pdf', bbox_inches='tight')\n", " buf = io.BytesIO()\n", " fig.savefig(buf, format='png', dpi=150, bbox_inches='tight')\n", " plt.close(fig)\n", " buf.seek(0)\n", " ipy_display(IPyImage(buf.read()))\n", " print(f'{tag} accuracy: {acc:.3f}')\n", " return acc" ] }, { "cell_type": "code", "execution_count": 5, "id": "aa000007", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fetching open data and whitening...\n", "white_noise: 163840 samples dtype: float64\n" ] } ], "source": [ "# Fetch and whiten noise \u2014 identical to evaluation2\n", "print('Fetching open data and whitening...')\n", "noise = TimeSeries.fetch_open_data(IFO, GW_START, GW_END, sample_rate=SRATE)\n", "noise = noise.to_pycbc()\n", "white_noise, psd = noise.whiten(\n", " len(noise) / (2 * SRATE),\n", " len(noise) / (4 * SRATE),\n", " remove_corrupted=False,\n", " return_psd=True,\n", ")\n", "print(f'white_noise: {len(white_noise)} samples dtype: {white_noise.dtype}')" ] }, { "cell_type": "code", "execution_count": 6, "id": "aa000008", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d6be0b3c9c9a4e1aa28b4071515ce979", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Classifying [GlitchGAN (epoch 210, SNR=50)]: 0%| | 0/70 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "GlitchGAN (epoch 210, SNR=50) accuracy: 0.714\n" ] } ], "source": [ "acc = plot_confusion(df, f'GlitchGAN (epoch 210, SNR={SNR_TARGET})', 'gspy_cm')" ] } ], "metadata": { "kernelspec": { "display_name": "cdvgan", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 5 }