tools, lockfile, deps
This commit is contained in:
105
tools/plot_scripts/background_ml_unsupervised.py
Normal file
105
tools/plot_scripts/background_ml_unsupervised.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Downloads the cats_vs_dogs dataset, then generates an unsupervised clusters image:
|
||||
- unsupervised_clusters.png
|
||||
|
||||
This script saves outputs in a datetime folder and also copies the latest outputs to a "latest" folder.
|
||||
All versions of the outputs and scripts are archived.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# CONFIGURATION
|
||||
# -----------------------------------------------------------------------------
|
||||
UNSUP_SAMPLES = 200
|
||||
|
||||
output_path = Path("/home/fedex/mt/plots/background_ml_unsupervised")
|
||||
datetime_folder_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
latest_folder_path = output_path / "latest"
|
||||
archive_folder_path = output_path / "archive"
|
||||
output_datetime_path = output_path / datetime_folder_name
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# UTILITIES
|
||||
# -----------------------------------------------------------------------------
|
||||
def ensure_dir(directory: Path):
|
||||
directory.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
# Create required output directories
|
||||
ensure_dir(output_path)
|
||||
ensure_dir(output_datetime_path)
|
||||
ensure_dir(latest_folder_path)
|
||||
ensure_dir(archive_folder_path)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Unsupervised Clustering Plot (PCA + KMeans)
|
||||
# -----------------------------------------------------------------------------
|
||||
def plot_unsupervised_clusters(ds, outpath):
|
||||
"""
|
||||
Processes a subset of images from the dataset, reduces their dimensionality with PCA,
|
||||
applies KMeans clustering, and saves a scatterplot of the clusters.
|
||||
"""
|
||||
imgs = []
|
||||
for img, _ in ds.take(UNSUP_SAMPLES):
|
||||
# resize to 64x64, convert to grayscale by averaging channels
|
||||
arr = tf.image.resize(img, (64, 64)).numpy().astype("float32").mean(axis=2)
|
||||
imgs.append(arr.ravel() / 255.0)
|
||||
X = np.stack(imgs)
|
||||
pca = PCA(n_components=2, random_state=0)
|
||||
X2 = pca.fit_transform(X)
|
||||
|
||||
km = KMeans(n_clusters=2, random_state=0)
|
||||
clusters = km.fit_predict(X2)
|
||||
|
||||
plt.figure(figsize=(6, 6))
|
||||
plt.scatter(X2[:, 0], X2[:, 1], c=clusters, s=15, alpha=0.6)
|
||||
plt.title("Unsupervised: PCA + KMeans")
|
||||
plt.xlabel("PCA 1")
|
||||
plt.ylabel("PCA 2")
|
||||
plt.tight_layout()
|
||||
plt.savefig(outpath, dpi=150)
|
||||
plt.close()
|
||||
print(f"✔ Saved unsupervised clusters → {outpath}")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# MAIN
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
# Load and prepare the dataset
|
||||
print("▶ Loading cats_vs_dogs dataset...")
|
||||
ds, _ = tfds.load("cats_vs_dogs", split="train", with_info=True, as_supervised=True)
|
||||
ds = ds.shuffle(1000, reshuffle_each_iteration=False).cache()
|
||||
|
||||
# Generate the unsupervised clusters image
|
||||
unsupervised_outfile = output_datetime_path / "unsupervised_clusters.png"
|
||||
plot_unsupervised_clusters(ds, unsupervised_outfile)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Update the 'latest' results folder: remove previous and copy current outputs
|
||||
# -----------------------------------------------------------------------------
|
||||
shutil.rmtree(latest_folder_path, ignore_errors=True)
|
||||
ensure_dir(latest_folder_path)
|
||||
for file in output_datetime_path.iterdir():
|
||||
shutil.copy2(file, latest_folder_path)
|
||||
|
||||
# Copy this script to the output folder and to the latest folder to preserve the used code
|
||||
script_path = Path(__file__)
|
||||
shutil.copy2(script_path, output_datetime_path)
|
||||
shutil.copy2(script_path, latest_folder_path)
|
||||
|
||||
# Move the output datetime folder to the archive folder for record keeping
|
||||
shutil.move(output_datetime_path, archive_folder_path)
|
||||
Reference in New Issue
Block a user