100 lines
3.6 KiB
Python
100 lines
3.6 KiB
Python
|
|
"""
|
||
|
|
Downloads the cats_vs_dogs dataset, then generates a supervised grid image:
|
||
|
|
- supervised_grid.png
|
||
|
|
|
||
|
|
This script saves outputs in a datetime folder and also copies the latest outputs to a "latest" folder.
|
||
|
|
All versions of the outputs and scripts are archived.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from datetime import datetime
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import matplotlib.pyplot as plt
|
||
|
|
import tensorflow_datasets as tfds
|
||
|
|
|
||
|
|
# -----------------------------------------------------------------------------
|
||
|
|
# CONFIGURATION
|
||
|
|
# -----------------------------------------------------------------------------
|
||
|
|
# Number of supervised samples and grid dimensions
|
||
|
|
NUM_SUPERVISED = 16
|
||
|
|
GRID_ROWS = 4
|
||
|
|
GRID_COLS = 4
|
||
|
|
|
||
|
|
# Output directories for saving plots and scripts
|
||
|
|
output_path = Path("/home/fedex/mt/plots/background_ml_supervised")
|
||
|
|
datetime_folder_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||
|
|
|
||
|
|
latest_folder_path = output_path / "latest"
|
||
|
|
archive_folder_path = output_path / "archive"
|
||
|
|
output_datetime_path = output_path / datetime_folder_name
|
||
|
|
|
||
|
|
|
||
|
|
# -----------------------------------------------------------------------------
|
||
|
|
# UTILITIES
|
||
|
|
# -----------------------------------------------------------------------------
|
||
|
|
def ensure_dir(directory: Path):
|
||
|
|
directory.mkdir(exist_ok=True, parents=True)
|
||
|
|
|
||
|
|
|
||
|
|
# Create required output directories
|
||
|
|
ensure_dir(output_path)
|
||
|
|
ensure_dir(output_datetime_path)
|
||
|
|
ensure_dir(latest_folder_path)
|
||
|
|
ensure_dir(archive_folder_path)
|
||
|
|
|
||
|
|
|
||
|
|
# -----------------------------------------------------------------------------
|
||
|
|
# Supervised grid plot
|
||
|
|
# -----------------------------------------------------------------------------
|
||
|
|
def plot_supervised_grid(ds, info, num, rows, cols, outpath):
|
||
|
|
"""
|
||
|
|
Plots a grid of images from the dataset with their corresponding labels.
|
||
|
|
"""
|
||
|
|
plt.figure(figsize=(cols * 2, rows * 2))
|
||
|
|
for i, (img, lbl) in enumerate(ds.take(num)):
|
||
|
|
ax = plt.subplot(rows, cols, i + 1)
|
||
|
|
ax.imshow(img.numpy().astype("uint8"))
|
||
|
|
ax.axis("off")
|
||
|
|
cname = info.features["label"].int2str(lbl.numpy())
|
||
|
|
ax.set_title(cname, fontsize=9)
|
||
|
|
plt.tight_layout()
|
||
|
|
plt.savefig(outpath, dpi=150)
|
||
|
|
plt.close()
|
||
|
|
print(f"✔ Saved supervised grid → {outpath}")
|
||
|
|
|
||
|
|
|
||
|
|
# -----------------------------------------------------------------------------
|
||
|
|
# MAIN
|
||
|
|
# -----------------------------------------------------------------------------
|
||
|
|
if __name__ == "__main__":
|
||
|
|
# Download and prepare the dataset
|
||
|
|
print("▶ Loading cats_vs_dogs dataset...")
|
||
|
|
ds, info = tfds.load(
|
||
|
|
"cats_vs_dogs", split="train", with_info=True, as_supervised=True
|
||
|
|
)
|
||
|
|
ds = ds.shuffle(1000, reshuffle_each_iteration=False).cache()
|
||
|
|
|
||
|
|
# Generate the supervised grid image
|
||
|
|
supervised_outfile = output_datetime_path / "supervised_grid.png"
|
||
|
|
plot_supervised_grid(
|
||
|
|
ds, info, NUM_SUPERVISED, GRID_ROWS, GRID_COLS, supervised_outfile
|
||
|
|
)
|
||
|
|
|
||
|
|
# -----------------------------------------------------------------------------
|
||
|
|
# Update the 'latest' results folder: remove previous and copy current outputs
|
||
|
|
# -----------------------------------------------------------------------------
|
||
|
|
import shutil
|
||
|
|
|
||
|
|
shutil.rmtree(latest_folder_path, ignore_errors=True)
|
||
|
|
ensure_dir(latest_folder_path)
|
||
|
|
for file in output_datetime_path.iterdir():
|
||
|
|
shutil.copy2(file, latest_folder_path)
|
||
|
|
|
||
|
|
# Copy this script to the output folder and to the latest folder to preserve the used code
|
||
|
|
script_path = Path(__file__)
|
||
|
|
shutil.copy2(script_path, output_datetime_path)
|
||
|
|
shutil.copy2(script_path, latest_folder_path)
|
||
|
|
|
||
|
|
# Move the output datetime folder to the archive folder for record keeping
|
||
|
|
shutil.move(output_datetime_path, archive_folder_path)
|