106 lines
4.0 KiB
Python
106 lines
4.0 KiB
Python
"""
|
|
Downloads the cats_vs_dogs dataset, then generates an unsupervised clusters image:
|
|
- unsupervised_clusters.png
|
|
|
|
This script saves outputs in a datetime folder and also copies the latest outputs to a "latest" folder.
|
|
All versions of the outputs and scripts are archived.
|
|
"""
|
|
|
|
import shutil
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import tensorflow_datasets as tfds
|
|
from sklearn.cluster import KMeans
|
|
from sklearn.decomposition import PCA
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# CONFIGURATION
|
|
# -----------------------------------------------------------------------------
|
|
UNSUP_SAMPLES = 200
|
|
|
|
output_path = Path("/home/fedex/mt/plots/background_ml_unsupervised")
|
|
datetime_folder_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
|
|
latest_folder_path = output_path / "latest"
|
|
archive_folder_path = output_path / "archive"
|
|
output_datetime_path = output_path / datetime_folder_name
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# UTILITIES
|
|
# -----------------------------------------------------------------------------
|
|
def ensure_dir(directory: Path):
|
|
directory.mkdir(exist_ok=True, parents=True)
|
|
|
|
|
|
# Create required output directories
|
|
ensure_dir(output_path)
|
|
ensure_dir(output_datetime_path)
|
|
ensure_dir(latest_folder_path)
|
|
ensure_dir(archive_folder_path)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Unsupervised Clustering Plot (PCA + KMeans)
|
|
# -----------------------------------------------------------------------------
|
|
def plot_unsupervised_clusters(ds, outpath):
|
|
"""
|
|
Processes a subset of images from the dataset, reduces their dimensionality with PCA,
|
|
applies KMeans clustering, and saves a scatterplot of the clusters.
|
|
"""
|
|
imgs = []
|
|
for img, _ in ds.take(UNSUP_SAMPLES):
|
|
# resize to 64x64, convert to grayscale by averaging channels
|
|
arr = tf.image.resize(img, (64, 64)).numpy().astype("float32").mean(axis=2)
|
|
imgs.append(arr.ravel() / 255.0)
|
|
X = np.stack(imgs)
|
|
pca = PCA(n_components=2, random_state=0)
|
|
X2 = pca.fit_transform(X)
|
|
|
|
km = KMeans(n_clusters=2, random_state=0)
|
|
clusters = km.fit_predict(X2)
|
|
|
|
plt.figure(figsize=(6, 6))
|
|
plt.scatter(X2[:, 0], X2[:, 1], c=clusters, s=15, alpha=0.6)
|
|
plt.title("Unsupervised: PCA + KMeans")
|
|
plt.xlabel("PCA 1")
|
|
plt.ylabel("PCA 2")
|
|
plt.tight_layout()
|
|
plt.savefig(outpath, dpi=150)
|
|
plt.close()
|
|
print(f"✔ Saved unsupervised clusters → {outpath}")
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# MAIN
|
|
# -----------------------------------------------------------------------------
|
|
if __name__ == "__main__":
|
|
# Load and prepare the dataset
|
|
print("▶ Loading cats_vs_dogs dataset...")
|
|
ds, _ = tfds.load("cats_vs_dogs", split="train", with_info=True, as_supervised=True)
|
|
ds = ds.shuffle(1000, reshuffle_each_iteration=False).cache()
|
|
|
|
# Generate the unsupervised clusters image
|
|
unsupervised_outfile = output_datetime_path / "unsupervised_clusters.png"
|
|
plot_unsupervised_clusters(ds, unsupervised_outfile)
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Update the 'latest' results folder: remove previous and copy current outputs
|
|
# -----------------------------------------------------------------------------
|
|
shutil.rmtree(latest_folder_path, ignore_errors=True)
|
|
ensure_dir(latest_folder_path)
|
|
for file in output_datetime_path.iterdir():
|
|
shutil.copy2(file, latest_folder_path)
|
|
|
|
# Copy this script to the output folder and to the latest folder to preserve the used code
|
|
script_path = Path(__file__)
|
|
shutil.copy2(script_path, output_datetime_path)
|
|
shutil.copy2(script_path, latest_folder_path)
|
|
|
|
# Move the output datetime folder to the archive folder for record keeping
|
|
shutil.move(output_datetime_path, archive_folder_path)
|