full upload so not to lose anything important

This commit is contained in:
Jan Kowalczyk
2025-03-14 18:02:23 +01:00
parent 35fcfb7d5a
commit b824ff7482
33 changed files with 3539 additions and 353 deletions

View File

@@ -1,18 +1,23 @@
from base.base_trainer import BaseTrainer
from base.base_dataset import BaseADDataset
from base.base_net import BaseNet
from torch.utils.data.dataloader import DataLoader
from sklearn.metrics import roc_auc_score
import logging
import time
import numpy as np
import torch
import torch.optim as optim
import numpy as np
from sklearn.metrics import (
average_precision_score,
precision_recall_curve,
roc_auc_score,
roc_curve,
)
from torch.utils.data.dataloader import DataLoader
from base.base_dataset import BaseADDataset
from base.base_net import BaseNet
from base.base_trainer import BaseTrainer
class DeepSADTrainer(BaseTrainer):
def __init__(
self,
c,
@@ -50,13 +55,22 @@ class DeepSADTrainer(BaseTrainer):
self.test_time = None
self.test_scores = None
def train(self, dataset: BaseADDataset, net: BaseNet):
def train(
self, dataset: BaseADDataset, net: BaseNet, k_fold_idx: int = None
) -> BaseNet:
logger = logging.getLogger()
# Get train data loader
train_loader, _, _ = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
)
if k_fold_idx is not None:
train_loader, _ = dataset.loaders_k_fold(
fold_idx=k_fold_idx,
batch_size=self.batch_size,
num_workers=self.n_jobs_dataloader,
)
else:
train_loader, _, _ = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
)
# Set device for network
net = net.to(self.device)
@@ -82,14 +96,14 @@ class DeepSADTrainer(BaseTrainer):
start_time = time.time()
net.train()
for epoch in range(self.n_epochs):
epoch_loss = 0.0
n_batches = 0
epoch_start_time = time.time()
for data in train_loader:
inputs, _, semi_targets, _ = data
inputs, semi_targets = inputs.to(self.device), semi_targets.to(
self.device
inputs, _, semi_targets, _, _ = data
inputs, semi_targets = (
inputs.to(self.device),
semi_targets.to(self.device),
)
# Zero the network parameter gradients
@@ -145,6 +159,7 @@ class DeepSADTrainer(BaseTrainer):
logger.info("Starting inference...")
n_batches = 0
start_time = time.time()
all_outputs = np.zeros((len(inference_loader.dataset), 1024), dtype=np.float32)
scores = []
net.eval()
with torch.no_grad():
@@ -155,6 +170,10 @@ class DeepSADTrainer(BaseTrainer):
idx = idx.to(self.device)
outputs = net(inputs)
all_idx = n_batches * self.batch_size
all_outputs[all_idx : all_idx + len(inputs)] = (
outputs.cpu().data.numpy()
)
dist = torch.sum((outputs - self.c) ** 2, dim=1)
scores += dist.cpu().data.numpy().tolist()
@@ -166,15 +185,22 @@ class DeepSADTrainer(BaseTrainer):
logger.info("Inference Time: {:.3f}s".format(self.inference_time))
logger.info("Finished inference.")
return np.array(scores)
return np.array(scores), all_outputs
def test(self, dataset: BaseADDataset, net: BaseNet):
def test(self, dataset: BaseADDataset, net: BaseNet, k_fold_idx: int = None):
logger = logging.getLogger()
# Get test data loader
_, test_loader, _ = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
)
if k_fold_idx is not None:
_, test_loader = dataset.loaders_k_fold(
fold_idx=k_fold_idx,
batch_size=self.batch_size,
num_workers=self.n_jobs_dataloader,
)
else:
_, test_loader, _ = dataset.loaders(
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
)
# Set device for network
net = net.to(self.device)
@@ -188,7 +214,7 @@ class DeepSADTrainer(BaseTrainer):
net.eval()
with torch.no_grad():
for data in test_loader:
inputs, labels, semi_targets, idx = data
inputs, labels, semi_targets, idx, _ = data
inputs = inputs.to(self.device)
labels = labels.to(self.device)
@@ -225,6 +251,9 @@ class DeepSADTrainer(BaseTrainer):
labels = np.array(labels)
scores = np.array(scores)
self.test_auc = roc_auc_score(labels, scores)
self.test_roc = roc_curve(labels, scores)
self.test_prc = precision_recall_curve(labels, scores)
self.test_ap = average_precision_score(labels, scores)
# Log results
logger.info("Test Loss: {:.6f}".format(epoch_loss / n_batches))
@@ -241,7 +270,7 @@ class DeepSADTrainer(BaseTrainer):
with torch.no_grad():
for data in train_loader:
# get the inputs of the batch
inputs, _, _, _ = data
inputs, _, _, _, _ = data
inputs = inputs.to(self.device)
outputs = net(inputs)
n_samples += outputs.shape[0]