Files
mt/tools/plot_scripts/background_ml_illustrations.py
2025-08-13 14:17:12 +02:00

224 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Downloads the cats_vs_dogs dataset, then generates four PNGs:
- supervised_grid.png
- unsupervised_clusters.png
- semi_supervised_classification.png
- semi_supervised_clustering.png
"""
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.semi_supervised import LabelSpreading
# -----------------------------------------------------------------------------
# CONFIGURATION
# -----------------------------------------------------------------------------
NUM_SUPERVISED = 16
GRID_ROWS = 4
GRID_COLS = 4
UNSUP_SAMPLES = 200
# how many labeled points to “seed” semi-sup methods
N_LABELED_CLASS = 10 # for classification demo
N_SEEDS_PER_CLASS = 3 # for clustering demo
OUTPUT_DIR = "outputs"
# -----------------------------------------------------------------------------
# UTILITIES
# -----------------------------------------------------------------------------
def ensure_dir(path):
os.makedirs(path, exist_ok=True)
# -----------------------------------------------------------------------------
# 1) Supervised grid
# -----------------------------------------------------------------------------
def plot_supervised_grid(ds, info, num, rows, cols, outpath):
plt.figure(figsize=(cols * 2, rows * 2))
for i, (img, lbl) in enumerate(ds.take(num)):
ax = plt.subplot(rows, cols, i + 1)
ax.imshow(img.numpy().astype("uint8"))
ax.axis("off")
cname = info.features["label"].int2str(lbl.numpy())
ax.set_title(cname, fontsize=9)
plt.tight_layout()
plt.savefig(outpath, dpi=150)
plt.close()
print(f"✔ Saved supervised grid → {outpath}")
# -----------------------------------------------------------------------------
# 2) Unsupervised clustering (PCA + KMeans)
# -----------------------------------------------------------------------------
def plot_unsupervised_clusters(ds, outpath):
imgs = []
for img, _ in ds.take(UNSUP_SAMPLES):
arr = (
tf.image.resize(img, (64, 64)).numpy().astype("float32").mean(axis=2)
) # resize and grayscale to speed up
imgs.append(arr.ravel() / 255.0)
X = np.stack(imgs)
pca = PCA(n_components=2, random_state=0)
X2 = pca.fit_transform(X)
km = KMeans(n_clusters=2, random_state=0)
clusters = km.fit_predict(X2)
plt.figure(figsize=(6, 6))
plt.scatter(X2[:, 0], X2[:, 1], c=clusters, s=15, alpha=0.6)
plt.title("Unsupervised: PCA + KMeans")
plt.xlabel("PCA 1")
plt.ylabel("PCA 2")
plt.tight_layout()
plt.savefig(outpath, dpi=150)
plt.close()
print(f"✔ Saved unsupervised clusters → {outpath}")
return X2, clusters # return for reuse
# -----------------------------------------------------------------------------
# 3) Semisupervised classification demo
# -----------------------------------------------------------------------------
def plot_semi_supervised_classification(X2, y_true, outpath):
n = X2.shape[0]
all_idx = list(range(n))
labeled_idx = random.sample(all_idx, N_LABELED_CLASS)
unlabeled_idx = list(set(all_idx) - set(labeled_idx))
# pure supervised
clf = LogisticRegression().fit(X2[labeled_idx], y_true[labeled_idx])
# semisupervised
y_train = np.full(n, -1, dtype=int)
y_train[labeled_idx] = y_true[labeled_idx]
ls = LabelSpreading().fit(X2, y_train)
# grid for decision boundary
x_min, x_max = X2[:, 0].min() - 1, X2[:, 0].max() + 1
y_min, y_max = X2[:, 1].min() - 1, X2[:, 1].max() + 1
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200), np.linspace(y_min, y_max, 200))
grid = np.c_[xx.ravel(), yy.ravel()]
pred_sup = clf.predict(grid).reshape(xx.shape)
pred_semi = ls.predict(grid).reshape(xx.shape)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
for ax, Z, title in zip(
axes, [pred_sup, pred_semi], ["Supervised only", "Semi-supervised"]
):
ax.contourf(xx, yy, Z, alpha=0.3)
ax.scatter(X2[unlabeled_idx, 0], X2[unlabeled_idx, 1], s=20, alpha=0.4)
ax.scatter(
X2[labeled_idx, 0],
X2[labeled_idx, 1],
c=y_true[labeled_idx],
s=80,
edgecolor="k",
)
ax.set_title(title)
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.tight_layout()
plt.savefig(outpath, dpi=150)
plt.close()
print(f"✔ Saved semi-supervised classification → {outpath}")
# -----------------------------------------------------------------------------
# 4) Semisupervised clustering (seeded KMeans)
# -----------------------------------------------------------------------------
def plot_semi_supervised_clustering(X2, y_true, outpath):
# pick a few seeds per class
cats = np.where(y_true == 0)[0]
dogs = np.where(y_true == 1)[0]
seeds = list(np.random.choice(cats, N_SEEDS_PER_CLASS, replace=False)) + list(
np.random.choice(dogs, N_SEEDS_PER_CLASS, replace=False)
)
# pure KMeans
km1 = KMeans(n_clusters=2, random_state=0).fit(X2)
lab1 = km1.labels_
# seeded: init centers from seed means
ctr0 = X2[seeds[:N_SEEDS_PER_CLASS]].mean(axis=0)
ctr1 = X2[seeds[N_SEEDS_PER_CLASS:]].mean(axis=0)
km2 = KMeans(
n_clusters=2, init=np.vstack([ctr0, ctr1]), n_init=1, random_state=0
).fit(X2)
lab2 = km2.labels_
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
for ax, labels, title in zip(axes, [lab1, lab2], ["Pure KMeans", "Seeded KMeans"]):
ax.scatter(X2[:, 0], X2[:, 1], c=labels, s=20, alpha=0.6)
ax.scatter(
X2[seeds, 0],
X2[seeds, 1],
c=y_true[seeds],
edgecolor="k",
marker="x",
s=100,
linewidths=2,
)
ax.set_title(title)
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.tight_layout()
plt.savefig(outpath, dpi=150)
plt.close()
print(f"✔ Saved semi-supervised clustering → {outpath}")
# -----------------------------------------------------------------------------
# MAIN
# -----------------------------------------------------------------------------
if __name__ == "__main__":
ensure_dir(OUTPUT_DIR)
# load
print("▶ Loading cats_vs_dogs...")
ds, info = tfds.load(
"cats_vs_dogs", split="train", with_info=True, as_supervised=True
)
ds = ds.shuffle(1000, reshuffle_each_iteration=False).cache()
# supervised
plot_supervised_grid(
ds,
info,
NUM_SUPERVISED,
GRID_ROWS,
GRID_COLS,
os.path.join(OUTPUT_DIR, "supervised_grid.png"),
)
# unsupervised
# also returns X2 for downstream demos
X2, _ = plot_unsupervised_clusters(
ds, os.path.join(OUTPUT_DIR, "unsupervised_clusters.png")
)
# collect true labels for that same subset
# (we need a y_true array for semi-sup demos)
y_true = np.array([lbl.numpy() for _, lbl in ds.take(UNSUP_SAMPLES)])
# semi-supervised classification
plot_semi_supervised_classification(
X2, y_true, os.path.join(OUTPUT_DIR, "semi_supervised_classification.png")
)
# semi-supervised clustering
plot_semi_supervised_clustering(
X2, y_true, os.path.join(OUTPUT_DIR, "semi_supervised_clustering.png")
)