ae elbow work
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import pickle
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
@@ -278,13 +279,6 @@ from utils.visualization.plot_images_grid import plot_images_grid
|
||||
default=-1,
|
||||
help="Number of jobs for model training.",
|
||||
)
|
||||
@click.option(
|
||||
"--ae_elbow_dims",
|
||||
type=int,
|
||||
multiple=True,
|
||||
default=[128, 256, 384, 512, 768, 1024],
|
||||
help="List of latent space dimensions to test for autoencoder elbow analysis.",
|
||||
)
|
||||
def main(
|
||||
action,
|
||||
dataset_name,
|
||||
@@ -327,7 +321,6 @@ def main(
|
||||
isoforest_max_samples,
|
||||
isoforest_contamination,
|
||||
isoforest_n_jobs_model,
|
||||
ae_elbow_dims,
|
||||
):
|
||||
"""
|
||||
Deep SAD, a method for deep semi-supervised anomaly detection.
|
||||
@@ -786,6 +779,8 @@ def main(
|
||||
)
|
||||
|
||||
# Dictionary to store results for each dimension
|
||||
# ae_elbow_dims = [32, 64, 128, 256, 384, 512, 768, 1024]
|
||||
ae_elbow_dims = [32, 64]
|
||||
elbow_results = {"dimensions": list(ae_elbow_dims), "ae_results": {}}
|
||||
|
||||
# Test each dimension
|
||||
@@ -812,25 +807,16 @@ def main(
|
||||
)
|
||||
|
||||
# Store results for this dimension
|
||||
elbow_results["ae_results"][rep_dim] = {
|
||||
"train_time": deepSAD.ae.train_time,
|
||||
"train_loss": deepSAD.ae.train_loss,
|
||||
"test_auc": deepSAD.ae.test_auc, # if available
|
||||
"test_loss": deepSAD.ae.test_loss,
|
||||
"scores": deepSAD.ae.test_scores,
|
||||
}
|
||||
elbow_results["ae_results"][rep_dim] = deepSAD.ae_results
|
||||
|
||||
logger.info(f"Finished testing dimension {rep_dim}")
|
||||
logger.info(f"Train time: {deepSAD.ae.train_time:.3f}s")
|
||||
logger.info(f"Final train loss: {deepSAD.ae.train_loss[-1]:.6f}")
|
||||
logger.info(f"Final test loss: {deepSAD.ae.test_loss:.6f}")
|
||||
|
||||
# Clear some memory
|
||||
del deepSAD
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Save all results
|
||||
results_path = Path(xp_path) / "ae_elbow_results.pkl"
|
||||
results_path = Path(xp_path) / f"ae_elbow_results_{net_name}.pkl"
|
||||
with open(results_path, "wb") as f:
|
||||
pickle.dump(elbow_results, f)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user