173 lines
6.0 KiB
Python
173 lines
6.0 KiB
Python
|
|
import pickle
|
||
|
|
import re
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import matplotlib.pyplot as plt
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
results_path = Path("/home/fedex/mt/results/done")
|
||
|
|
|
||
|
|
efficient_paths, lenet_paths = dict(), dict()
|
||
|
|
|
||
|
|
for result_folder in results_path.iterdir():
|
||
|
|
if not result_folder.is_dir():
|
||
|
|
continue
|
||
|
|
if "n0_a0" not in result_folder.name:
|
||
|
|
continue
|
||
|
|
|
||
|
|
if "efficient" in result_folder.name:
|
||
|
|
match = re.search(r"subter_efficient_latent(\d+)_n0_a0", result_folder.name)
|
||
|
|
if match:
|
||
|
|
latent_value = int(match.group(1))
|
||
|
|
else:
|
||
|
|
raise ValueError("Could not extract latent value from string using regex")
|
||
|
|
efficient_paths[latent_value] = result_folder
|
||
|
|
if "LeNet" in result_folder.name:
|
||
|
|
match = re.search(r"subter_LeNet_latent(\d+)_n0_a0", result_folder.name)
|
||
|
|
if match:
|
||
|
|
latent_value = int(match.group(1))
|
||
|
|
else:
|
||
|
|
raise ValueError("Could not extract latent value from string using regex")
|
||
|
|
lenet_paths[latent_value] = result_folder
|
||
|
|
|
||
|
|
"test"
|
||
|
|
"exp_basedmanual_based"
|
||
|
|
"auc"
|
||
|
|
|
||
|
|
results = dict()
|
||
|
|
print("Efficient paths:")
|
||
|
|
latent_dims = set()
|
||
|
|
for latent_value, path in sorted(efficient_paths.items()):
|
||
|
|
latent_dims.add(latent_value)
|
||
|
|
print(f"Latent {latent_value}: {path}")
|
||
|
|
for kfold_idx in range(5):
|
||
|
|
for method in ["deepsad", "ocsvm", "isoforest"]:
|
||
|
|
with open(path / f"results_{method}_{kfold_idx}.pkl", "rb") as f:
|
||
|
|
results.setdefault("efficient", {}).setdefault(
|
||
|
|
latent_value, {}
|
||
|
|
).setdefault(method, {})[kfold_idx] = pickle.load(f)
|
||
|
|
|
||
|
|
print("\nLeNet paths:")
|
||
|
|
for latent_value, path in sorted(lenet_paths.items()):
|
||
|
|
print(f"Latent {latent_value}: {path}")
|
||
|
|
for kfold_idx in range(5):
|
||
|
|
for method in ["deepsad", "ocsvm", "isoforest"]:
|
||
|
|
with open(path / f"results_{method}_{kfold_idx}.pkl", "rb") as f:
|
||
|
|
results.setdefault("lenet", {}).setdefault(latent_value, {}).setdefault(
|
||
|
|
method, {}
|
||
|
|
)[kfold_idx] = pickle.load(f)
|
||
|
|
|
||
|
|
|
||
|
|
for latent_dim in latent_dims:
|
||
|
|
for network in ["efficient", "lenet"]:
|
||
|
|
for method in ["deepsad", "ocsvm", "isoforest"]:
|
||
|
|
if (
|
||
|
|
latent_dim not in results[network]
|
||
|
|
or method not in results[network][latent_dim]
|
||
|
|
):
|
||
|
|
raise ValueError(
|
||
|
|
f"Missing results for {network} with latent {latent_dim} and method {method}"
|
||
|
|
)
|
||
|
|
if method == "deepsad":
|
||
|
|
results[network][latent_dim][method]["mean_auc_exp"] = np.mean(
|
||
|
|
[
|
||
|
|
results[network][latent_dim][method][kfold_idx]["test"][
|
||
|
|
"exp_based"
|
||
|
|
]["auc"]
|
||
|
|
for kfold_idx in range(5)
|
||
|
|
]
|
||
|
|
)
|
||
|
|
results[network][latent_dim][method]["mean_auc_man"] = np.mean(
|
||
|
|
[
|
||
|
|
results[network][latent_dim][method][kfold_idx]["test"][
|
||
|
|
"manual_based"
|
||
|
|
]["auc"]
|
||
|
|
for kfold_idx in range(5)
|
||
|
|
]
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
results[network][latent_dim][method]["mean_auc_exp"] = np.mean(
|
||
|
|
[
|
||
|
|
results[network][latent_dim][method][kfold_idx][
|
||
|
|
"test_auc_exp_based"
|
||
|
|
]
|
||
|
|
for kfold_idx in range(5)
|
||
|
|
]
|
||
|
|
)
|
||
|
|
results[network][latent_dim][method]["mean_auc_man"] = np.mean(
|
||
|
|
[
|
||
|
|
results[network][latent_dim][method][kfold_idx][
|
||
|
|
"test_auc_manual_based"
|
||
|
|
]
|
||
|
|
for kfold_idx in range(5)
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def plot_auc_comparison(results, evaluation_type):
|
||
|
|
"""Plot AUC comparison across methods, architectures and latent dimensions.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
results: Dict containing all results
|
||
|
|
evaluation_type: Either 'exp' or 'man' for experiment or manual based evaluation
|
||
|
|
"""
|
||
|
|
plt.figure(figsize=(10, 6))
|
||
|
|
|
||
|
|
# Define markers for methods
|
||
|
|
markers = {"deepsad": "o", "ocsvm": "s", "isoforest": "^"}
|
||
|
|
|
||
|
|
# Define base colors for architectures and method-specific lightness
|
||
|
|
base_colors = {
|
||
|
|
"efficient": "#1f77b4", # blue
|
||
|
|
"lenet": "#d62728", # red
|
||
|
|
}
|
||
|
|
# Different alpha values for methods
|
||
|
|
method_alphas = {
|
||
|
|
"deepsad": 1.0, # full intensity
|
||
|
|
"ocsvm": 0.7, # slightly lighter
|
||
|
|
"isoforest": 0.4, # even lighter
|
||
|
|
}
|
||
|
|
|
||
|
|
# Get all latent dimensions
|
||
|
|
latent_dims = sorted(list(results["efficient"].keys()))
|
||
|
|
|
||
|
|
# Plot each method and architecture combination
|
||
|
|
for network in ["efficient", "lenet"]:
|
||
|
|
for method in ["deepsad", "ocsvm", "isoforest"]:
|
||
|
|
auc_values = [
|
||
|
|
results[network][dim][method][f"mean_auc_{evaluation_type}"]
|
||
|
|
for dim in latent_dims
|
||
|
|
]
|
||
|
|
|
||
|
|
plt.plot(
|
||
|
|
latent_dims,
|
||
|
|
auc_values,
|
||
|
|
marker=markers[method],
|
||
|
|
color=base_colors[network],
|
||
|
|
alpha=method_alphas[method],
|
||
|
|
linestyle="-" if network == "efficient" else "--",
|
||
|
|
label=f"{network.capitalize()} {method.upper()}",
|
||
|
|
markersize=8,
|
||
|
|
)
|
||
|
|
|
||
|
|
plt.xlabel("Latent Dimension")
|
||
|
|
plt.ylabel("Mean AUC")
|
||
|
|
if evaluation_type == "exp":
|
||
|
|
plt.title("AUC Comparison (Experiment Based Evaluation Labels)")
|
||
|
|
else:
|
||
|
|
plt.title("AUC Comparison (Manual Based Evaluation Labels)")
|
||
|
|
plt.grid(True, alpha=0.3)
|
||
|
|
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
|
||
|
|
plt.xticks(latent_dims)
|
||
|
|
plt.tight_layout()
|
||
|
|
return plt.gcf()
|
||
|
|
|
||
|
|
|
||
|
|
# Create and save both plots
|
||
|
|
for eval_type in ["exp", "man"]:
|
||
|
|
fig = plot_auc_comparison(results, eval_type)
|
||
|
|
fig.savefig(
|
||
|
|
f"auc_comp/auc_comparison_{eval_type}_based.png", dpi=300, bbox_inches="tight"
|
||
|
|
)
|
||
|
|
plt.close(fig)
|