full upload so not to lose anything important
This commit is contained in:
@@ -1,18 +1,23 @@
|
||||
from base.base_trainer import BaseTrainer
|
||||
from base.base_dataset import BaseADDataset
|
||||
from base.base_net import BaseNet
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
from sklearn.metrics import (
|
||||
average_precision_score,
|
||||
precision_recall_curve,
|
||||
roc_auc_score,
|
||||
roc_curve,
|
||||
)
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from base.base_dataset import BaseADDataset
|
||||
from base.base_net import BaseNet
|
||||
from base.base_trainer import BaseTrainer
|
||||
|
||||
|
||||
class DeepSADTrainer(BaseTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c,
|
||||
@@ -50,13 +55,22 @@ class DeepSADTrainer(BaseTrainer):
|
||||
self.test_time = None
|
||||
self.test_scores = None
|
||||
|
||||
def train(self, dataset: BaseADDataset, net: BaseNet):
|
||||
def train(
|
||||
self, dataset: BaseADDataset, net: BaseNet, k_fold_idx: int = None
|
||||
) -> BaseNet:
|
||||
logger = logging.getLogger()
|
||||
|
||||
# Get train data loader
|
||||
train_loader, _, _ = dataset.loaders(
|
||||
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
||||
)
|
||||
if k_fold_idx is not None:
|
||||
train_loader, _ = dataset.loaders_k_fold(
|
||||
fold_idx=k_fold_idx,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.n_jobs_dataloader,
|
||||
)
|
||||
else:
|
||||
train_loader, _, _ = dataset.loaders(
|
||||
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
||||
)
|
||||
|
||||
# Set device for network
|
||||
net = net.to(self.device)
|
||||
@@ -82,14 +96,14 @@ class DeepSADTrainer(BaseTrainer):
|
||||
start_time = time.time()
|
||||
net.train()
|
||||
for epoch in range(self.n_epochs):
|
||||
|
||||
epoch_loss = 0.0
|
||||
n_batches = 0
|
||||
epoch_start_time = time.time()
|
||||
for data in train_loader:
|
||||
inputs, _, semi_targets, _ = data
|
||||
inputs, semi_targets = inputs.to(self.device), semi_targets.to(
|
||||
self.device
|
||||
inputs, _, semi_targets, _, _ = data
|
||||
inputs, semi_targets = (
|
||||
inputs.to(self.device),
|
||||
semi_targets.to(self.device),
|
||||
)
|
||||
|
||||
# Zero the network parameter gradients
|
||||
@@ -145,6 +159,7 @@ class DeepSADTrainer(BaseTrainer):
|
||||
logger.info("Starting inference...")
|
||||
n_batches = 0
|
||||
start_time = time.time()
|
||||
all_outputs = np.zeros((len(inference_loader.dataset), 1024), dtype=np.float32)
|
||||
scores = []
|
||||
net.eval()
|
||||
with torch.no_grad():
|
||||
@@ -155,6 +170,10 @@ class DeepSADTrainer(BaseTrainer):
|
||||
idx = idx.to(self.device)
|
||||
|
||||
outputs = net(inputs)
|
||||
all_idx = n_batches * self.batch_size
|
||||
all_outputs[all_idx : all_idx + len(inputs)] = (
|
||||
outputs.cpu().data.numpy()
|
||||
)
|
||||
dist = torch.sum((outputs - self.c) ** 2, dim=1)
|
||||
scores += dist.cpu().data.numpy().tolist()
|
||||
|
||||
@@ -166,15 +185,22 @@ class DeepSADTrainer(BaseTrainer):
|
||||
logger.info("Inference Time: {:.3f}s".format(self.inference_time))
|
||||
logger.info("Finished inference.")
|
||||
|
||||
return np.array(scores)
|
||||
return np.array(scores), all_outputs
|
||||
|
||||
def test(self, dataset: BaseADDataset, net: BaseNet):
|
||||
def test(self, dataset: BaseADDataset, net: BaseNet, k_fold_idx: int = None):
|
||||
logger = logging.getLogger()
|
||||
|
||||
# Get test data loader
|
||||
_, test_loader, _ = dataset.loaders(
|
||||
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
||||
)
|
||||
if k_fold_idx is not None:
|
||||
_, test_loader = dataset.loaders_k_fold(
|
||||
fold_idx=k_fold_idx,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.n_jobs_dataloader,
|
||||
)
|
||||
else:
|
||||
_, test_loader, _ = dataset.loaders(
|
||||
batch_size=self.batch_size, num_workers=self.n_jobs_dataloader
|
||||
)
|
||||
|
||||
# Set device for network
|
||||
net = net.to(self.device)
|
||||
@@ -188,7 +214,7 @@ class DeepSADTrainer(BaseTrainer):
|
||||
net.eval()
|
||||
with torch.no_grad():
|
||||
for data in test_loader:
|
||||
inputs, labels, semi_targets, idx = data
|
||||
inputs, labels, semi_targets, idx, _ = data
|
||||
|
||||
inputs = inputs.to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
@@ -225,6 +251,9 @@ class DeepSADTrainer(BaseTrainer):
|
||||
labels = np.array(labels)
|
||||
scores = np.array(scores)
|
||||
self.test_auc = roc_auc_score(labels, scores)
|
||||
self.test_roc = roc_curve(labels, scores)
|
||||
self.test_prc = precision_recall_curve(labels, scores)
|
||||
self.test_ap = average_precision_score(labels, scores)
|
||||
|
||||
# Log results
|
||||
logger.info("Test Loss: {:.6f}".format(epoch_loss / n_batches))
|
||||
@@ -241,7 +270,7 @@ class DeepSADTrainer(BaseTrainer):
|
||||
with torch.no_grad():
|
||||
for data in train_loader:
|
||||
# get the inputs of the batch
|
||||
inputs, _, _, _ = data
|
||||
inputs, _, _, _, _ = data
|
||||
inputs = inputs.to(self.device)
|
||||
outputs = net(inputs)
|
||||
n_samples += outputs.shape[0]
|
||||
|
||||
Reference in New Issue
Block a user