Files
mt/tools/plot_scripts/background_ml_supervised.py

100 lines
3.6 KiB
Python
Raw Normal View History

2025-08-13 14:17:12 +02:00
"""
Downloads the cats_vs_dogs dataset, then generates a supervised grid image:
- supervised_grid.png
This script saves outputs in a datetime folder and also copies the latest outputs to a "latest" folder.
All versions of the outputs and scripts are archived.
"""
from datetime import datetime
from pathlib import Path
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
# -----------------------------------------------------------------------------
# CONFIGURATION
# -----------------------------------------------------------------------------
# Number of supervised samples and grid dimensions
NUM_SUPERVISED = 16
GRID_ROWS = 4
GRID_COLS = 4
# Output directories for saving plots and scripts
output_path = Path("/home/fedex/mt/plots/background_ml_supervised")
datetime_folder_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
latest_folder_path = output_path / "latest"
archive_folder_path = output_path / "archive"
output_datetime_path = output_path / datetime_folder_name
# -----------------------------------------------------------------------------
# UTILITIES
# -----------------------------------------------------------------------------
def ensure_dir(directory: Path):
directory.mkdir(exist_ok=True, parents=True)
# Create required output directories
ensure_dir(output_path)
ensure_dir(output_datetime_path)
ensure_dir(latest_folder_path)
ensure_dir(archive_folder_path)
# -----------------------------------------------------------------------------
# Supervised grid plot
# -----------------------------------------------------------------------------
def plot_supervised_grid(ds, info, num, rows, cols, outpath):
"""
Plots a grid of images from the dataset with their corresponding labels.
"""
plt.figure(figsize=(cols * 2, rows * 2))
for i, (img, lbl) in enumerate(ds.take(num)):
ax = plt.subplot(rows, cols, i + 1)
ax.imshow(img.numpy().astype("uint8"))
ax.axis("off")
cname = info.features["label"].int2str(lbl.numpy())
ax.set_title(cname, fontsize=9)
plt.tight_layout()
plt.savefig(outpath, dpi=150)
plt.close()
print(f"✔ Saved supervised grid → {outpath}")
# -----------------------------------------------------------------------------
# MAIN
# -----------------------------------------------------------------------------
if __name__ == "__main__":
# Download and prepare the dataset
print("▶ Loading cats_vs_dogs dataset...")
ds, info = tfds.load(
"cats_vs_dogs", split="train", with_info=True, as_supervised=True
)
ds = ds.shuffle(1000, reshuffle_each_iteration=False).cache()
# Generate the supervised grid image
supervised_outfile = output_datetime_path / "supervised_grid.png"
plot_supervised_grid(
ds, info, NUM_SUPERVISED, GRID_ROWS, GRID_COLS, supervised_outfile
)
# -----------------------------------------------------------------------------
# Update the 'latest' results folder: remove previous and copy current outputs
# -----------------------------------------------------------------------------
import shutil
shutil.rmtree(latest_folder_path, ignore_errors=True)
ensure_dir(latest_folder_path)
for file in output_datetime_path.iterdir():
shutil.copy2(file, latest_folder_path)
# Copy this script to the output folder and to the latest folder to preserve the used code
script_path = Path(__file__)
shutil.copy2(script_path, output_datetime_path)
shutil.copy2(script_path, latest_folder_path)
# Move the output datetime folder to the archive folder for record keeping
shutil.move(output_datetime_path, archive_folder_path)