224 lines
7.3 KiB
Python
224 lines
7.3 KiB
Python
|
|
"""
|
|||
|
|
Downloads the cats_vs_dogs dataset, then generates four PNGs:
|
|||
|
|
- supervised_grid.png
|
|||
|
|
- unsupervised_clusters.png
|
|||
|
|
- semi_supervised_classification.png
|
|||
|
|
- semi_supervised_clustering.png
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import random
|
|||
|
|
|
|||
|
|
import matplotlib.pyplot as plt
|
|||
|
|
import numpy as np
|
|||
|
|
import tensorflow as tf
|
|||
|
|
import tensorflow_datasets as tfds
|
|||
|
|
from sklearn.cluster import KMeans
|
|||
|
|
from sklearn.decomposition import PCA
|
|||
|
|
from sklearn.linear_model import LogisticRegression
|
|||
|
|
from sklearn.semi_supervised import LabelSpreading
|
|||
|
|
|
|||
|
|
# -----------------------------------------------------------------------------
|
|||
|
|
# CONFIGURATION
|
|||
|
|
# -----------------------------------------------------------------------------
|
|||
|
|
NUM_SUPERVISED = 16
|
|||
|
|
GRID_ROWS = 4
|
|||
|
|
GRID_COLS = 4
|
|||
|
|
|
|||
|
|
UNSUP_SAMPLES = 200
|
|||
|
|
|
|||
|
|
# how many labeled points to “seed” semi-sup methods
|
|||
|
|
N_LABELED_CLASS = 10 # for classification demo
|
|||
|
|
N_SEEDS_PER_CLASS = 3 # for clustering demo
|
|||
|
|
|
|||
|
|
OUTPUT_DIR = "outputs"
|
|||
|
|
|
|||
|
|
|
|||
|
|
# -----------------------------------------------------------------------------
|
|||
|
|
# UTILITIES
|
|||
|
|
# -----------------------------------------------------------------------------
|
|||
|
|
def ensure_dir(path):
|
|||
|
|
os.makedirs(path, exist_ok=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# -----------------------------------------------------------------------------
|
|||
|
|
# 1) Supervised grid
|
|||
|
|
# -----------------------------------------------------------------------------
|
|||
|
|
def plot_supervised_grid(ds, info, num, rows, cols, outpath):
|
|||
|
|
plt.figure(figsize=(cols * 2, rows * 2))
|
|||
|
|
for i, (img, lbl) in enumerate(ds.take(num)):
|
|||
|
|
ax = plt.subplot(rows, cols, i + 1)
|
|||
|
|
ax.imshow(img.numpy().astype("uint8"))
|
|||
|
|
ax.axis("off")
|
|||
|
|
cname = info.features["label"].int2str(lbl.numpy())
|
|||
|
|
ax.set_title(cname, fontsize=9)
|
|||
|
|
plt.tight_layout()
|
|||
|
|
plt.savefig(outpath, dpi=150)
|
|||
|
|
plt.close()
|
|||
|
|
print(f"✔ Saved supervised grid → {outpath}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# -----------------------------------------------------------------------------
|
|||
|
|
# 2) Unsupervised clustering (PCA + KMeans)
|
|||
|
|
# -----------------------------------------------------------------------------
|
|||
|
|
def plot_unsupervised_clusters(ds, outpath):
|
|||
|
|
imgs = []
|
|||
|
|
for img, _ in ds.take(UNSUP_SAMPLES):
|
|||
|
|
arr = (
|
|||
|
|
tf.image.resize(img, (64, 64)).numpy().astype("float32").mean(axis=2)
|
|||
|
|
) # resize and grayscale to speed up
|
|||
|
|
imgs.append(arr.ravel() / 255.0)
|
|||
|
|
X = np.stack(imgs)
|
|||
|
|
pca = PCA(n_components=2, random_state=0)
|
|||
|
|
X2 = pca.fit_transform(X)
|
|||
|
|
|
|||
|
|
km = KMeans(n_clusters=2, random_state=0)
|
|||
|
|
clusters = km.fit_predict(X2)
|
|||
|
|
|
|||
|
|
plt.figure(figsize=(6, 6))
|
|||
|
|
plt.scatter(X2[:, 0], X2[:, 1], c=clusters, s=15, alpha=0.6)
|
|||
|
|
plt.title("Unsupervised: PCA + KMeans")
|
|||
|
|
plt.xlabel("PCA 1")
|
|||
|
|
plt.ylabel("PCA 2")
|
|||
|
|
plt.tight_layout()
|
|||
|
|
plt.savefig(outpath, dpi=150)
|
|||
|
|
plt.close()
|
|||
|
|
print(f"✔ Saved unsupervised clusters → {outpath}")
|
|||
|
|
|
|||
|
|
return X2, clusters # return for reuse
|
|||
|
|
|
|||
|
|
|
|||
|
|
# -----------------------------------------------------------------------------
|
|||
|
|
# 3) Semi‐supervised classification demo
|
|||
|
|
# -----------------------------------------------------------------------------
|
|||
|
|
def plot_semi_supervised_classification(X2, y_true, outpath):
|
|||
|
|
n = X2.shape[0]
|
|||
|
|
all_idx = list(range(n))
|
|||
|
|
labeled_idx = random.sample(all_idx, N_LABELED_CLASS)
|
|||
|
|
unlabeled_idx = list(set(all_idx) - set(labeled_idx))
|
|||
|
|
|
|||
|
|
# pure supervised
|
|||
|
|
clf = LogisticRegression().fit(X2[labeled_idx], y_true[labeled_idx])
|
|||
|
|
|
|||
|
|
# semi‐supervised
|
|||
|
|
y_train = np.full(n, -1, dtype=int)
|
|||
|
|
y_train[labeled_idx] = y_true[labeled_idx]
|
|||
|
|
ls = LabelSpreading().fit(X2, y_train)
|
|||
|
|
|
|||
|
|
# grid for decision boundary
|
|||
|
|
x_min, x_max = X2[:, 0].min() - 1, X2[:, 0].max() + 1
|
|||
|
|
y_min, y_max = X2[:, 1].min() - 1, X2[:, 1].max() + 1
|
|||
|
|
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200), np.linspace(y_min, y_max, 200))
|
|||
|
|
grid = np.c_[xx.ravel(), yy.ravel()]
|
|||
|
|
pred_sup = clf.predict(grid).reshape(xx.shape)
|
|||
|
|
pred_semi = ls.predict(grid).reshape(xx.shape)
|
|||
|
|
|
|||
|
|
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
|
|||
|
|
for ax, Z, title in zip(
|
|||
|
|
axes, [pred_sup, pred_semi], ["Supervised only", "Semi-supervised"]
|
|||
|
|
):
|
|||
|
|
ax.contourf(xx, yy, Z, alpha=0.3)
|
|||
|
|
ax.scatter(X2[unlabeled_idx, 0], X2[unlabeled_idx, 1], s=20, alpha=0.4)
|
|||
|
|
ax.scatter(
|
|||
|
|
X2[labeled_idx, 0],
|
|||
|
|
X2[labeled_idx, 1],
|
|||
|
|
c=y_true[labeled_idx],
|
|||
|
|
s=80,
|
|||
|
|
edgecolor="k",
|
|||
|
|
)
|
|||
|
|
ax.set_title(title)
|
|||
|
|
ax.set_xlabel("PCA 1")
|
|||
|
|
ax.set_ylabel("PCA 2")
|
|||
|
|
plt.tight_layout()
|
|||
|
|
plt.savefig(outpath, dpi=150)
|
|||
|
|
plt.close()
|
|||
|
|
print(f"✔ Saved semi-supervised classification → {outpath}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# -----------------------------------------------------------------------------
|
|||
|
|
# 4) Semi‐supervised clustering (seeded KMeans)
|
|||
|
|
# -----------------------------------------------------------------------------
|
|||
|
|
def plot_semi_supervised_clustering(X2, y_true, outpath):
|
|||
|
|
# pick a few seeds per class
|
|||
|
|
cats = np.where(y_true == 0)[0]
|
|||
|
|
dogs = np.where(y_true == 1)[0]
|
|||
|
|
seeds = list(np.random.choice(cats, N_SEEDS_PER_CLASS, replace=False)) + list(
|
|||
|
|
np.random.choice(dogs, N_SEEDS_PER_CLASS, replace=False)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# pure KMeans
|
|||
|
|
km1 = KMeans(n_clusters=2, random_state=0).fit(X2)
|
|||
|
|
lab1 = km1.labels_
|
|||
|
|
|
|||
|
|
# seeded: init centers from seed means
|
|||
|
|
ctr0 = X2[seeds[:N_SEEDS_PER_CLASS]].mean(axis=0)
|
|||
|
|
ctr1 = X2[seeds[N_SEEDS_PER_CLASS:]].mean(axis=0)
|
|||
|
|
km2 = KMeans(
|
|||
|
|
n_clusters=2, init=np.vstack([ctr0, ctr1]), n_init=1, random_state=0
|
|||
|
|
).fit(X2)
|
|||
|
|
lab2 = km2.labels_
|
|||
|
|
|
|||
|
|
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
|
|||
|
|
for ax, labels, title in zip(axes, [lab1, lab2], ["Pure KMeans", "Seeded KMeans"]):
|
|||
|
|
ax.scatter(X2[:, 0], X2[:, 1], c=labels, s=20, alpha=0.6)
|
|||
|
|
ax.scatter(
|
|||
|
|
X2[seeds, 0],
|
|||
|
|
X2[seeds, 1],
|
|||
|
|
c=y_true[seeds],
|
|||
|
|
edgecolor="k",
|
|||
|
|
marker="x",
|
|||
|
|
s=100,
|
|||
|
|
linewidths=2,
|
|||
|
|
)
|
|||
|
|
ax.set_title(title)
|
|||
|
|
ax.set_xlabel("PCA 1")
|
|||
|
|
ax.set_ylabel("PCA 2")
|
|||
|
|
plt.tight_layout()
|
|||
|
|
plt.savefig(outpath, dpi=150)
|
|||
|
|
plt.close()
|
|||
|
|
print(f"✔ Saved semi-supervised clustering → {outpath}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# -----------------------------------------------------------------------------
|
|||
|
|
# MAIN
|
|||
|
|
# -----------------------------------------------------------------------------
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
ensure_dir(OUTPUT_DIR)
|
|||
|
|
|
|||
|
|
# load
|
|||
|
|
print("▶ Loading cats_vs_dogs...")
|
|||
|
|
ds, info = tfds.load(
|
|||
|
|
"cats_vs_dogs", split="train", with_info=True, as_supervised=True
|
|||
|
|
)
|
|||
|
|
ds = ds.shuffle(1000, reshuffle_each_iteration=False).cache()
|
|||
|
|
|
|||
|
|
# supervised
|
|||
|
|
plot_supervised_grid(
|
|||
|
|
ds,
|
|||
|
|
info,
|
|||
|
|
NUM_SUPERVISED,
|
|||
|
|
GRID_ROWS,
|
|||
|
|
GRID_COLS,
|
|||
|
|
os.path.join(OUTPUT_DIR, "supervised_grid.png"),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# unsupervised
|
|||
|
|
# also returns X2 for downstream demos
|
|||
|
|
X2, _ = plot_unsupervised_clusters(
|
|||
|
|
ds, os.path.join(OUTPUT_DIR, "unsupervised_clusters.png")
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# collect true labels for that same subset
|
|||
|
|
# (we need a y_true array for semi-sup demos)
|
|||
|
|
y_true = np.array([lbl.numpy() for _, lbl in ds.take(UNSUP_SAMPLES)])
|
|||
|
|
|
|||
|
|
# semi-supervised classification
|
|||
|
|
plot_semi_supervised_classification(
|
|||
|
|
X2, y_true, os.path.join(OUTPUT_DIR, "semi_supervised_classification.png")
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# semi-supervised clustering
|
|||
|
|
plot_semi_supervised_clustering(
|
|||
|
|
X2, y_true, os.path.join(OUTPUT_DIR, "semi_supervised_clustering.png")
|
|||
|
|
)
|