Files
mt/tools/plot_scripts/background_ml_illustrations.py

224 lines
7.3 KiB
Python
Raw Normal View History

2025-08-13 14:17:12 +02:00
"""
Downloads the cats_vs_dogs dataset, then generates four PNGs:
- supervised_grid.png
- unsupervised_clusters.png
- semi_supervised_classification.png
- semi_supervised_clustering.png
"""
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.semi_supervised import LabelSpreading
# -----------------------------------------------------------------------------
# CONFIGURATION
# -----------------------------------------------------------------------------
NUM_SUPERVISED = 16
GRID_ROWS = 4
GRID_COLS = 4
UNSUP_SAMPLES = 200
# how many labeled points to “seed” semi-sup methods
N_LABELED_CLASS = 10 # for classification demo
N_SEEDS_PER_CLASS = 3 # for clustering demo
OUTPUT_DIR = "outputs"
# -----------------------------------------------------------------------------
# UTILITIES
# -----------------------------------------------------------------------------
def ensure_dir(path):
os.makedirs(path, exist_ok=True)
# -----------------------------------------------------------------------------
# 1) Supervised grid
# -----------------------------------------------------------------------------
def plot_supervised_grid(ds, info, num, rows, cols, outpath):
plt.figure(figsize=(cols * 2, rows * 2))
for i, (img, lbl) in enumerate(ds.take(num)):
ax = plt.subplot(rows, cols, i + 1)
ax.imshow(img.numpy().astype("uint8"))
ax.axis("off")
cname = info.features["label"].int2str(lbl.numpy())
ax.set_title(cname, fontsize=9)
plt.tight_layout()
plt.savefig(outpath, dpi=150)
plt.close()
print(f"✔ Saved supervised grid → {outpath}")
# -----------------------------------------------------------------------------
# 2) Unsupervised clustering (PCA + KMeans)
# -----------------------------------------------------------------------------
def plot_unsupervised_clusters(ds, outpath):
imgs = []
for img, _ in ds.take(UNSUP_SAMPLES):
arr = (
tf.image.resize(img, (64, 64)).numpy().astype("float32").mean(axis=2)
) # resize and grayscale to speed up
imgs.append(arr.ravel() / 255.0)
X = np.stack(imgs)
pca = PCA(n_components=2, random_state=0)
X2 = pca.fit_transform(X)
km = KMeans(n_clusters=2, random_state=0)
clusters = km.fit_predict(X2)
plt.figure(figsize=(6, 6))
plt.scatter(X2[:, 0], X2[:, 1], c=clusters, s=15, alpha=0.6)
plt.title("Unsupervised: PCA + KMeans")
plt.xlabel("PCA 1")
plt.ylabel("PCA 2")
plt.tight_layout()
plt.savefig(outpath, dpi=150)
plt.close()
print(f"✔ Saved unsupervised clusters → {outpath}")
return X2, clusters # return for reuse
# -----------------------------------------------------------------------------
# 3) Semisupervised classification demo
# -----------------------------------------------------------------------------
def plot_semi_supervised_classification(X2, y_true, outpath):
n = X2.shape[0]
all_idx = list(range(n))
labeled_idx = random.sample(all_idx, N_LABELED_CLASS)
unlabeled_idx = list(set(all_idx) - set(labeled_idx))
# pure supervised
clf = LogisticRegression().fit(X2[labeled_idx], y_true[labeled_idx])
# semisupervised
y_train = np.full(n, -1, dtype=int)
y_train[labeled_idx] = y_true[labeled_idx]
ls = LabelSpreading().fit(X2, y_train)
# grid for decision boundary
x_min, x_max = X2[:, 0].min() - 1, X2[:, 0].max() + 1
y_min, y_max = X2[:, 1].min() - 1, X2[:, 1].max() + 1
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200), np.linspace(y_min, y_max, 200))
grid = np.c_[xx.ravel(), yy.ravel()]
pred_sup = clf.predict(grid).reshape(xx.shape)
pred_semi = ls.predict(grid).reshape(xx.shape)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
for ax, Z, title in zip(
axes, [pred_sup, pred_semi], ["Supervised only", "Semi-supervised"]
):
ax.contourf(xx, yy, Z, alpha=0.3)
ax.scatter(X2[unlabeled_idx, 0], X2[unlabeled_idx, 1], s=20, alpha=0.4)
ax.scatter(
X2[labeled_idx, 0],
X2[labeled_idx, 1],
c=y_true[labeled_idx],
s=80,
edgecolor="k",
)
ax.set_title(title)
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.tight_layout()
plt.savefig(outpath, dpi=150)
plt.close()
print(f"✔ Saved semi-supervised classification → {outpath}")
# -----------------------------------------------------------------------------
# 4) Semisupervised clustering (seeded KMeans)
# -----------------------------------------------------------------------------
def plot_semi_supervised_clustering(X2, y_true, outpath):
# pick a few seeds per class
cats = np.where(y_true == 0)[0]
dogs = np.where(y_true == 1)[0]
seeds = list(np.random.choice(cats, N_SEEDS_PER_CLASS, replace=False)) + list(
np.random.choice(dogs, N_SEEDS_PER_CLASS, replace=False)
)
# pure KMeans
km1 = KMeans(n_clusters=2, random_state=0).fit(X2)
lab1 = km1.labels_
# seeded: init centers from seed means
ctr0 = X2[seeds[:N_SEEDS_PER_CLASS]].mean(axis=0)
ctr1 = X2[seeds[N_SEEDS_PER_CLASS:]].mean(axis=0)
km2 = KMeans(
n_clusters=2, init=np.vstack([ctr0, ctr1]), n_init=1, random_state=0
).fit(X2)
lab2 = km2.labels_
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
for ax, labels, title in zip(axes, [lab1, lab2], ["Pure KMeans", "Seeded KMeans"]):
ax.scatter(X2[:, 0], X2[:, 1], c=labels, s=20, alpha=0.6)
ax.scatter(
X2[seeds, 0],
X2[seeds, 1],
c=y_true[seeds],
edgecolor="k",
marker="x",
s=100,
linewidths=2,
)
ax.set_title(title)
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.tight_layout()
plt.savefig(outpath, dpi=150)
plt.close()
print(f"✔ Saved semi-supervised clustering → {outpath}")
# -----------------------------------------------------------------------------
# MAIN
# -----------------------------------------------------------------------------
if __name__ == "__main__":
ensure_dir(OUTPUT_DIR)
# load
print("▶ Loading cats_vs_dogs...")
ds, info = tfds.load(
"cats_vs_dogs", split="train", with_info=True, as_supervised=True
)
ds = ds.shuffle(1000, reshuffle_each_iteration=False).cache()
# supervised
plot_supervised_grid(
ds,
info,
NUM_SUPERVISED,
GRID_ROWS,
GRID_COLS,
os.path.join(OUTPUT_DIR, "supervised_grid.png"),
)
# unsupervised
# also returns X2 for downstream demos
X2, _ = plot_unsupervised_clusters(
ds, os.path.join(OUTPUT_DIR, "unsupervised_clusters.png")
)
# collect true labels for that same subset
# (we need a y_true array for semi-sup demos)
y_true = np.array([lbl.numpy() for _, lbl in ds.take(UNSUP_SAMPLES)])
# semi-supervised classification
plot_semi_supervised_classification(
X2, y_true, os.path.join(OUTPUT_DIR, "semi_supervised_classification.png")
)
# semi-supervised clustering
plot_semi_supervised_clustering(
X2, y_true, os.path.join(OUTPUT_DIR, "semi_supervised_clustering.png")
)