ocsvm working

This commit is contained in:
Jan Kowalczyk
2025-06-13 10:24:54 +02:00
parent d88719e718
commit 9298dea329
6 changed files with 376 additions and 137 deletions

View File

@@ -1,4 +1,3 @@
import json
import logging
import pickle
import time
@@ -20,7 +19,7 @@ from networks.main import build_autoencoder
class OCSVM(object):
"""A class for One-Class SVM models."""
def __init__(self, kernel="rbf", nu=0.1, hybrid=False):
def __init__(self, kernel="rbf", nu=0.1, hybrid=False, latent_space_dim=128):
"""Init OCSVM instance."""
self.kernel = kernel
self.nu = nu
@@ -30,6 +29,7 @@ class OCSVM(object):
self.model = OneClassSVM(kernel=kernel, nu=nu, verbose=True, max_mem_size=4048)
self.hybrid = hybrid
self.latent_space_dim = latent_space_dim
self.ae_net = None # autoencoder network for the case of a hybrid model
self.linear_model = (
None # also init a model with linear kernel if hybrid approach
@@ -38,8 +38,16 @@ class OCSVM(object):
self.results = {
"train_time": None,
"test_time": None,
"test_auc": None,
"test_scores": None,
"test_auc_exp_based": None,
"test_roc_exp_based": None,
"test_prc_exp_based": None,
"test_ap_exp_based": None,
"test_scores_exp_based": None,
"test_auc_manual_based": None,
"test_roc_manual_based": None,
"test_prc_manual_based": None,
"test_ap_manual_based": None,
"test_scores_manual_based": None,
"train_time_linear": None,
"test_time_linear": None,
"test_auc_linear": None,
@@ -70,15 +78,11 @@ class OCSVM(object):
# Get data from loader
X = ()
for data in train_loader:
inputs, _, _, _, _ = data
inputs, _, _, _, _, _ = data # Updated unpacking
inputs = inputs.to(device)
if self.hybrid:
inputs = self.ae_net.encoder(
inputs
) # in hybrid approach, take code representation of AE as features
X_batch = inputs.view(
inputs.size(0), -1
) # X_batch.shape = (batch_size, n_channels * height * width)
inputs = self.ae_net.encoder(inputs)
X_batch = inputs.view(inputs.size(0), -1)
X += (X_batch.cpu().data.numpy(),)
X = np.concatenate(X)
@@ -101,40 +105,59 @@ class OCSVM(object):
batch_size=batch_size, num_workers=n_jobs_dataloader
)
# Sample hold-out set from test set
X_test = ()
labels = []
labels_exp = []
labels_manual = []
for data in test_loader:
inputs, label_batch, _, _, _ = data
inputs, label_batch = inputs.to(device), label_batch.to(device)
inputs, label_exp, label_manual, _, _, _ = data # Updated unpacking
inputs = inputs.to(device)
label_exp = label_exp.to(device)
label_manual = label_manual.to(device)
if self.hybrid:
inputs = self.ae_net.encoder(
inputs
) # in hybrid approach, take code representation of AE as features
X_batch = inputs.view(
inputs.size(0), -1
) # X_batch.shape = (batch_size, n_channels * height * width)
inputs = self.ae_net.encoder(inputs)
X_batch = inputs.view(inputs.size(0), -1)
X_test += (X_batch.cpu().data.numpy(),)
labels += label_batch.cpu().data.numpy().astype(np.int64).tolist()
X_test, labels = np.concatenate(X_test), np.array(labels)
n_test, n_normal, n_outlier = (
len(X_test),
np.sum(labels == 0),
np.sum(labels == 1),
)
n_val = int(0.1 * n_test)
n_val_normal, n_val_outlier = (
int(n_val * (n_normal / n_test)),
int(n_val * (n_outlier / n_test)),
)
perm = np.random.permutation(n_test)
labels_exp += label_exp.cpu().data.numpy().astype(np.int64).tolist()
labels_manual += label_manual.cpu().data.numpy().astype(np.int64).tolist()
X_test = np.concatenate(X_test)
labels_exp = np.array(labels_exp)
labels_manual = np.array(labels_manual)
# Use experiment-based labels for model selection (could also use manual-based)
labels = labels_exp
# Count samples for validation split (-1: anomaly, 1: normal, 0: unknown)
n_test = len(X_test)
n_normal = np.sum(labels == 1)
n_outlier = np.sum(labels == -1)
n_val = int(0.1 * n_test) # Use 10% of test data for validation
# Only consider labeled samples for validation
valid_mask = labels != 0
X_test_valid = X_test[valid_mask]
labels_valid = labels[valid_mask]
# Calculate validation split sizes
n_val_normal = int(n_val * (n_normal / (n_normal + n_outlier)))
n_val_outlier = int(n_val * (n_outlier / (n_normal + n_outlier)))
# Create validation set with balanced normal/anomaly ratio
perm = np.random.permutation(len(X_test_valid))
X_val = np.concatenate(
(
X_test[perm][labels[perm] == 0][:n_val_normal],
X_test[perm][labels[perm] == 1][:n_val_outlier],
X_test_valid[perm][labels_valid[perm] == 1][:n_val_normal],
X_test_valid[perm][labels_valid[perm] == -1][:n_val_outlier],
)
)
labels = np.array([0] * n_val_normal + [1] * n_val_outlier)
val_labels = np.array(
[0] * n_val_normal + [1] * n_val_outlier
) # Convert to binary (0: normal, 1: anomaly)
# Model selection loop
i = 1
for gamma in gammas:
# Model candidate
@@ -155,12 +178,12 @@ class OCSVM(object):
scores = (-1.0) * model.decision_function(X_val)
scores = scores.flatten()
# Compute AUC
auc = roc_auc_score(labels, scores)
# Compute AUC with binary labels
auc = roc_auc_score(val_labels, scores)
logger.info(
f" | Model {i:02}/{len(gammas):02} | Gamma: {gamma:.8f} | Train Time: {train_time:.3f}s "
f"| Val AUC: {100. * auc:.2f} |"
f"| Val AUC: {100.0 * auc:.2f} |"
)
if auc > best_auc:
@@ -182,7 +205,7 @@ class OCSVM(object):
self.results["train_time_linear"] = train_time
logger.info(
f"Best Model: | Gamma: {self.gamma:.8f} | AUC: {100. * best_auc:.2f}"
f"Best Model: | Gamma: {self.gamma:.8f} | AUC: {100.0 * best_auc:.2f}"
)
logger.info("Training Time: {:.3f}s".format(self.results["train_time"]))
logger.info("Finished training.")
@@ -210,51 +233,121 @@ class OCSVM(object):
)
# Get data from loader
idx_label_score = []
idx_label_score_exp = []
idx_label_score_manual = []
X = ()
idxs = []
labels = []
labels_exp = []
labels_manual = []
for data in test_loader:
inputs, label_batch, _, idx, _ = data
inputs, label_batch, idx = (
inputs, label_exp, label_manual, _, idx, _ = data # Updated unpacking
inputs, label_exp, label_manual, idx = (
inputs.to(device),
label_batch.to(device),
label_exp.to(device),
label_manual.to(device),
idx.to(device),
)
if self.hybrid:
inputs = self.ae_net.encoder(
inputs
) # in hybrid approach, take code representation of AE as features
X_batch = inputs.view(
inputs.size(0), -1
) # X_batch.shape = (batch_size, n_channels * height * width)
inputs = self.ae_net.encoder(inputs)
X_batch = inputs.view(inputs.size(0), -1)
X += (X_batch.cpu().data.numpy(),)
idxs += idx.cpu().data.numpy().astype(np.int64).tolist()
labels += label_batch.cpu().data.numpy().astype(np.int64).tolist()
labels_exp += label_exp.cpu().data.numpy().astype(np.int64).tolist()
labels_manual += label_manual.cpu().data.numpy().astype(np.int64).tolist()
X = np.concatenate(X)
labels_exp = np.array(labels_exp)
labels_manual = np.array(labels_manual)
# Count and log label stats for exp_based
n_exp_normal = np.sum(labels_exp == 1)
n_exp_anomaly = np.sum(labels_exp == -1)
n_exp_unknown = np.sum(labels_exp == 0)
logger.info(
f"Exp-based labels: normal(1)={n_exp_normal}, "
f"anomaly(-1)={n_exp_anomaly}, unknown(0)={n_exp_unknown}"
)
# Count and log label stats for manual_based
n_manual_normal = np.sum(labels_manual == 1)
n_manual_anomaly = np.sum(labels_manual == -1)
n_manual_unknown = np.sum(labels_manual == 0)
logger.info(
f"Manual-based labels: normal(1)={n_manual_normal}, "
f"anomaly(-1)={n_manual_anomaly}, unknown(0)={n_manual_unknown}"
)
# Testing
logger.info("Starting testing...")
start_time = time.time()
scores = (-1.0) * self.model.decision_function(X)
self.results["test_time"] = time.time() - start_time
scores = scores.flatten()
self.rho = -self.model.intercept_[0]
# Save triples of (idx, label, score) in a list
idx_label_score += list(zip(idxs, labels, scores.tolist()))
self.results["test_scores"] = idx_label_score
# Save triples of (idx, label, score) for both label types
idx_label_score_exp += list(zip(idxs, labels_exp.tolist(), scores.tolist()))
idx_label_score_manual += list(
zip(idxs, labels_manual.tolist(), scores.tolist())
)
# Compute AUC
_, labels, scores = zip(*idx_label_score)
labels = np.array(labels)
scores = np.array(scores)
self.results["test_auc"] = roc_auc_score(labels, scores)
self.results["test_roc"] = roc_curve(labels, scores)
self.results["test_prc"] = precision_recall_curve(labels, scores)
self.results["test_ap"] = average_precision_score(labels, scores)
self.results["test_scores_exp_based"] = idx_label_score_exp
self.results["test_scores_manual_based"] = idx_label_score_manual
# --- Evaluation for exp_based (only labeled samples) ---
valid_mask_exp = labels_exp != 0
if np.any(valid_mask_exp):
labels_exp_binary = (labels_exp[valid_mask_exp] == -1).astype(int)
scores_exp_valid = scores[valid_mask_exp]
self.results["test_auc_exp_based"] = roc_auc_score(
labels_exp_binary, scores_exp_valid
)
self.results["test_roc_exp_based"] = roc_curve(
labels_exp_binary, scores_exp_valid
)
self.results["test_prc_exp_based"] = precision_recall_curve(
labels_exp_binary, scores_exp_valid
)
self.results["test_ap_exp_based"] = average_precision_score(
labels_exp_binary, scores_exp_valid
)
logger.info(
"Test AUC (exp_based): {:.2f}%".format(
100.0 * self.results["test_auc_exp_based"]
)
)
else:
logger.info("Test AUC (exp_based): N/A (no labeled samples)")
# --- Evaluation for manual_based (only labeled samples) ---
valid_mask_manual = labels_manual != 0
if np.any(valid_mask_manual):
labels_manual_binary = (labels_manual[valid_mask_manual] == -1).astype(int)
scores_manual_valid = scores[valid_mask_manual]
self.results["test_auc_manual_based"] = roc_auc_score(
labels_manual_binary, scores_manual_valid
)
self.results["test_roc_manual_based"] = roc_curve(
labels_manual_binary, scores_manual_valid
)
self.results["test_prc_manual_based"] = precision_recall_curve(
labels_manual_binary, scores_manual_valid
)
self.results["test_ap_manual_based"] = average_precision_score(
labels_manual_binary, scores_manual_valid
)
logger.info(
"Test AUC (manual_based): {:.2f}%".format(
100.0 * self.results["test_auc_manual_based"]
)
)
else:
logger.info("Test AUC (manual_based): N/A (no labeled samples)")
# If hybrid, also test model with linear kernel
if self.hybrid:
@@ -262,35 +355,115 @@ class OCSVM(object):
scores_linear = (-1.0) * self.linear_model.decision_function(X)
self.results["test_time_linear"] = time.time() - start_time
scores_linear = scores_linear.flatten()
self.results["test_auc_linear"] = roc_auc_score(labels, scores_linear)
logger.info(
"Test AUC linear model: {:.2f}%".format(
100.0 * self.results["test_auc_linear"]
# Save linear model results for both label types
# --- exp_based ---
valid_mask_exp_linear = labels_exp != 0
if np.any(valid_mask_exp_linear):
labels_exp_binary_linear = (
labels_exp[valid_mask_exp_linear] == -1
).astype(int)
scores_exp_linear_valid = scores_linear[valid_mask_exp_linear]
self.results["test_auc_linear_exp_based"] = roc_auc_score(
labels_exp_binary_linear, scores_exp_linear_valid
)
)
self.results["test_roc_linear_exp_based"] = roc_curve(
labels_exp_binary_linear, scores_exp_linear_valid
)
self.results["test_prc_linear_exp_based"] = precision_recall_curve(
labels_exp_binary_linear, scores_exp_linear_valid
)
self.results["test_ap_linear_exp_based"] = average_precision_score(
labels_exp_binary_linear, scores_exp_linear_valid
)
else:
self.results["test_auc_linear_exp_based"] = None
self.results["test_roc_linear_exp_based"] = None
self.results["test_prc_linear_exp_based"] = None
self.results["test_ap_linear_exp_based"] = None
# --- manual_based ---
valid_mask_manual_linear = labels_manual != 0
if np.any(valid_mask_manual_linear):
labels_manual_binary_linear = (
labels_manual[valid_mask_manual_linear] == -1
).astype(int)
scores_manual_linear_valid = scores_linear[valid_mask_manual_linear]
self.results["test_auc_linear_manual_based"] = roc_auc_score(
labels_manual_binary_linear, scores_manual_linear_valid
)
self.results["test_roc_linear_manual_based"] = roc_curve(
labels_manual_binary_linear, scores_manual_linear_valid
)
self.results["test_prc_linear_manual_based"] = precision_recall_curve(
labels_manual_binary_linear, scores_manual_linear_valid
)
self.results["test_ap_linear_manual_based"] = average_precision_score(
labels_manual_binary_linear, scores_manual_linear_valid
)
else:
self.results["test_auc_linear_manual_based"] = None
self.results["test_roc_linear_manual_based"] = None
self.results["test_prc_linear_manual_based"] = None
self.results["test_ap_linear_manual_based"] = None
# Log exp_based results for linear model
if self.results["test_auc_linear_exp_based"] is not None:
logger.info(
"Test AUC linear model (exp_based): {:.2f}%".format(
100.0 * self.results["test_auc_linear_exp_based"]
)
)
else:
logger.info(
"Test AUC linear model (exp_based): N/A (no labeled samples)"
)
# Log manual_based results for linear model
if self.results["test_auc_linear_manual_based"] is not None:
logger.info(
"Test AUC linear model (manual_based): {:.2f}%".format(
100.0 * self.results["test_auc_linear_manual_based"]
)
)
else:
logger.info(
"Test AUC linear model (manual_based): N/A (no labeled samples)"
)
logger.info(
"Test Time linear model: {:.3f}s".format(
self.results["test_time_linear"]
)
)
# Log results
logger.info("Test AUC: {:.2f}%".format(100.0 * self.results["test_auc"]))
# Log results for both label types
if self.results.get("test_auc_exp_based") is not None:
logger.info(
"Test AUC (exp_based): {:.2f}%".format(
100.0 * self.results["test_auc_exp_based"]
)
)
else:
logger.info("Test AUC (exp_based): N/A (no labeled samples)")
if self.results.get("test_auc_manual_based") is not None:
logger.info(
"Test AUC (manual_based): {:.2f}%".format(
100.0 * self.results["test_auc_manual_based"]
)
)
else:
logger.info("Test AUC (manual_based): N/A (no labeled samples)")
logger.info("Test Time: {:.3f}s".format(self.results["test_time"]))
logger.info("Finished testing.")
def load_ae(self, dataset_name, model_path):
def load_ae(self, model_path, net_name, device="cpu"):
"""Load pretrained autoencoder from model_path for feature extraction in a hybrid OC-SVM model."""
model_dict = torch.load(model_path, map_location="cpu")
ae_net_dict = model_dict["ae_net_dict"]
if dataset_name in ["mnist", "fmnist", "cifar10"]:
net_name = dataset_name + "_LeNet"
else:
net_name = dataset_name + "_mlp"
if self.ae_net is None:
self.ae_net = build_autoencoder(net_name)
self.ae_net = build_autoencoder(net_name, rep_dim=self.latent_space_dim)
# update keys (since there was a change in network definition)
ae_keys = list(self.ae_net.state_dict().keys())
@@ -301,6 +474,8 @@ class OCSVM(object):
i += 1
self.ae_net.load_state_dict(ae_net_dict)
if device != "cpu":
self.ae_net.to(torch.device(device))
self.ae_net.eval()
def save_model(self, export_path):