Files
mt/Deep-SAD-PyTorch/src/DeepSAD.py

351 lines
13 KiB
Python
Raw Normal View History

2024-06-28 07:42:12 +02:00
import json
import pickle
2025-06-10 09:31:28 +02:00
import numpy as np
2024-06-28 07:42:12 +02:00
import torch
from base.base_dataset import BaseADDataset
from networks.main import build_autoencoder, build_network
2024-06-28 07:42:12 +02:00
from optim.ae_trainer import AETrainer
from optim.DeepSAD_trainer import DeepSADTrainer
2024-06-28 07:42:12 +02:00
class DeepSAD(object):
"""A class for the Deep SAD method.
Attributes:
eta: Deep SAD hyperparameter eta (must be 0 < eta).
c: Hypersphere center c.
net_name: A string indicating the name of the neural network to use.
net: The neural network phi.
trainer: DeepSADTrainer to train a Deep SAD model.
optimizer_name: A string indicating the optimizer to use for training the Deep SAD network.
ae_net: The autoencoder network corresponding to phi for network weights pretraining.
ae_trainer: AETrainer to train an autoencoder in pretraining.
ae_optimizer_name: A string indicating the optimizer to use for pretraining the autoencoder.
results: A dictionary to save the results.
ae_results: A dictionary to save the autoencoder results.
"""
2025-06-13 10:24:54 +02:00
def __init__(self, rep_dim: int, eta: float = 1.0):
2024-06-28 07:42:12 +02:00
"""Inits DeepSAD with hyperparameter eta."""
2025-06-13 10:24:54 +02:00
self.rep_dim = rep_dim # representation dimension
2024-06-28 07:42:12 +02:00
self.eta = eta
self.c = None # hypersphere center c
self.net_name = None
self.net = None # neural network phi
self.trainer = None
self.optimizer_name = None
self.ae_net = None # autoencoder network for pretraining
self.ae_trainer = None
self.ae_optimizer_name = None
self.results = {
2025-06-10 09:31:28 +02:00
"train": {
"time": None,
"indices": None,
"file_ids": None,
"frame_ids": None,
"file_names": None, # mapping of file_ids to file names
},
"test": {
"time": None,
"indices": None,
"file_ids": None,
"frame_ids": None,
"file_names": None, # mapping of file_ids to file names
"exp_based": {
"auc": None,
"roc": None,
"prc": None,
"ap": None,
"scores": None,
"indices": None,
"labels": None,
"valid_mask": None,
},
"manual_based": {
"auc": None,
"roc": None,
"prc": None,
"ap": None,
"scores": None,
"indices": None,
"labels": None,
"valid_mask": None,
},
},
"inference": {
"time": None,
"indices": None,
"file_ids": None,
"frame_ids": None,
"file_names": None, # mapping of file_ids to file names
},
2024-06-28 07:42:12 +02:00
}
2024-06-28 11:36:46 +02:00
self.ae_results = {"train_time": None, "test_auc": None, "test_time": None}
2024-06-28 07:42:12 +02:00
2025-06-13 10:24:54 +02:00
def set_network(self, net_name):
2024-06-28 07:42:12 +02:00
"""Builds the neural network phi."""
self.net_name = net_name
2025-06-13 10:24:54 +02:00
self.net = build_network(net_name, self.rep_dim)
2024-06-28 07:42:12 +02:00
2024-06-28 11:36:46 +02:00
def train(
self,
dataset: BaseADDataset,
optimizer_name: str = "adam",
lr: float = 0.001,
n_epochs: int = 50,
lr_milestones: tuple = (),
batch_size: int = 128,
weight_decay: float = 1e-6,
device: str = "cuda",
n_jobs_dataloader: int = 0,
k_fold_idx: int = None,
2024-06-28 11:36:46 +02:00
):
2024-06-28 07:42:12 +02:00
"""Trains the Deep SAD model on the training data."""
self.optimizer_name = optimizer_name
2024-06-28 11:36:46 +02:00
self.trainer = DeepSADTrainer(
self.c,
self.eta,
optimizer_name=optimizer_name,
lr=lr,
n_epochs=n_epochs,
lr_milestones=lr_milestones,
batch_size=batch_size,
weight_decay=weight_decay,
device=device,
n_jobs_dataloader=n_jobs_dataloader,
)
2024-06-28 07:42:12 +02:00
# Get the model
self.net = self.trainer.train(dataset, self.net, k_fold_idx=k_fold_idx)
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list
2025-06-10 09:31:28 +02:00
# Store training results including indices
self.results["train"]["time"] = self.trainer.train_time
self.results["train"]["indices"] = self.trainer.train_indices
self.results["train"]["file_ids"] = self.trainer.train_file_ids
self.results["train"]["frame_ids"] = self.trainer.train_frame_ids
# Get file names mapping for training data
self.results["train"]["file_names"] = {
file_id: dataset.get_file_name_from_idx(file_id)
for file_id in np.unique(self.trainer.train_file_ids)
}
2024-06-28 07:42:12 +02:00
2024-07-04 15:36:01 +02:00
def inference(
self, dataset: BaseADDataset, device: str = "cuda", n_jobs_dataloader: int = 0
):
"""Tests the Deep SAD model on the test data."""
if self.trainer is None:
self.trainer = DeepSADTrainer(
self.c, self.eta, device=device, n_jobs_dataloader=n_jobs_dataloader
)
2025-06-10 09:31:28 +02:00
scores, outputs = self.trainer.infer(dataset, self.net)
# Store inference indices and mappings
self.results["inference"]["time"] = self.trainer.inference_time
self.results["inference"]["indices"] = self.trainer.inference_indices
self.results["inference"]["file_ids"] = self.trainer.inference_file_ids
self.results["inference"]["frame_ids"] = self.trainer.inference_frame_ids
# Get file names mapping for inference data
self.results["inference"]["file_names"] = {
file_id: dataset.get_file_name_from_idx(file_id)
for file_id in np.unique(self.trainer.inference_file_ids)
}
return scores, outputs
2024-07-04 15:36:01 +02:00
2024-06-28 11:36:46 +02:00
def test(
self,
dataset: BaseADDataset,
device: str = "cuda",
n_jobs_dataloader: int = 0,
k_fold_idx: int = None,
2024-06-28 11:36:46 +02:00
):
2024-06-28 07:42:12 +02:00
"""Tests the Deep SAD model on the test data."""
if self.trainer is None:
2024-06-28 11:36:46 +02:00
self.trainer = DeepSADTrainer(
self.c, self.eta, device=device, n_jobs_dataloader=n_jobs_dataloader
)
2024-06-28 07:42:12 +02:00
self.trainer.test(dataset, self.net, k_fold_idx=k_fold_idx)
2024-06-28 07:42:12 +02:00
2025-06-10 09:31:28 +02:00
# Store all test indices and mappings
self.results["test"]["time"] = self.trainer.test_time
self.results["test"]["indices"] = self.trainer.test_indices
self.results["test"]["file_ids"] = self.trainer.test_file_ids
self.results["test"]["frame_ids"] = self.trainer.test_frame_ids
# Get file names mapping for test data
self.results["test"]["file_names"] = {
file_id: dataset.get_file_name_from_idx(file_id)
for file_id in np.unique(self.trainer.test_file_ids)
}
# Store experiment-based results
self.results["test"]["exp_based"]["auc"] = self.trainer.test_auc_exp_based
self.results["test"]["exp_based"]["roc"] = self.trainer.test_roc_exp_based
self.results["test"]["exp_based"]["prc"] = self.trainer.test_prc_exp_based
self.results["test"]["exp_based"]["ap"] = self.trainer.test_ap_exp_based
self.results["test"]["exp_based"]["scores"] = self.trainer.test_scores_exp_based
self.results["test"]["exp_based"]["indices"] = self.trainer.test_index_mapping[
"exp_based"
]["indices"]
self.results["test"]["exp_based"]["labels"] = self.trainer.test_index_mapping[
"exp_based"
]["labels"]
self.results["test"]["exp_based"]["valid_mask"] = (
self.trainer.test_index_mapping["exp_based"]["valid_mask"]
)
# Store manual-based results
self.results["test"]["manual_based"]["auc"] = self.trainer.test_auc_manual_based
self.results["test"]["manual_based"]["roc"] = self.trainer.test_roc_manual_based
self.results["test"]["manual_based"]["prc"] = self.trainer.test_prc_manual_based
self.results["test"]["manual_based"]["ap"] = self.trainer.test_ap_manual_based
self.results["test"]["manual_based"]["scores"] = (
self.trainer.test_scores_manual_based
)
self.results["test"]["manual_based"]["indices"] = (
self.trainer.test_index_mapping["manual_based"]["indices"]
)
self.results["test"]["manual_based"]["labels"] = (
self.trainer.test_index_mapping["manual_based"]["labels"]
)
self.results["test"]["manual_based"]["valid_mask"] = (
self.trainer.test_index_mapping["manual_based"]["valid_mask"]
)
2024-06-28 11:36:46 +02:00
def pretrain(
self,
dataset: BaseADDataset,
optimizer_name: str = "adam",
lr: float = 0.001,
n_epochs: int = 100,
lr_milestones: tuple = (),
batch_size: int = 128,
weight_decay: float = 1e-6,
device: str = "cuda",
n_jobs_dataloader: int = 0,
k_fold_idx: int = None,
2024-06-28 11:36:46 +02:00
):
2024-06-28 07:42:12 +02:00
"""Pretrains the weights for the Deep SAD network phi via autoencoder."""
# Set autoencoder network
2025-06-13 10:24:54 +02:00
self.ae_net = build_autoencoder(self.net_name, self.rep_dim)
2024-06-28 07:42:12 +02:00
# Train
self.ae_optimizer_name = optimizer_name
2024-06-28 11:36:46 +02:00
self.ae_trainer = AETrainer(
optimizer_name,
lr=lr,
n_epochs=n_epochs,
lr_milestones=lr_milestones,
batch_size=batch_size,
weight_decay=weight_decay,
device=device,
n_jobs_dataloader=n_jobs_dataloader,
)
self.ae_net = self.ae_trainer.train(dataset, self.ae_net, k_fold_idx=k_fold_idx)
2024-06-28 07:42:12 +02:00
# Test
self.ae_trainer.test(dataset, self.ae_net, k_fold_idx=k_fold_idx)
2024-06-28 07:42:12 +02:00
2025-06-10 13:58:38 +02:00
# Get train results
self.ae_results = {
"train": {
"time": self.ae_trainer.train_time,
"indices": self.ae_trainer.train_indices,
"labels_exp_based": self.ae_trainer.train_labels_exp_based,
"labels_manual_based": self.ae_trainer.train_labels_manual_based,
"semi_targets": self.ae_trainer.train_semi_targets,
"file_ids": self.ae_trainer.train_file_ids,
"frame_ids": self.ae_trainer.train_frame_ids,
"scores": self.ae_trainer.train_scores,
"loss": self.ae_trainer.train_loss,
"file_names": {
file_id: dataset.get_file_name_from_idx(file_id)
for file_id in np.unique(self.ae_trainer.train_file_ids)
},
},
"test": {
"time": self.ae_trainer.test_time,
"indices": self.ae_trainer.test_indices,
"labels_exp_based": self.ae_trainer.test_labels_exp_based,
"labels_manual_based": self.ae_trainer.test_labels_manual_based,
"semi_targets": self.ae_trainer.test_semi_targets,
"file_ids": self.ae_trainer.test_file_ids,
"frame_ids": self.ae_trainer.test_frame_ids,
"scores": self.ae_trainer.test_scores,
"loss": self.ae_trainer.test_loss,
"file_names": {
file_id: dataset.get_file_name_from_idx(file_id)
for file_id in np.unique(self.ae_trainer.test_file_ids)
},
},
}
2024-06-28 07:42:12 +02:00
# Initialize Deep SAD network weights from pre-trained encoder
self.init_network_weights_from_pretraining()
def init_network_weights_from_pretraining(self):
"""Initialize the Deep SAD network weights from the encoder weights of the pretraining autoencoder."""
net_dict = self.net.state_dict()
ae_net_dict = self.ae_net.state_dict()
# Filter out decoder network keys
ae_net_dict = {k: v for k, v in ae_net_dict.items() if k in net_dict}
# Overwrite values in the existing state_dict
net_dict.update(ae_net_dict)
# Load the new state_dict
self.net.load_state_dict(net_dict)
def save_model(self, export_model, save_ae=True):
"""Save Deep SAD model to export_model."""
net_dict = self.net.state_dict()
ae_net_dict = self.ae_net.state_dict() if save_ae else None
2024-06-28 11:36:46 +02:00
torch.save(
{"c": self.c, "net_dict": net_dict, "ae_net_dict": ae_net_dict},
export_model,
)
2024-06-28 07:42:12 +02:00
2024-06-28 11:36:46 +02:00
def load_model(self, model_path, load_ae=False, map_location="cpu"):
2024-06-28 07:42:12 +02:00
"""Load Deep SAD model from model_path."""
model_dict = torch.load(model_path, map_location=map_location)
2024-06-28 11:36:46 +02:00
self.c = model_dict["c"]
self.net.load_state_dict(model_dict["net_dict"])
2024-06-28 07:42:12 +02:00
# load autoencoder parameters if specified
if load_ae:
if self.ae_net is None:
self.ae_net = build_autoencoder(self.net_name, self.rep_dim)
2024-06-28 11:36:46 +02:00
self.ae_net.load_state_dict(model_dict["ae_net_dict"])
2024-06-28 07:42:12 +02:00
def save_results(self, export_pkl):
2024-06-28 07:42:12 +02:00
"""Save results dict to a JSON-file."""
with open(export_pkl, "wb") as fp:
# json.dump(self.results, fp)
pickle.dump(self.results, fp)
2024-06-28 07:42:12 +02:00
2025-06-13 10:24:54 +02:00
def save_ae_results(self, export_pkl):
2024-06-28 07:42:12 +02:00
"""Save autoencoder results dict to a JSON-file."""
2025-06-13 10:24:54 +02:00
with open(export_pkl, "wb") as fp:
pickle.dump(self.ae_results, fp)