2024-06-28 07:42:12 +02:00
|
|
|
import json
|
2025-03-14 18:02:23 +01:00
|
|
|
import pickle
|
|
|
|
|
|
2024-06-28 07:42:12 +02:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from base.base_dataset import BaseADDataset
|
2025-03-14 18:02:23 +01:00
|
|
|
from networks.main import build_autoencoder, build_network
|
2024-06-28 07:42:12 +02:00
|
|
|
from optim.ae_trainer import AETrainer
|
2025-03-14 18:02:23 +01:00
|
|
|
from optim.DeepSAD_trainer import DeepSADTrainer
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeepSAD(object):
|
|
|
|
|
"""A class for the Deep SAD method.
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
eta: Deep SAD hyperparameter eta (must be 0 < eta).
|
|
|
|
|
c: Hypersphere center c.
|
|
|
|
|
net_name: A string indicating the name of the neural network to use.
|
|
|
|
|
net: The neural network phi.
|
|
|
|
|
trainer: DeepSADTrainer to train a Deep SAD model.
|
|
|
|
|
optimizer_name: A string indicating the optimizer to use for training the Deep SAD network.
|
|
|
|
|
ae_net: The autoencoder network corresponding to phi for network weights pretraining.
|
|
|
|
|
ae_trainer: AETrainer to train an autoencoder in pretraining.
|
|
|
|
|
ae_optimizer_name: A string indicating the optimizer to use for pretraining the autoencoder.
|
|
|
|
|
results: A dictionary to save the results.
|
|
|
|
|
ae_results: A dictionary to save the autoencoder results.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, eta: float = 1.0):
|
|
|
|
|
"""Inits DeepSAD with hyperparameter eta."""
|
|
|
|
|
|
|
|
|
|
self.eta = eta
|
|
|
|
|
self.c = None # hypersphere center c
|
|
|
|
|
|
|
|
|
|
self.net_name = None
|
|
|
|
|
self.net = None # neural network phi
|
|
|
|
|
|
|
|
|
|
self.trainer = None
|
|
|
|
|
self.optimizer_name = None
|
|
|
|
|
|
|
|
|
|
self.ae_net = None # autoencoder network for pretraining
|
|
|
|
|
self.ae_trainer = None
|
|
|
|
|
self.ae_optimizer_name = None
|
|
|
|
|
|
|
|
|
|
self.results = {
|
2024-06-28 11:36:46 +02:00
|
|
|
"train_time": None,
|
|
|
|
|
"test_auc": None,
|
|
|
|
|
"test_time": None,
|
|
|
|
|
"test_scores": None,
|
2024-06-28 07:42:12 +02:00
|
|
|
}
|
|
|
|
|
|
2024-06-28 11:36:46 +02:00
|
|
|
self.ae_results = {"train_time": None, "test_auc": None, "test_time": None}
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
def set_network(self, net_name):
|
|
|
|
|
"""Builds the neural network phi."""
|
|
|
|
|
self.net_name = net_name
|
|
|
|
|
self.net = build_network(net_name)
|
|
|
|
|
|
2024-06-28 11:36:46 +02:00
|
|
|
def train(
|
|
|
|
|
self,
|
|
|
|
|
dataset: BaseADDataset,
|
|
|
|
|
optimizer_name: str = "adam",
|
|
|
|
|
lr: float = 0.001,
|
|
|
|
|
n_epochs: int = 50,
|
|
|
|
|
lr_milestones: tuple = (),
|
|
|
|
|
batch_size: int = 128,
|
|
|
|
|
weight_decay: float = 1e-6,
|
|
|
|
|
device: str = "cuda",
|
|
|
|
|
n_jobs_dataloader: int = 0,
|
2025-03-14 18:02:23 +01:00
|
|
|
k_fold_idx: int = None,
|
2024-06-28 11:36:46 +02:00
|
|
|
):
|
2024-06-28 07:42:12 +02:00
|
|
|
"""Trains the Deep SAD model on the training data."""
|
|
|
|
|
|
|
|
|
|
self.optimizer_name = optimizer_name
|
2024-06-28 11:36:46 +02:00
|
|
|
self.trainer = DeepSADTrainer(
|
|
|
|
|
self.c,
|
|
|
|
|
self.eta,
|
|
|
|
|
optimizer_name=optimizer_name,
|
|
|
|
|
lr=lr,
|
|
|
|
|
n_epochs=n_epochs,
|
|
|
|
|
lr_milestones=lr_milestones,
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
weight_decay=weight_decay,
|
|
|
|
|
device=device,
|
|
|
|
|
n_jobs_dataloader=n_jobs_dataloader,
|
|
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
# Get the model
|
2025-03-14 18:02:23 +01:00
|
|
|
self.net = self.trainer.train(dataset, self.net, k_fold_idx=k_fold_idx)
|
2024-06-28 11:36:46 +02:00
|
|
|
self.results["train_time"] = self.trainer.train_time
|
2024-06-28 07:42:12 +02:00
|
|
|
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list
|
|
|
|
|
|
2024-07-04 15:36:01 +02:00
|
|
|
def inference(
|
|
|
|
|
self, dataset: BaseADDataset, device: str = "cuda", n_jobs_dataloader: int = 0
|
|
|
|
|
):
|
|
|
|
|
"""Tests the Deep SAD model on the test data."""
|
|
|
|
|
|
|
|
|
|
if self.trainer is None:
|
|
|
|
|
self.trainer = DeepSADTrainer(
|
|
|
|
|
self.c, self.eta, device=device, n_jobs_dataloader=n_jobs_dataloader
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return self.trainer.infer(dataset, self.net)
|
|
|
|
|
|
2024-06-28 11:36:46 +02:00
|
|
|
def test(
|
2025-03-14 18:02:23 +01:00
|
|
|
self,
|
|
|
|
|
dataset: BaseADDataset,
|
|
|
|
|
device: str = "cuda",
|
|
|
|
|
n_jobs_dataloader: int = 0,
|
|
|
|
|
k_fold_idx: int = None,
|
2024-06-28 11:36:46 +02:00
|
|
|
):
|
2024-06-28 07:42:12 +02:00
|
|
|
"""Tests the Deep SAD model on the test data."""
|
|
|
|
|
|
|
|
|
|
if self.trainer is None:
|
2024-06-28 11:36:46 +02:00
|
|
|
self.trainer = DeepSADTrainer(
|
|
|
|
|
self.c, self.eta, device=device, n_jobs_dataloader=n_jobs_dataloader
|
|
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
2025-03-14 18:02:23 +01:00
|
|
|
self.trainer.test(dataset, self.net, k_fold_idx=k_fold_idx)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Get results
|
2024-06-28 11:36:46 +02:00
|
|
|
self.results["test_auc"] = self.trainer.test_auc
|
2025-03-14 18:02:23 +01:00
|
|
|
self.results["test_roc"] = self.trainer.test_roc
|
|
|
|
|
self.results["test_prc"] = self.trainer.test_prc
|
|
|
|
|
self.results["test_ap"] = self.trainer.test_ap
|
2024-06-28 11:36:46 +02:00
|
|
|
self.results["test_time"] = self.trainer.test_time
|
|
|
|
|
self.results["test_scores"] = self.trainer.test_scores
|
|
|
|
|
|
|
|
|
|
def pretrain(
|
|
|
|
|
self,
|
|
|
|
|
dataset: BaseADDataset,
|
|
|
|
|
optimizer_name: str = "adam",
|
|
|
|
|
lr: float = 0.001,
|
|
|
|
|
n_epochs: int = 100,
|
|
|
|
|
lr_milestones: tuple = (),
|
|
|
|
|
batch_size: int = 128,
|
|
|
|
|
weight_decay: float = 1e-6,
|
|
|
|
|
device: str = "cuda",
|
|
|
|
|
n_jobs_dataloader: int = 0,
|
2025-03-14 18:02:23 +01:00
|
|
|
k_fold_idx: int = None,
|
2024-06-28 11:36:46 +02:00
|
|
|
):
|
2024-06-28 07:42:12 +02:00
|
|
|
"""Pretrains the weights for the Deep SAD network phi via autoencoder."""
|
|
|
|
|
|
|
|
|
|
# Set autoencoder network
|
|
|
|
|
self.ae_net = build_autoencoder(self.net_name)
|
|
|
|
|
|
|
|
|
|
# Train
|
|
|
|
|
self.ae_optimizer_name = optimizer_name
|
2024-06-28 11:36:46 +02:00
|
|
|
self.ae_trainer = AETrainer(
|
|
|
|
|
optimizer_name,
|
|
|
|
|
lr=lr,
|
|
|
|
|
n_epochs=n_epochs,
|
|
|
|
|
lr_milestones=lr_milestones,
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
weight_decay=weight_decay,
|
|
|
|
|
device=device,
|
|
|
|
|
n_jobs_dataloader=n_jobs_dataloader,
|
|
|
|
|
)
|
2025-03-14 18:02:23 +01:00
|
|
|
self.ae_net = self.ae_trainer.train(dataset, self.ae_net, k_fold_idx=k_fold_idx)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Get train results
|
2024-06-28 11:36:46 +02:00
|
|
|
self.ae_results["train_time"] = self.ae_trainer.train_time
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Test
|
2025-03-14 18:02:23 +01:00
|
|
|
self.ae_trainer.test(dataset, self.ae_net, k_fold_idx=k_fold_idx)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Get test results
|
2024-06-28 11:36:46 +02:00
|
|
|
self.ae_results["test_auc"] = self.ae_trainer.test_auc
|
|
|
|
|
self.ae_results["test_time"] = self.ae_trainer.test_time
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# Initialize Deep SAD network weights from pre-trained encoder
|
|
|
|
|
self.init_network_weights_from_pretraining()
|
|
|
|
|
|
|
|
|
|
def init_network_weights_from_pretraining(self):
|
|
|
|
|
"""Initialize the Deep SAD network weights from the encoder weights of the pretraining autoencoder."""
|
|
|
|
|
|
|
|
|
|
net_dict = self.net.state_dict()
|
|
|
|
|
ae_net_dict = self.ae_net.state_dict()
|
|
|
|
|
|
|
|
|
|
# Filter out decoder network keys
|
|
|
|
|
ae_net_dict = {k: v for k, v in ae_net_dict.items() if k in net_dict}
|
|
|
|
|
# Overwrite values in the existing state_dict
|
|
|
|
|
net_dict.update(ae_net_dict)
|
|
|
|
|
# Load the new state_dict
|
|
|
|
|
self.net.load_state_dict(net_dict)
|
|
|
|
|
|
|
|
|
|
def save_model(self, export_model, save_ae=True):
|
|
|
|
|
"""Save Deep SAD model to export_model."""
|
|
|
|
|
|
|
|
|
|
net_dict = self.net.state_dict()
|
|
|
|
|
ae_net_dict = self.ae_net.state_dict() if save_ae else None
|
|
|
|
|
|
2024-06-28 11:36:46 +02:00
|
|
|
torch.save(
|
|
|
|
|
{"c": self.c, "net_dict": net_dict, "ae_net_dict": ae_net_dict},
|
|
|
|
|
export_model,
|
|
|
|
|
)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
2024-06-28 11:36:46 +02:00
|
|
|
def load_model(self, model_path, load_ae=False, map_location="cpu"):
|
2024-06-28 07:42:12 +02:00
|
|
|
"""Load Deep SAD model from model_path."""
|
|
|
|
|
|
|
|
|
|
model_dict = torch.load(model_path, map_location=map_location)
|
|
|
|
|
|
2024-06-28 11:36:46 +02:00
|
|
|
self.c = model_dict["c"]
|
|
|
|
|
self.net.load_state_dict(model_dict["net_dict"])
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
# load autoencoder parameters if specified
|
|
|
|
|
if load_ae:
|
|
|
|
|
if self.ae_net is None:
|
|
|
|
|
self.ae_net = build_autoencoder(self.net_name)
|
2024-06-28 11:36:46 +02:00
|
|
|
self.ae_net.load_state_dict(model_dict["ae_net_dict"])
|
2024-06-28 07:42:12 +02:00
|
|
|
|
2025-03-14 18:02:23 +01:00
|
|
|
def save_results(self, export_pkl):
|
2024-06-28 07:42:12 +02:00
|
|
|
"""Save results dict to a JSON-file."""
|
2025-03-14 18:02:23 +01:00
|
|
|
with open(export_pkl, "wb") as fp:
|
|
|
|
|
# json.dump(self.results, fp)
|
|
|
|
|
pickle.dump(self.results, fp)
|
2024-06-28 07:42:12 +02:00
|
|
|
|
|
|
|
|
def save_ae_results(self, export_json):
|
|
|
|
|
"""Save autoencoder results dict to a JSON-file."""
|
2024-06-28 11:36:46 +02:00
|
|
|
with open(export_json, "w") as fp:
|
2024-06-28 07:42:12 +02:00
|
|
|
json.dump(self.ae_results, fp)
|