wip
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from base.base_dataset import BaseADDataset
|
from base.base_dataset import BaseADDataset
|
||||||
@@ -43,10 +44,47 @@ class DeepSAD(object):
|
|||||||
self.ae_optimizer_name = None
|
self.ae_optimizer_name = None
|
||||||
|
|
||||||
self.results = {
|
self.results = {
|
||||||
"train_time": None,
|
"train": {
|
||||||
"test_auc": None,
|
"time": None,
|
||||||
"test_time": None,
|
"indices": None,
|
||||||
"test_scores": None,
|
"file_ids": None,
|
||||||
|
"frame_ids": None,
|
||||||
|
"file_names": None, # mapping of file_ids to file names
|
||||||
|
},
|
||||||
|
"test": {
|
||||||
|
"time": None,
|
||||||
|
"indices": None,
|
||||||
|
"file_ids": None,
|
||||||
|
"frame_ids": None,
|
||||||
|
"file_names": None, # mapping of file_ids to file names
|
||||||
|
"exp_based": {
|
||||||
|
"auc": None,
|
||||||
|
"roc": None,
|
||||||
|
"prc": None,
|
||||||
|
"ap": None,
|
||||||
|
"scores": None,
|
||||||
|
"indices": None,
|
||||||
|
"labels": None,
|
||||||
|
"valid_mask": None,
|
||||||
|
},
|
||||||
|
"manual_based": {
|
||||||
|
"auc": None,
|
||||||
|
"roc": None,
|
||||||
|
"prc": None,
|
||||||
|
"ap": None,
|
||||||
|
"scores": None,
|
||||||
|
"indices": None,
|
||||||
|
"labels": None,
|
||||||
|
"valid_mask": None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"inference": {
|
||||||
|
"time": None,
|
||||||
|
"indices": None,
|
||||||
|
"file_ids": None,
|
||||||
|
"frame_ids": None,
|
||||||
|
"file_names": None, # mapping of file_ids to file names
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
self.ae_results = {"train_time": None, "test_auc": None, "test_time": None}
|
self.ae_results = {"train_time": None, "test_auc": None, "test_time": None}
|
||||||
@@ -86,8 +124,17 @@ class DeepSAD(object):
|
|||||||
)
|
)
|
||||||
# Get the model
|
# Get the model
|
||||||
self.net = self.trainer.train(dataset, self.net, k_fold_idx=k_fold_idx)
|
self.net = self.trainer.train(dataset, self.net, k_fold_idx=k_fold_idx)
|
||||||
self.results["train_time"] = self.trainer.train_time
|
# Store training results including indices
|
||||||
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list
|
self.results["train"]["time"] = self.trainer.train_time
|
||||||
|
self.results["train"]["indices"] = self.trainer.train_indices
|
||||||
|
self.results["train"]["file_ids"] = self.trainer.train_file_ids
|
||||||
|
self.results["train"]["frame_ids"] = self.trainer.train_frame_ids
|
||||||
|
|
||||||
|
# Get file names mapping for training data
|
||||||
|
self.results["train"]["file_names"] = {
|
||||||
|
file_id: dataset.get_file_name_from_idx(file_id)
|
||||||
|
for file_id in np.unique(self.trainer.train_file_ids)
|
||||||
|
}
|
||||||
|
|
||||||
def inference(
|
def inference(
|
||||||
self, dataset: BaseADDataset, device: str = "cuda", n_jobs_dataloader: int = 0
|
self, dataset: BaseADDataset, device: str = "cuda", n_jobs_dataloader: int = 0
|
||||||
@@ -99,7 +146,21 @@ class DeepSAD(object):
|
|||||||
self.c, self.eta, device=device, n_jobs_dataloader=n_jobs_dataloader
|
self.c, self.eta, device=device, n_jobs_dataloader=n_jobs_dataloader
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.trainer.infer(dataset, self.net)
|
scores, outputs = self.trainer.infer(dataset, self.net)
|
||||||
|
|
||||||
|
# Store inference indices and mappings
|
||||||
|
self.results["inference"]["time"] = self.trainer.inference_time
|
||||||
|
self.results["inference"]["indices"] = self.trainer.inference_indices
|
||||||
|
self.results["inference"]["file_ids"] = self.trainer.inference_file_ids
|
||||||
|
self.results["inference"]["frame_ids"] = self.trainer.inference_frame_ids
|
||||||
|
|
||||||
|
# Get file names mapping for inference data
|
||||||
|
self.results["inference"]["file_names"] = {
|
||||||
|
file_id: dataset.get_file_name_from_idx(file_id)
|
||||||
|
for file_id in np.unique(self.trainer.inference_file_ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
return scores, outputs
|
||||||
|
|
||||||
def test(
|
def test(
|
||||||
self,
|
self,
|
||||||
@@ -117,13 +178,51 @@ class DeepSAD(object):
|
|||||||
|
|
||||||
self.trainer.test(dataset, self.net, k_fold_idx=k_fold_idx)
|
self.trainer.test(dataset, self.net, k_fold_idx=k_fold_idx)
|
||||||
|
|
||||||
# Get results
|
# Store all test indices and mappings
|
||||||
self.results["test_auc"] = self.trainer.test_auc
|
self.results["test"]["time"] = self.trainer.test_time
|
||||||
self.results["test_roc"] = self.trainer.test_roc
|
self.results["test"]["indices"] = self.trainer.test_indices
|
||||||
self.results["test_prc"] = self.trainer.test_prc
|
self.results["test"]["file_ids"] = self.trainer.test_file_ids
|
||||||
self.results["test_ap"] = self.trainer.test_ap
|
self.results["test"]["frame_ids"] = self.trainer.test_frame_ids
|
||||||
self.results["test_time"] = self.trainer.test_time
|
|
||||||
self.results["test_scores"] = self.trainer.test_scores
|
# Get file names mapping for test data
|
||||||
|
self.results["test"]["file_names"] = {
|
||||||
|
file_id: dataset.get_file_name_from_idx(file_id)
|
||||||
|
for file_id in np.unique(self.trainer.test_file_ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store experiment-based results
|
||||||
|
self.results["test"]["exp_based"]["auc"] = self.trainer.test_auc_exp_based
|
||||||
|
self.results["test"]["exp_based"]["roc"] = self.trainer.test_roc_exp_based
|
||||||
|
self.results["test"]["exp_based"]["prc"] = self.trainer.test_prc_exp_based
|
||||||
|
self.results["test"]["exp_based"]["ap"] = self.trainer.test_ap_exp_based
|
||||||
|
self.results["test"]["exp_based"]["scores"] = self.trainer.test_scores_exp_based
|
||||||
|
self.results["test"]["exp_based"]["indices"] = self.trainer.test_index_mapping[
|
||||||
|
"exp_based"
|
||||||
|
]["indices"]
|
||||||
|
self.results["test"]["exp_based"]["labels"] = self.trainer.test_index_mapping[
|
||||||
|
"exp_based"
|
||||||
|
]["labels"]
|
||||||
|
self.results["test"]["exp_based"]["valid_mask"] = (
|
||||||
|
self.trainer.test_index_mapping["exp_based"]["valid_mask"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store manual-based results
|
||||||
|
self.results["test"]["manual_based"]["auc"] = self.trainer.test_auc_manual_based
|
||||||
|
self.results["test"]["manual_based"]["roc"] = self.trainer.test_roc_manual_based
|
||||||
|
self.results["test"]["manual_based"]["prc"] = self.trainer.test_prc_manual_based
|
||||||
|
self.results["test"]["manual_based"]["ap"] = self.trainer.test_ap_manual_based
|
||||||
|
self.results["test"]["manual_based"]["scores"] = (
|
||||||
|
self.trainer.test_scores_manual_based
|
||||||
|
)
|
||||||
|
self.results["test"]["manual_based"]["indices"] = (
|
||||||
|
self.trainer.test_index_mapping["manual_based"]["indices"]
|
||||||
|
)
|
||||||
|
self.results["test"]["manual_based"]["labels"] = (
|
||||||
|
self.trainer.test_index_mapping["manual_based"]["labels"]
|
||||||
|
)
|
||||||
|
self.results["test"]["manual_based"]["valid_mask"] = (
|
||||||
|
self.trainer.test_index_mapping["manual_based"]["valid_mask"]
|
||||||
|
)
|
||||||
|
|
||||||
def pretrain(
|
def pretrain(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ class TorchvisionDataset(BaseADDataset):
|
|||||||
shuffle_test=False,
|
shuffle_test=False,
|
||||||
num_workers: int = 0,
|
num_workers: int = 0,
|
||||||
) -> (DataLoader, DataLoader):
|
) -> (DataLoader, DataLoader):
|
||||||
|
if self.k_fold_number is None:
|
||||||
|
raise ValueError("k_fold_number must be set to a positive integer.")
|
||||||
if self.fold_indices is None:
|
if self.fold_indices is None:
|
||||||
# Define the K-fold Cross Validator
|
# Define the K-fold Cross Validator
|
||||||
kfold = KFold(n_splits=self.k_fold_number, shuffle=False)
|
kfold = KFold(n_splits=self.k_fold_number, shuffle=False)
|
||||||
|
|||||||
@@ -51,9 +51,16 @@ class IsoForest(object):
|
|||||||
self.results = {
|
self.results = {
|
||||||
"train_time": None,
|
"train_time": None,
|
||||||
"test_time": None,
|
"test_time": None,
|
||||||
"test_auc": None,
|
"test_auc_exp_based": None,
|
||||||
"test_roc": None,
|
"test_roc_exp_based": None,
|
||||||
"test_scores": None,
|
"test_prc_exp_based": None,
|
||||||
|
"test_ap_exp_based": None,
|
||||||
|
"test_scores_exp_based": None,
|
||||||
|
"test_auc_manual_based": None,
|
||||||
|
"test_roc_manual_based": None,
|
||||||
|
"test_prc_manual_based": None,
|
||||||
|
"test_ap_manual_based": None,
|
||||||
|
"test_scores_manual_based": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
@@ -89,7 +96,7 @@ class IsoForest(object):
|
|||||||
# Get data from loader
|
# Get data from loader
|
||||||
X = ()
|
X = ()
|
||||||
for data in train_loader:
|
for data in train_loader:
|
||||||
inputs, _, _, _, _ = data
|
inputs, _, _, _, _, _ = data
|
||||||
inputs = inputs.to(device)
|
inputs = inputs.to(device)
|
||||||
if self.hybrid:
|
if self.hybrid:
|
||||||
inputs = self.ae_net.encoder(
|
inputs = self.ae_net.encoder(
|
||||||
@@ -133,28 +140,50 @@ class IsoForest(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get data from loader
|
# Get data from loader
|
||||||
idx_label_score = []
|
idx_label_score_exp = []
|
||||||
|
idx_label_score_manual = []
|
||||||
X = ()
|
X = ()
|
||||||
idxs = []
|
idxs = []
|
||||||
labels = []
|
labels_exp = []
|
||||||
|
labels_manual = []
|
||||||
|
|
||||||
for data in test_loader:
|
for data in test_loader:
|
||||||
inputs, label_batch, _, idx, _ = data
|
inputs, label_exp, label_manual, _, idx, _ = data
|
||||||
inputs, label_batch, idx = (
|
inputs, label_exp, label_manual, idx = (
|
||||||
inputs.to(device),
|
inputs.to(device),
|
||||||
label_batch.to(device),
|
label_exp.to(device),
|
||||||
|
label_manual.to(device),
|
||||||
idx.to(device),
|
idx.to(device),
|
||||||
)
|
)
|
||||||
if self.hybrid:
|
if self.hybrid:
|
||||||
inputs = self.ae_net.encoder(
|
inputs = self.ae_net.encoder(inputs)
|
||||||
inputs
|
X_batch = inputs.view(inputs.size(0), -1)
|
||||||
) # in hybrid approach, take code representation of AE as features
|
|
||||||
X_batch = inputs.view(
|
|
||||||
inputs.size(0), -1
|
|
||||||
) # X_batch.shape = (batch_size, n_channels * height * width)
|
|
||||||
X += (X_batch.cpu().data.numpy(),)
|
X += (X_batch.cpu().data.numpy(),)
|
||||||
idxs += idx.cpu().data.numpy().astype(np.int64).tolist()
|
idxs += idx.cpu().data.numpy().astype(np.int64).tolist()
|
||||||
labels += label_batch.cpu().data.numpy().astype(np.int64).tolist()
|
labels_exp += label_exp.cpu().data.numpy().astype(np.int64).tolist()
|
||||||
|
labels_manual += label_manual.cpu().data.numpy().astype(np.int64).tolist()
|
||||||
|
|
||||||
X = np.concatenate(X)
|
X = np.concatenate(X)
|
||||||
|
labels_exp = np.array(labels_exp)
|
||||||
|
labels_manual = np.array(labels_manual)
|
||||||
|
|
||||||
|
# Count and log label stats for exp_based
|
||||||
|
n_exp_normal = np.sum(labels_exp == 1)
|
||||||
|
n_exp_anomaly = np.sum(labels_exp == -1)
|
||||||
|
n_exp_unknown = np.sum(labels_exp == 0)
|
||||||
|
logger.info(
|
||||||
|
f"Exp-based labels: normal(1)={n_exp_normal}, "
|
||||||
|
f"anomaly(-1)={n_exp_anomaly}, unknown(0)={n_exp_unknown}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count and log label stats for manual_based
|
||||||
|
n_manual_normal = np.sum(labels_manual == 1)
|
||||||
|
n_manual_anomaly = np.sum(labels_manual == -1)
|
||||||
|
n_manual_unknown = np.sum(labels_manual == 0)
|
||||||
|
logger.info(
|
||||||
|
f"Manual-based labels: normal(1)={n_manual_normal}, "
|
||||||
|
f"anomaly(-1)={n_manual_anomaly}, unknown(0)={n_manual_unknown}"
|
||||||
|
)
|
||||||
|
|
||||||
# Testing
|
# Testing
|
||||||
logger.info("Starting testing...")
|
logger.info("Starting testing...")
|
||||||
@@ -163,21 +192,72 @@ class IsoForest(object):
|
|||||||
self.results["test_time"] = time.time() - start_time
|
self.results["test_time"] = time.time() - start_time
|
||||||
scores = scores.flatten()
|
scores = scores.flatten()
|
||||||
|
|
||||||
# Save triples of (idx, label, score) in a list
|
# Save triples of (idx, label, score) in a list for both label types
|
||||||
idx_label_score += list(zip(idxs, labels, scores.tolist()))
|
idx_label_score_exp += list(zip(idxs, labels_exp.tolist(), scores.tolist()))
|
||||||
self.results["test_scores"] = idx_label_score
|
idx_label_score_manual += list(
|
||||||
|
zip(idxs, labels_manual.tolist(), scores.tolist())
|
||||||
|
)
|
||||||
|
|
||||||
# Compute AUC
|
self.results["test_scores_exp_based"] = idx_label_score_exp
|
||||||
_, labels, scores = zip(*idx_label_score)
|
self.results["test_scores_manual_based"] = idx_label_score_manual
|
||||||
labels = np.array(labels)
|
|
||||||
scores = np.array(scores)
|
# --- Evaluation for exp_based (only labeled samples) ---
|
||||||
self.results["test_auc"] = roc_auc_score(labels, scores)
|
# Filter out unknown labels and convert to binary (1: anomaly, 0: normal) for ROC
|
||||||
self.results["test_roc"] = roc_curve(labels, scores)
|
valid_mask_exp = labels_exp != 0
|
||||||
self.results["test_prc"] = precision_recall_curve(labels, scores)
|
if np.any(valid_mask_exp):
|
||||||
self.results["test_ap"] = average_precision_score(labels, scores)
|
# Convert to binary labels for ROC (-1 → 1, 1 → 0)
|
||||||
|
labels_exp_binary = (labels_exp[valid_mask_exp] == -1).astype(int)
|
||||||
|
scores_exp_valid = scores[valid_mask_exp]
|
||||||
|
|
||||||
|
self.results["test_auc_exp_based"] = roc_auc_score(
|
||||||
|
labels_exp_binary, scores_exp_valid
|
||||||
|
)
|
||||||
|
self.results["test_roc_exp_based"] = roc_curve(
|
||||||
|
labels_exp_binary, scores_exp_valid
|
||||||
|
)
|
||||||
|
self.results["test_prc_exp_based"] = precision_recall_curve(
|
||||||
|
labels_exp_binary, scores_exp_valid
|
||||||
|
)
|
||||||
|
self.results["test_ap_exp_based"] = average_precision_score(
|
||||||
|
labels_exp_binary, scores_exp_valid
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Test AUC (exp_based): {:.2f}%".format(
|
||||||
|
100.0 * self.results["test_auc_exp_based"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("Test AUC (exp_based): N/A (no labeled samples)")
|
||||||
|
|
||||||
|
# --- Evaluation for manual_based (only labeled samples) ---
|
||||||
|
valid_mask_manual = labels_manual != 0
|
||||||
|
if np.any(valid_mask_manual):
|
||||||
|
# Convert to binary labels for ROC (-1 → 1, 1 → 0)
|
||||||
|
labels_manual_binary = (labels_manual[valid_mask_manual] == -1).astype(int)
|
||||||
|
scores_manual_valid = scores[valid_mask_manual]
|
||||||
|
|
||||||
|
self.results["test_auc_manual_based"] = roc_auc_score(
|
||||||
|
labels_manual_binary, scores_manual_valid
|
||||||
|
)
|
||||||
|
self.results["test_roc_manual_based"] = roc_curve(
|
||||||
|
labels_manual_binary, scores_manual_valid
|
||||||
|
)
|
||||||
|
self.results["test_prc_manual_based"] = precision_recall_curve(
|
||||||
|
labels_manual_binary, scores_manual_valid
|
||||||
|
)
|
||||||
|
self.results["test_ap_manual_based"] = average_precision_score(
|
||||||
|
labels_manual_binary, scores_manual_valid
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Test AUC (manual_based): {:.2f}%".format(
|
||||||
|
100.0 * self.results["test_auc_manual_based"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("Test AUC (manual_based): N/A (no labeled samples)")
|
||||||
|
|
||||||
# Log results
|
|
||||||
logger.info("Test AUC: {:.2f}%".format(100.0 * self.results["test_auc"]))
|
|
||||||
logger.info("Test Time: {:.3f}s".format(self.results["test_time"]))
|
logger.info("Test Time: {:.3f}s".format(self.results["test_time"]))
|
||||||
logger.info("Finished testing.")
|
logger.info("Finished testing.")
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ def load_dataset(
|
|||||||
ratio_pollution: float = 0.0,
|
ratio_pollution: float = 0.0,
|
||||||
random_state=None,
|
random_state=None,
|
||||||
inference: bool = False,
|
inference: bool = False,
|
||||||
k_fold: bool = False,
|
k_fold_num: int = None,
|
||||||
num_known_normal: int = 0,
|
num_known_normal: int = 0,
|
||||||
num_known_outlier: int = 0,
|
num_known_outlier: int = 0,
|
||||||
):
|
):
|
||||||
@@ -45,11 +45,8 @@ def load_dataset(
|
|||||||
if dataset_name == "subter":
|
if dataset_name == "subter":
|
||||||
dataset = SubTer_Dataset(
|
dataset = SubTer_Dataset(
|
||||||
root=data_path,
|
root=data_path,
|
||||||
ratio_known_normal=ratio_known_normal,
|
|
||||||
ratio_known_outlier=ratio_known_outlier,
|
|
||||||
ratio_pollution=ratio_pollution,
|
|
||||||
inference=inference,
|
inference=inference,
|
||||||
k_fold=k_fold,
|
k_fold_num=k_fold_num,
|
||||||
num_known_normal=num_known_normal,
|
num_known_normal=num_known_normal,
|
||||||
num_known_outlier=num_known_outlier,
|
num_known_outlier=num_known_outlier,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
@@ -8,596 +7,350 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import Subset
|
|
||||||
from torch.utils.data.dataset import ConcatDataset
|
|
||||||
from torchvision.datasets import VisionDataset
|
from torchvision.datasets import VisionDataset
|
||||||
|
|
||||||
from base.torchvision_dataset import TorchvisionDataset
|
from base.torchvision_dataset import TorchvisionDataset
|
||||||
|
|
||||||
from .preprocessing import create_semisupervised_setting
|
|
||||||
|
|
||||||
|
|
||||||
class SubTer_Dataset(TorchvisionDataset):
|
class SubTer_Dataset(TorchvisionDataset):
|
||||||
|
"""
|
||||||
|
Wrapper for SubTerTraining and SubTerInference, sets up train/test/inference/data_set as needed.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root: str,
|
root: str,
|
||||||
ratio_known_normal: float = 0.0,
|
|
||||||
ratio_known_outlier: float = 0.0,
|
|
||||||
ratio_pollution: float = 0.0,
|
|
||||||
inference: bool = False,
|
|
||||||
k_fold: bool = False,
|
|
||||||
num_known_normal: int = 0,
|
num_known_normal: int = 0,
|
||||||
num_known_outlier: int = 0,
|
num_known_outlier: int = 0,
|
||||||
only_use_given_semi_targets_for_evaluation: bool = True,
|
k_fold_num: int = None,
|
||||||
|
inference: bool = False,
|
||||||
|
transform: Optional[Callable] = None,
|
||||||
|
target_transform: Optional[Callable] = None,
|
||||||
|
seed: int = 0,
|
||||||
|
split: float = 0.7,
|
||||||
):
|
):
|
||||||
super().__init__(root)
|
super().__init__(root, k_fold_number=k_fold_num)
|
||||||
if Path(root).is_dir():
|
|
||||||
with open(Path(root) / "semi_targets.json", "r") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
semi_targets_given = {
|
|
||||||
item["filename"]: (
|
|
||||||
item["semi_target_begin_frame"],
|
|
||||||
item["semi_target_end_frame"],
|
|
||||||
)
|
|
||||||
for item in data["files"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Define normal and outlier classes
|
|
||||||
self.n_classes = 2 # 0: normal, 1: outlier
|
|
||||||
self.normal_classes = tuple([0])
|
|
||||||
self.outlier_classes = tuple([1])
|
|
||||||
self.inference_set = None
|
self.inference_set = None
|
||||||
|
self.train_set = None
|
||||||
# MNIST preprocessing: feature scaling to [0, 1]
|
self.test_set = None
|
||||||
# FIXME understand mnist feature scaling and check if it or other preprocessing is necessary for elpv
|
self.data_set = None
|
||||||
transform = transforms.ToTensor()
|
|
||||||
target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))
|
|
||||||
|
|
||||||
if inference:
|
if inference:
|
||||||
self.inference_set = SubTerInference(
|
self.inference_set = SubTerInference(
|
||||||
root=self.root,
|
root=root,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Always require the manual label file
|
||||||
|
manual_json_path = Path(root) / "manually_labeled_anomaly_frames.json"
|
||||||
|
if not manual_json_path.exists():
|
||||||
|
raise FileNotFoundError(f"Required file not found: {manual_json_path}")
|
||||||
|
|
||||||
|
# For k_fold, data_set is the full dataset, train/test are None
|
||||||
|
if k_fold_num is not None:
|
||||||
|
self.data_set = SubTerTraining(
|
||||||
|
root=root,
|
||||||
|
num_known_normal=num_known_normal,
|
||||||
|
num_known_outlier=num_known_outlier,
|
||||||
|
transform=transform,
|
||||||
|
target_transform=target_transform,
|
||||||
|
seed=seed,
|
||||||
|
split=1.0, # use all data for k-fold
|
||||||
|
)
|
||||||
|
self.train_set = None
|
||||||
|
self.test_set = None
|
||||||
else:
|
else:
|
||||||
if k_fold:
|
# Standard split
|
||||||
# Get train set
|
self.train_set = SubTerTraining(
|
||||||
data_set = SubTerTraining(
|
root=root,
|
||||||
root=self.root,
|
num_known_normal=num_known_normal,
|
||||||
transform=transform,
|
num_known_outlier=num_known_outlier,
|
||||||
target_transform=target_transform,
|
transform=transform,
|
||||||
train=True,
|
target_transform=target_transform,
|
||||||
split=1,
|
seed=seed,
|
||||||
semi_targets_given=semi_targets_given,
|
split=split,
|
||||||
)
|
train=True,
|
||||||
|
)
|
||||||
|
self.test_set = SubTerTraining(
|
||||||
|
root=root,
|
||||||
|
num_known_normal=num_known_normal,
|
||||||
|
num_known_outlier=num_known_outlier,
|
||||||
|
transform=transform,
|
||||||
|
target_transform=target_transform,
|
||||||
|
seed=seed,
|
||||||
|
split=split,
|
||||||
|
train=False,
|
||||||
|
)
|
||||||
|
self.data_set = None # not used unless k_fold
|
||||||
|
|
||||||
np.random.seed(0)
|
def get_file_name_from_idx(self, idx: int) -> Optional[str]:
|
||||||
semi_targets = data_set.semi_targets.numpy()
|
"""
|
||||||
|
Get filename for a file_id by delegating to the appropriate dataset.
|
||||||
|
|
||||||
# Find indices where semi_targets is -1 (abnormal) or 1 (normal)
|
Args:
|
||||||
normal_indices = np.where(semi_targets == 1)[0]
|
idx: The file index to look up
|
||||||
abnormal_indices = np.where(semi_targets == -1)[0]
|
|
||||||
|
|
||||||
# Randomly select the specified number of indices to keep for each category
|
Returns:
|
||||||
if len(normal_indices) > num_known_normal:
|
str: The filename corresponding to the index, or None if not found
|
||||||
keep_normal_indices = np.random.choice(
|
"""
|
||||||
normal_indices, size=num_known_normal, replace=False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
keep_normal_indices = (
|
|
||||||
normal_indices # Keep all if there are fewer than required
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(abnormal_indices) > num_known_outlier:
|
# For non-inference, use any available dataset (they all have the same files)
|
||||||
keep_abnormal_indices = np.random.choice(
|
if self.data_set is not None:
|
||||||
abnormal_indices, size=num_known_outlier, replace=False
|
return self.data_set.get_file_name_from_idx(idx)
|
||||||
)
|
if self.train_set is not None:
|
||||||
else:
|
return self.train_set.get_file_name_from_idx(idx)
|
||||||
keep_abnormal_indices = (
|
if self.test_set is not None:
|
||||||
abnormal_indices # Keep all if there are fewer than required
|
return self.test_set.get_file_name_from_idx(idx)
|
||||||
)
|
if self.inference_set is not None:
|
||||||
|
return self.inference_set.get_file_name_from_idx(idx)
|
||||||
|
|
||||||
# Set all values to 0, then restore only the selected -1 and 1 values
|
return None
|
||||||
semi_targets[(semi_targets == 1) | (semi_targets == -1)] = 0
|
|
||||||
semi_targets[keep_normal_indices] = 1
|
|
||||||
semi_targets[keep_abnormal_indices] = -1
|
|
||||||
data_set.semi_targets = torch.tensor(semi_targets, dtype=torch.int8)
|
|
||||||
|
|
||||||
self.data_set = data_set
|
|
||||||
|
|
||||||
# # Create semi-supervised setting
|
|
||||||
# idx, _, semi_targets = create_semisupervised_setting(
|
|
||||||
# data_set.targets.cpu().data.numpy(),
|
|
||||||
# self.normal_classes,
|
|
||||||
# self.outlier_classes,
|
|
||||||
# self.outlier_classes,
|
|
||||||
# ratio_known_normal,
|
|
||||||
# ratio_known_outlier,
|
|
||||||
# ratio_pollution,
|
|
||||||
# )
|
|
||||||
# data_set.semi_targets[idx] = torch.tensor(
|
|
||||||
# np.array(semi_targets, dtype=np.int8)
|
|
||||||
# ) # set respective semi-supervised labels
|
|
||||||
|
|
||||||
# # Subset data_set to semi-supervised setup
|
|
||||||
# self.data_set = Subset(data_set, idx)
|
|
||||||
else:
|
|
||||||
# Get train set
|
|
||||||
if only_use_given_semi_targets_for_evaluation:
|
|
||||||
pass
|
|
||||||
train_set = SubTerTrainingSelective(
|
|
||||||
root=self.root,
|
|
||||||
transform=transform,
|
|
||||||
target_transform=target_transform,
|
|
||||||
train=True,
|
|
||||||
num_known_outlier=num_known_outlier,
|
|
||||||
semi_targets_given=semi_targets_given,
|
|
||||||
)
|
|
||||||
|
|
||||||
np.random.seed(0)
|
|
||||||
semi_targets = train_set.semi_targets.numpy()
|
|
||||||
|
|
||||||
# Find indices where semi_targets is -1 (abnormal) or 1 (normal)
|
|
||||||
normal_indices = np.where(semi_targets == 1)[0]
|
|
||||||
|
|
||||||
# Randomly select the specified number of indices to keep for each category
|
|
||||||
if len(normal_indices) > num_known_normal:
|
|
||||||
keep_normal_indices = np.random.choice(
|
|
||||||
normal_indices, size=num_known_normal, replace=False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
keep_normal_indices = (
|
|
||||||
normal_indices # Keep all if there are fewer than required
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set all values to 0, then restore only the selected -1 and 1 values
|
|
||||||
semi_targets[semi_targets == 1] = 0
|
|
||||||
semi_targets[keep_normal_indices] = 1
|
|
||||||
train_set.semi_targets = torch.tensor(
|
|
||||||
semi_targets, dtype=torch.int8
|
|
||||||
)
|
|
||||||
|
|
||||||
self.train_set = train_set
|
|
||||||
self.test_set = SubTerTrainingSelective(
|
|
||||||
root=self.root,
|
|
||||||
transform=transform,
|
|
||||||
target_transform=target_transform,
|
|
||||||
num_known_outlier=num_known_outlier,
|
|
||||||
train=False,
|
|
||||||
semi_targets_given=semi_targets_given,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
train_set = SubTerTraining(
|
|
||||||
root=self.root,
|
|
||||||
transform=transform,
|
|
||||||
target_transform=target_transform,
|
|
||||||
train=True,
|
|
||||||
semi_targets_given=semi_targets_given,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create semi-supervised setting
|
|
||||||
idx, _, semi_targets = create_semisupervised_setting(
|
|
||||||
train_set.targets.cpu().data.numpy(),
|
|
||||||
self.normal_classes,
|
|
||||||
self.outlier_classes,
|
|
||||||
self.outlier_classes,
|
|
||||||
ratio_known_normal,
|
|
||||||
ratio_known_outlier,
|
|
||||||
ratio_pollution,
|
|
||||||
)
|
|
||||||
train_set.semi_targets[idx] = torch.tensor(
|
|
||||||
np.array(semi_targets, dtype=np.int8)
|
|
||||||
) # set respective semi-supervised labels
|
|
||||||
|
|
||||||
# Subset train_set to semi-supervised setup
|
|
||||||
self.train_set = Subset(train_set, idx)
|
|
||||||
|
|
||||||
# Get test set
|
|
||||||
self.test_set = SubTerTraining(
|
|
||||||
root=self.root,
|
|
||||||
train=False,
|
|
||||||
transform=transform,
|
|
||||||
target_transform=target_transform,
|
|
||||||
semi_targets_given=semi_targets_given,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SubTerTraining(VisionDataset):
|
class SubTerTraining(VisionDataset):
|
||||||
|
"""
|
||||||
|
Loads all data, builds targets, and supports train/test split.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root: str,
|
root: str,
|
||||||
transforms: Optional[Callable] = None,
|
num_known_normal: int = 0,
|
||||||
|
num_known_outlier: int = 0,
|
||||||
transform: Optional[Callable] = None,
|
transform: Optional[Callable] = None,
|
||||||
target_transform: Optional[Callable] = None,
|
target_transform: Optional[Callable] = None,
|
||||||
train=False,
|
seed: int = 0,
|
||||||
split=0.7,
|
split: float = 0.7,
|
||||||
seed=0,
|
train: bool = True,
|
||||||
semi_targets_given=None,
|
|
||||||
only_use_given_semi_targets_for_evaluation=False,
|
|
||||||
):
|
):
|
||||||
super(SubTerTraining, self).__init__(
|
super().__init__(root, transform=transform, target_transform=target_transform)
|
||||||
root, transforms, transform, target_transform
|
|
||||||
)
|
|
||||||
|
|
||||||
experiments_data = []
|
|
||||||
experiments_targets = []
|
|
||||||
experiments_semi_targets = []
|
|
||||||
# validation_files = []
|
|
||||||
experiment_files = []
|
|
||||||
experiment_frame_ids = []
|
|
||||||
experiment_file_ids = []
|
|
||||||
file_names = {}
|
|
||||||
|
|
||||||
for file_idx, experiment_file in enumerate(sorted(Path(root).iterdir())):
|
|
||||||
# if experiment_file.is_dir() and experiment_file.name == "validation":
|
|
||||||
# for validation_file in experiment_file.iterdir():
|
|
||||||
# if validation_file.suffix != ".npy":
|
|
||||||
# continue
|
|
||||||
# validation_files.append(experiment_file)
|
|
||||||
if experiment_file.suffix != ".npy":
|
|
||||||
continue
|
|
||||||
file_names[file_idx] = experiment_file.name
|
|
||||||
experiment_files.append(experiment_file)
|
|
||||||
experiment_data = np.load(experiment_file)
|
|
||||||
experiment_targets = (
|
|
||||||
np.ones(experiment_data.shape[0], dtype=np.int8)
|
|
||||||
if "smoke" in experiment_file.name
|
|
||||||
else np.zeros(experiment_data.shape[0], dtype=np.int8)
|
|
||||||
)
|
|
||||||
# experiment_data = np.lib.format.open_memmap(experiment_file, mode='r+')
|
|
||||||
experiment_semi_targets = np.zeros(experiment_data.shape[0], dtype=np.int8)
|
|
||||||
if "smoke" not in experiment_file.name:
|
|
||||||
experiment_semi_targets = np.ones(
|
|
||||||
experiment_data.shape[0], dtype=np.int8
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if semi_targets_given:
|
|
||||||
if experiment_file.name in semi_targets_given:
|
|
||||||
semi_target_begin_frame, semi_target_end_frame = (
|
|
||||||
semi_targets_given[experiment_file.name]
|
|
||||||
)
|
|
||||||
experiment_semi_targets[
|
|
||||||
semi_target_begin_frame:semi_target_end_frame
|
|
||||||
] = -1
|
|
||||||
else:
|
|
||||||
experiment_semi_targets = (
|
|
||||||
np.ones(experiment_data.shape[0], dtype=np.int8) * -1
|
|
||||||
)
|
|
||||||
|
|
||||||
experiment_file_ids.append(
|
|
||||||
np.full(experiment_data.shape[0], file_idx, dtype=np.int8)
|
|
||||||
)
|
|
||||||
experiment_frame_ids.append(
|
|
||||||
np.arange(experiment_data.shape[0], dtype=np.int32)
|
|
||||||
)
|
|
||||||
experiments_data.append(experiment_data)
|
|
||||||
experiments_targets.append(experiment_targets)
|
|
||||||
experiments_semi_targets.append(experiment_semi_targets)
|
|
||||||
|
|
||||||
# filtered_validation_files = []
|
|
||||||
# for validation_file in validation_files:
|
|
||||||
# validation_file_name = validation_file.name
|
|
||||||
# file_exists_in_experiments = any(
|
|
||||||
# experiment_file.name == validation_file_name
|
|
||||||
# for experiment_file in experiment_files
|
|
||||||
# )
|
|
||||||
# if not file_exists_in_experiments:
|
|
||||||
# filtered_validation_files.append(validation_file)
|
|
||||||
# validation_files = filtered_validation_files
|
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
logger.info(
|
manual_json_path = Path(root) / "manually_labeled_anomaly_frames.json"
|
||||||
f"Train/Test experiments: {[experiment_file.name for experiment_file in experiment_files]}"
|
with open(manual_json_path, "r") as f:
|
||||||
)
|
manual_data = json.load(f)
|
||||||
# logger.info(
|
manual_anomaly_ranges = {
|
||||||
# f"Validation experiments: {[validation_file.name for validation_file in validation_files]}"
|
item["filename"]: (
|
||||||
# )
|
item["semi_target_begin_frame"],
|
||||||
|
item["semi_target_end_frame"],
|
||||||
|
)
|
||||||
|
for item in manual_data["files"]
|
||||||
|
}
|
||||||
|
|
||||||
lidar_projections = np.concatenate(experiments_data)
|
all_data = []
|
||||||
smoke_presence = np.concatenate(experiments_targets)
|
all_file_ids = []
|
||||||
semi_targets = np.concatenate(experiments_semi_targets)
|
all_frame_ids = []
|
||||||
file_ids = np.concatenate(experiment_file_ids)
|
all_filenames = []
|
||||||
frame_ids = np.concatenate(experiment_frame_ids)
|
test_target_experiment_based = []
|
||||||
|
test_target_manually_set = []
|
||||||
|
train_semi_targets = []
|
||||||
|
file_names = {}
|
||||||
|
file_idx = 0
|
||||||
|
|
||||||
|
for experiment_file in sorted(Path(root).iterdir()):
|
||||||
|
if experiment_file.suffix != ".npy":
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_names[file_idx] = experiment_file.name
|
||||||
|
experiment_data = np.load(experiment_file)
|
||||||
|
n_frames = experiment_data.shape[0]
|
||||||
|
|
||||||
|
is_smoke = "smoke" in experiment_file.name
|
||||||
|
if is_smoke:
|
||||||
|
if experiment_file.name not in manual_anomaly_ranges:
|
||||||
|
raise ValueError(
|
||||||
|
f"Experiment file {experiment_file.name} is marked as smoke but has no manual anomaly ranges."
|
||||||
|
)
|
||||||
|
manual_anomaly_start_frame, manual_anomaly_end_frame = (
|
||||||
|
manual_anomaly_ranges[experiment_file.name]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Experiment-based: 1 (normal), -1 (anomaly)
|
||||||
|
exp_based_targets = (
|
||||||
|
np.full(n_frames, -1, dtype=np.int8) # anomaly
|
||||||
|
if is_smoke
|
||||||
|
else np.full(n_frames, 1, dtype=np.int8) # normal
|
||||||
|
)
|
||||||
|
|
||||||
|
# Manually set: 1 (normal), -1 (anomaly), 0 (unknown/NaN)
|
||||||
|
if not is_smoke:
|
||||||
|
manual_targets = np.full(n_frames, 1, dtype=np.int8) # normal
|
||||||
|
else:
|
||||||
|
manual_targets = np.zeros(n_frames, dtype=np.int8) # unknown
|
||||||
|
manual_targets[
|
||||||
|
manual_anomaly_start_frame:manual_anomaly_end_frame
|
||||||
|
] = -1 # anomaly
|
||||||
|
|
||||||
|
# log how many manual anomaly frames were set to each value
|
||||||
|
logger.info(
|
||||||
|
f"Experiment {experiment_file.name}: "
|
||||||
|
f"Manual targets - normal(1): {np.sum(manual_targets == 1)}, "
|
||||||
|
f"anomaly(-1): {np.sum(manual_targets == -1)}, "
|
||||||
|
f"unknown(0): {np.sum(manual_targets == 0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Semi-supervised targets: 1 (known normal), -1 (known anomaly), 0 (unknown)
|
||||||
|
if not is_smoke:
|
||||||
|
semi_targets = np.ones(n_frames, dtype=np.int8) # normal
|
||||||
|
else:
|
||||||
|
semi_targets = np.zeros(n_frames, dtype=np.int8) # unknown
|
||||||
|
semi_targets[
|
||||||
|
manual_anomaly_start_frame:manual_anomaly_end_frame
|
||||||
|
] = -1 # anomaly
|
||||||
|
|
||||||
|
all_data.append(experiment_data)
|
||||||
|
all_file_ids.append(np.full(n_frames, file_idx, dtype=np.int32))
|
||||||
|
all_frame_ids.append(np.arange(n_frames, dtype=np.int32))
|
||||||
|
all_filenames.extend([experiment_file.name] * n_frames)
|
||||||
|
test_target_experiment_based.append(exp_based_targets)
|
||||||
|
test_target_manually_set.append(manual_targets)
|
||||||
|
train_semi_targets.append(semi_targets)
|
||||||
|
|
||||||
|
file_idx += 1
|
||||||
|
|
||||||
|
# Flatten everything
|
||||||
|
data = np.nan_to_num(np.concatenate(all_data))
|
||||||
|
file_ids = np.concatenate(all_file_ids)
|
||||||
|
frame_ids = np.concatenate(all_frame_ids)
|
||||||
|
filenames = all_filenames
|
||||||
self.file_names = file_names
|
self.file_names = file_names
|
||||||
|
|
||||||
|
test_target_experiment_based = np.concatenate(test_target_experiment_based)
|
||||||
|
test_target_manually_set = np.concatenate(test_target_manually_set)
|
||||||
|
semi_targets_np = np.concatenate(train_semi_targets)
|
||||||
|
|
||||||
|
# Limit the number of known normal/anomaly samples for training
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
normal_indices = np.where(semi_targets_np == 1)[0]
|
||||||
|
anomaly_indices = np.where(semi_targets_np == -1)[0]
|
||||||
|
|
||||||
shuffled_indices = np.random.permutation(lidar_projections.shape[0])
|
if num_known_normal > 0 and len(normal_indices) > num_known_normal:
|
||||||
shuffled_lidar_projections = lidar_projections[shuffled_indices]
|
keep_normal = np.random.choice(
|
||||||
shuffled_smoke_presence = smoke_presence[shuffled_indices]
|
normal_indices, size=num_known_normal, replace=False
|
||||||
shuffled_file_ids = file_ids[shuffled_indices]
|
)
|
||||||
shuffled_frame_ids = frame_ids[shuffled_indices]
|
else:
|
||||||
shuffled_semis = semi_targets[shuffled_indices]
|
keep_normal = normal_indices
|
||||||
|
|
||||||
split_idx = int(split * shuffled_lidar_projections.shape[0])
|
if num_known_outlier > 0 and len(anomaly_indices) > num_known_outlier:
|
||||||
|
keep_anomaly = np.random.choice(
|
||||||
|
anomaly_indices, size=num_known_outlier, replace=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
keep_anomaly = anomaly_indices
|
||||||
|
|
||||||
|
semi_targets_np[(semi_targets_np == 1) | (semi_targets_np == -1)] = 0
|
||||||
|
semi_targets_np[keep_normal] = 1
|
||||||
|
semi_targets_np[keep_anomaly] = -1
|
||||||
|
|
||||||
|
# Shuffle and split
|
||||||
|
indices = np.arange(len(data))
|
||||||
|
np.random.seed(seed)
|
||||||
|
np.random.shuffle(indices)
|
||||||
|
split_idx = int(split * len(data))
|
||||||
if train:
|
if train:
|
||||||
self.data = shuffled_lidar_projections[:split_idx]
|
use_idx = indices[:split_idx]
|
||||||
self.targets = shuffled_smoke_presence[:split_idx]
|
|
||||||
semi_targets = shuffled_semis[:split_idx]
|
|
||||||
self.shuffled_file_ids = shuffled_file_ids[:split_idx]
|
|
||||||
self.shuffled_frame_ids = shuffled_frame_ids[:split_idx]
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.data = shuffled_lidar_projections[split_idx:]
|
use_idx = indices[split_idx:]
|
||||||
self.targets = shuffled_smoke_presence[split_idx:]
|
|
||||||
semi_targets = shuffled_semis[split_idx:]
|
|
||||||
self.shuffled_file_ids = shuffled_file_ids[split_idx:]
|
|
||||||
self.shuffled_frame_ids = shuffled_frame_ids[split_idx:]
|
|
||||||
|
|
||||||
self.data = np.nan_to_num(self.data)
|
self.data = torch.tensor(data[use_idx])
|
||||||
|
self.file_ids = file_ids[use_idx]
|
||||||
|
self.frame_ids = frame_ids[use_idx]
|
||||||
|
self.filenames = [filenames[i] for i in use_idx]
|
||||||
|
self.test_target_experiment_based = torch.tensor(
|
||||||
|
test_target_experiment_based[use_idx], dtype=torch.int8
|
||||||
|
)
|
||||||
|
self.test_target_manually_set = torch.tensor(
|
||||||
|
test_target_manually_set[use_idx], dtype=torch.int8
|
||||||
|
)
|
||||||
|
|
||||||
self.data = torch.tensor(self.data)
|
# log how many of the test_target_manually_set are in each category
|
||||||
self.targets = torch.tensor(self.targets, dtype=torch.int8)
|
logger.info(
|
||||||
|
f"Test targets - normal(1): {np.sum(self.test_target_manually_set.numpy() == 1)}, "
|
||||||
|
f"anomaly(-1): {np.sum(self.test_target_manually_set.numpy() == -1)}, "
|
||||||
|
f"unknown(0): {np.sum(self.test_target_manually_set.numpy() == 0)}"
|
||||||
|
)
|
||||||
|
|
||||||
if semi_targets_given is not None:
|
self.train_semi_targets = torch.tensor(
|
||||||
self.semi_targets = torch.tensor(semi_targets, dtype=torch.int8)
|
semi_targets_np[use_idx], dtype=torch.int8
|
||||||
else:
|
)
|
||||||
self.semi_targets = torch.zeros_like(self.targets, dtype=torch.int8)
|
|
||||||
|
self.transform = transform if transform else transforms.ToTensor()
|
||||||
|
self.target_transform = target_transform
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
"""Override the original method of the MNIST class.
|
img = self.data[index]
|
||||||
Args:
|
target_experiment_based = int(self.test_target_experiment_based[index])
|
||||||
index (int): Index
|
target_manually_set = int(self.test_target_manually_set[index])
|
||||||
|
semi_target = int(self.train_semi_targets[index])
|
||||||
|
file_id = int(self.file_ids[index])
|
||||||
|
frame_id = int(self.frame_ids[index])
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (image, target, semi_target, index)
|
|
||||||
"""
|
|
||||||
img, target, semi_target, file_id, frame_id = (
|
|
||||||
self.data[index],
|
|
||||||
int(self.targets[index]),
|
|
||||||
int(self.semi_targets[index]),
|
|
||||||
int(self.shuffled_file_ids[index]),
|
|
||||||
int(self.shuffled_frame_ids[index]),
|
|
||||||
)
|
|
||||||
|
|
||||||
# doing this so that it is consistent with all other datasets
|
|
||||||
# to return a PIL Image
|
|
||||||
img = Image.fromarray(img.numpy(), mode="F")
|
img = Image.fromarray(img.numpy(), mode="F")
|
||||||
|
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
img = self.transform(img)
|
img = self.transform(img)
|
||||||
|
|
||||||
if self.target_transform is not None:
|
if self.target_transform is not None:
|
||||||
target = self.target_transform(target)
|
target_experiment_based = self.target_transform(target_experiment_based)
|
||||||
|
target_manually_set = self.target_transform(target_manually_set)
|
||||||
|
semi_target = self.target_transform(semi_target)
|
||||||
|
|
||||||
return img, target, semi_target, index, (file_id, frame_id)
|
return (
|
||||||
|
img,
|
||||||
|
target_experiment_based,
|
||||||
|
target_manually_set,
|
||||||
|
semi_target,
|
||||||
|
index,
|
||||||
|
(file_id, frame_id),
|
||||||
|
)
|
||||||
|
|
||||||
def get_file_name_from_idx(self, idx: int):
|
def get_file_name_from_idx(self, idx: int):
|
||||||
return self.file_names[idx]
|
return self.file_names.get(idx, None)
|
||||||
|
|
||||||
|
|
||||||
class SubTerInference(VisionDataset):
|
class SubTerInference(VisionDataset):
|
||||||
|
"""
|
||||||
|
Loads a single experiment file for inference.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root: str,
|
root: str,
|
||||||
transforms: Optional[Callable] = None,
|
|
||||||
transform: Optional[Callable] = None,
|
transform: Optional[Callable] = None,
|
||||||
):
|
):
|
||||||
super(SubTerInference, self).__init__(root, transforms, transform)
|
super().__init__(root, transform=transform)
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
self.experiment_file_path = Path(root)
|
experiment_file = Path(root)
|
||||||
|
if not experiment_file.is_file():
|
||||||
if not self.experiment_file_path.is_file():
|
|
||||||
logger.error(
|
logger.error(
|
||||||
"For inference the data path has to be a single experiment file!"
|
"For inference the data path has to be a single experiment file!"
|
||||||
)
|
)
|
||||||
raise Exception("Inference data is not a loadable file!")
|
raise Exception("Inference data is not a loadable file!")
|
||||||
|
|
||||||
self.data = np.load(self.experiment_file_path)
|
self.data = np.load(experiment_file)
|
||||||
self.data = np.nan_to_num(self.data)
|
self.data = np.nan_to_num(self.data)
|
||||||
self.data = torch.tensor(self.data)
|
self.data = torch.tensor(self.data)
|
||||||
|
self.filenames = [experiment_file.name] * self.data.shape[0]
|
||||||
|
self.file_ids = np.zeros(self.data.shape[0], dtype=np.int32)
|
||||||
|
self.frame_ids = np.arange(self.data.shape[0], dtype=np.int32)
|
||||||
|
self.file_names = {0: experiment_file.name}
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
"""Override the original method of the MNIST class.
|
|
||||||
Args:
|
|
||||||
index (int): Index
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (image, index)
|
|
||||||
"""
|
|
||||||
img = self.data[index]
|
img = self.data[index]
|
||||||
|
file_id = int(self.file_ids[index])
|
||||||
|
frame_id = int(self.frame_ids[index])
|
||||||
|
|
||||||
# doing this so that it is consistent with all other datasets
|
|
||||||
# to return a PIL Image
|
|
||||||
img = Image.fromarray(img.numpy(), mode="F")
|
img = Image.fromarray(img.numpy(), mode="F")
|
||||||
|
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
img = self.transform(img)
|
img = self.transform(img)
|
||||||
|
|
||||||
return img, index
|
return img, index, (file_id, frame_id)
|
||||||
|
|
||||||
|
|
||||||
class SubTerTrainingSelective(VisionDataset):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
root: str,
|
|
||||||
transforms: Optional[Callable] = None,
|
|
||||||
transform: Optional[Callable] = None,
|
|
||||||
target_transform: Optional[Callable] = None,
|
|
||||||
train=False,
|
|
||||||
num_known_outlier=0,
|
|
||||||
seed=0,
|
|
||||||
semi_targets_given=None,
|
|
||||||
ratio_test_normal_to_anomalous=3,
|
|
||||||
):
|
|
||||||
super(SubTerTrainingSelective, self).__init__(
|
|
||||||
root, transforms, transform, target_transform
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger()
|
|
||||||
|
|
||||||
if semi_targets_given is None:
|
|
||||||
raise ValueError(
|
|
||||||
"semi_targets_given must be provided for selective training"
|
|
||||||
)
|
|
||||||
|
|
||||||
experiments_data = []
|
|
||||||
experiments_targets = []
|
|
||||||
experiments_semi_targets = []
|
|
||||||
# validation_files = []
|
|
||||||
experiment_files = []
|
|
||||||
experiment_frame_ids = []
|
|
||||||
experiment_file_ids = []
|
|
||||||
file_names = {}
|
|
||||||
|
|
||||||
for file_idx, experiment_file in enumerate(sorted(Path(root).iterdir())):
|
|
||||||
if experiment_file.suffix != ".npy":
|
|
||||||
continue
|
|
||||||
|
|
||||||
file_names[file_idx] = experiment_file.name
|
|
||||||
experiment_files.append(experiment_file)
|
|
||||||
experiment_data = np.load(experiment_file)
|
|
||||||
|
|
||||||
experiment_targets = (
|
|
||||||
np.ones(experiment_data.shape[0], dtype=np.int8)
|
|
||||||
if "smoke" in experiment_file.name
|
|
||||||
else np.zeros(experiment_data.shape[0], dtype=np.int8)
|
|
||||||
)
|
|
||||||
|
|
||||||
experiment_semi_targets = np.zeros(experiment_data.shape[0], dtype=np.int8)
|
|
||||||
if "smoke" not in experiment_file.name:
|
|
||||||
experiment_semi_targets = np.ones(
|
|
||||||
experiment_data.shape[0], dtype=np.int8
|
|
||||||
)
|
|
||||||
elif experiment_file.name in semi_targets_given:
|
|
||||||
semi_target_begin_frame, semi_target_end_frame = semi_targets_given[
|
|
||||||
experiment_file.name
|
|
||||||
]
|
|
||||||
experiment_semi_targets[
|
|
||||||
semi_target_begin_frame:semi_target_end_frame
|
|
||||||
] = -1
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"smoke experiment not in given semi_targets. required for selective training"
|
|
||||||
)
|
|
||||||
|
|
||||||
experiment_file_ids.append(
|
|
||||||
np.full(experiment_data.shape[0], file_idx, dtype=np.int8)
|
|
||||||
)
|
|
||||||
experiment_frame_ids.append(
|
|
||||||
np.arange(experiment_data.shape[0], dtype=np.int32)
|
|
||||||
)
|
|
||||||
experiments_data.append(experiment_data)
|
|
||||||
experiments_targets.append(experiment_targets)
|
|
||||||
experiments_semi_targets.append(experiment_semi_targets)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Train/Test experiments: {[experiment_file.name for experiment_file in experiment_files]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
lidar_projections = np.concatenate(experiments_data)
|
|
||||||
smoke_presence = np.concatenate(experiments_targets)
|
|
||||||
semi_targets = np.concatenate(experiments_semi_targets)
|
|
||||||
file_ids = np.concatenate(experiment_file_ids)
|
|
||||||
frame_ids = np.concatenate(experiment_frame_ids)
|
|
||||||
self.file_names = file_names
|
|
||||||
|
|
||||||
np.random.seed(seed)
|
|
||||||
|
|
||||||
shuffled_indices = np.random.permutation(lidar_projections.shape[0])
|
|
||||||
shuffled_lidar_projections = lidar_projections[shuffled_indices]
|
|
||||||
shuffled_smoke_presence = smoke_presence[shuffled_indices]
|
|
||||||
shuffled_file_ids = file_ids[shuffled_indices]
|
|
||||||
shuffled_frame_ids = frame_ids[shuffled_indices]
|
|
||||||
shuffled_semis = semi_targets[shuffled_indices]
|
|
||||||
|
|
||||||
# check if there are enough known normal and known outlier samples
|
|
||||||
outlier_indices = np.where(shuffled_semis == -1)[0]
|
|
||||||
normal_indices = np.where(shuffled_semis == 1)[0]
|
|
||||||
|
|
||||||
if len(outlier_indices) < num_known_outlier:
|
|
||||||
raise ValueError(
|
|
||||||
f"Not enough known outliers in dataset. Required: {num_known_outlier}, Found: {len(outlier_indices)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# randomly select known normal and outlier samples
|
|
||||||
keep_outlier_indices = np.random.choice(
|
|
||||||
outlier_indices, size=num_known_outlier, replace=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# put outliers that are not kept into test set and the same number of normal samples aside for testing
|
|
||||||
test_outlier_indices = np.setdiff1d(outlier_indices, keep_outlier_indices)
|
|
||||||
num_test_outliers = len(test_outlier_indices)
|
|
||||||
test_normal_indices = np.random.choice(
|
|
||||||
normal_indices,
|
|
||||||
size=num_test_outliers * ratio_test_normal_to_anomalous,
|
|
||||||
replace=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# combine test indices
|
|
||||||
test_indices = np.concatenate([test_outlier_indices, test_normal_indices])
|
|
||||||
|
|
||||||
# training indices are the rest
|
|
||||||
train_indices = np.setdiff1d(np.arange(len(shuffled_semis)), test_indices)
|
|
||||||
|
|
||||||
if train:
|
|
||||||
self.data = shuffled_lidar_projections[train_indices]
|
|
||||||
self.targets = shuffled_smoke_presence[train_indices]
|
|
||||||
semi_targets = shuffled_semis[train_indices]
|
|
||||||
self.shuffled_file_ids = shuffled_file_ids[train_indices]
|
|
||||||
self.shuffled_frame_ids = shuffled_frame_ids[train_indices]
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.data = shuffled_lidar_projections[test_indices]
|
|
||||||
self.targets = shuffled_smoke_presence[test_indices]
|
|
||||||
semi_targets = shuffled_semis[test_indices]
|
|
||||||
self.shuffled_file_ids = shuffled_file_ids[test_indices]
|
|
||||||
self.shuffled_frame_ids = shuffled_frame_ids[test_indices]
|
|
||||||
|
|
||||||
self.data = np.nan_to_num(self.data)
|
|
||||||
|
|
||||||
self.data = torch.tensor(self.data)
|
|
||||||
self.targets = torch.tensor(self.targets, dtype=torch.int8)
|
|
||||||
self.semi_targets = torch.tensor(semi_targets, dtype=torch.int8)
|
|
||||||
|
|
||||||
# log some stats to ensure the data is loaded correctly
|
|
||||||
if train:
|
|
||||||
logger.info(
|
|
||||||
f"Training set: {len(self.data)} samples, {sum(self.semi_targets == -1)} semi-supervised samples"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f"Test set: {len(self.data)} samples, {sum(self.semi_targets == -1)} semi-supervised samples"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
"""Override the original method of the MNIST class.
|
|
||||||
Args:
|
|
||||||
index (int): Index
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (image, target, semi_target, index)
|
|
||||||
"""
|
|
||||||
img, target, semi_target, file_id, frame_id = (
|
|
||||||
self.data[index],
|
|
||||||
int(self.targets[index]),
|
|
||||||
int(self.semi_targets[index]),
|
|
||||||
int(self.shuffled_file_ids[index]),
|
|
||||||
int(self.shuffled_frame_ids[index]),
|
|
||||||
)
|
|
||||||
|
|
||||||
# doing this so that it is consistent with all other datasets
|
|
||||||
# to return a PIL Image
|
|
||||||
img = Image.fromarray(img.numpy(), mode="F")
|
|
||||||
|
|
||||||
if self.transform is not None:
|
|
||||||
img = self.transform(img)
|
|
||||||
|
|
||||||
if self.target_transform is not None:
|
|
||||||
target = self.target_transform(target)
|
|
||||||
|
|
||||||
return img, target, semi_target, index, (file_id, frame_id)
|
|
||||||
|
|
||||||
def get_file_name_from_idx(self, idx: int):
|
def get_file_name_from_idx(self, idx: int):
|
||||||
return self.file_names[idx]
|
return self.file_names.get(idx, None)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from utils.visualization.plot_images_grid import plot_images_grid
|
|||||||
[
|
[
|
||||||
"train",
|
"train",
|
||||||
"infer",
|
"infer",
|
||||||
|
"ae_elbow_test", # Add new action
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -76,8 +77,8 @@ from utils.visualization.plot_images_grid import plot_images_grid
|
|||||||
@click.option(
|
@click.option(
|
||||||
"--k_fold_num",
|
"--k_fold_num",
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=None,
|
||||||
help="Number of folds for k-fold cross-validation (default: 5).",
|
help="Number of folds for k-fold cross-validation (default: None).",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--num_known_normal",
|
"--num_known_normal",
|
||||||
@@ -277,6 +278,13 @@ from utils.visualization.plot_images_grid import plot_images_grid
|
|||||||
default=-1,
|
default=-1,
|
||||||
help="Number of jobs for model training.",
|
help="Number of jobs for model training.",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--ae_elbow_dims",
|
||||||
|
type=int,
|
||||||
|
multiple=True,
|
||||||
|
default=[128, 256, 384, 512, 768, 1024],
|
||||||
|
help="List of latent space dimensions to test for autoencoder elbow analysis.",
|
||||||
|
)
|
||||||
def main(
|
def main(
|
||||||
action,
|
action,
|
||||||
dataset_name,
|
dataset_name,
|
||||||
@@ -319,6 +327,7 @@ def main(
|
|||||||
isoforest_max_samples,
|
isoforest_max_samples,
|
||||||
isoforest_contamination,
|
isoforest_contamination,
|
||||||
isoforest_n_jobs_model,
|
isoforest_n_jobs_model,
|
||||||
|
ae_elbow_dims,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Deep SAD, a method for deep semi-supervised anomaly detection.
|
Deep SAD, a method for deep semi-supervised anomaly detection.
|
||||||
@@ -402,7 +411,7 @@ def main(
|
|||||||
ratio_known_outlier,
|
ratio_known_outlier,
|
||||||
ratio_pollution,
|
ratio_pollution,
|
||||||
random_state=np.random.RandomState(cfg.settings["seed"]),
|
random_state=np.random.RandomState(cfg.settings["seed"]),
|
||||||
k_fold=k_fold,
|
k_fold_num=k_fold_num,
|
||||||
num_known_normal=num_known_normal,
|
num_known_normal=num_known_normal,
|
||||||
num_known_outlier=num_known_outlier,
|
num_known_outlier=num_known_outlier,
|
||||||
)
|
)
|
||||||
@@ -593,18 +602,35 @@ def main(
|
|||||||
|
|
||||||
# Plot most anomalous and most normal test samples
|
# Plot most anomalous and most normal test samples
|
||||||
if train_deepsad:
|
if train_deepsad:
|
||||||
indices, labels, scores = zip(*deepSAD.results["test_scores"])
|
# Use experiment-based scores for plotting
|
||||||
|
indices, labels, scores = zip(
|
||||||
|
*deepSAD.results["test"]["exp_based"]["scores"]
|
||||||
|
)
|
||||||
indices, labels, scores = (
|
indices, labels, scores = (
|
||||||
np.array(indices),
|
np.array(indices),
|
||||||
np.array(labels),
|
np.array(labels),
|
||||||
np.array(scores),
|
np.array(scores),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Filter out samples with unknown labels (0)
|
||||||
|
valid_mask = labels != 0
|
||||||
|
indices = indices[valid_mask]
|
||||||
|
labels = labels[valid_mask]
|
||||||
|
scores = scores[valid_mask]
|
||||||
|
|
||||||
|
# Convert labels from -1/1 to 0/1 for plotting
|
||||||
|
labels = (labels == -1).astype(int) # -1 (anomaly) → 1, 1 (normal) → 0
|
||||||
|
|
||||||
idx_all_sorted = indices[
|
idx_all_sorted = indices[
|
||||||
np.argsort(scores)
|
np.argsort(scores)
|
||||||
] # from lowest to highest score
|
] # from lowest to highest score
|
||||||
idx_normal_sorted = indices[labels == 0][
|
idx_normal_sorted = indices[labels == 0][
|
||||||
np.argsort(scores[labels == 0])
|
np.argsort(scores[labels == 0])
|
||||||
] # from lowest to highest score
|
]
|
||||||
|
|
||||||
|
# Optionally plot manual-based results:
|
||||||
|
# indices_m, labels_m, scores_m = zip(*deepSAD.results["test"]["manual_based"]["scores"])
|
||||||
|
# ...same processing as above...
|
||||||
|
|
||||||
if dataset_name in (
|
if dataset_name in (
|
||||||
"mnist",
|
"mnist",
|
||||||
@@ -745,6 +771,71 @@ def main(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Inference: median={np.median(inference_results)} mean={np.mean(inference_results)} min={inference_results.min()} max={inference_results.max()}"
|
f"Inference: median={np.median(inference_results)} mean={np.mean(inference_results)} min={inference_results.min()} max={inference_results.max()}"
|
||||||
)
|
)
|
||||||
|
elif action == "ae_elbow_test":
|
||||||
|
# Load data once
|
||||||
|
dataset = load_dataset(
|
||||||
|
dataset_name,
|
||||||
|
data_path,
|
||||||
|
normal_class,
|
||||||
|
known_outlier_class,
|
||||||
|
n_known_outlier_classes,
|
||||||
|
ratio_known_normal,
|
||||||
|
ratio_known_outlier,
|
||||||
|
ratio_pollution,
|
||||||
|
random_state=np.random.RandomState(cfg.settings["seed"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dictionary to store results for each dimension
|
||||||
|
elbow_results = {"dimensions": list(ae_elbow_dims), "ae_results": {}}
|
||||||
|
|
||||||
|
# Test each dimension
|
||||||
|
for rep_dim in ae_elbow_dims:
|
||||||
|
logger.info(f"Testing autoencoder with latent dimension: {rep_dim}")
|
||||||
|
|
||||||
|
# Initialize DeepSAD model with current dimension
|
||||||
|
deepSAD = DeepSAD(cfg.settings["eta"])
|
||||||
|
deepSAD.set_network(
|
||||||
|
net_name, rep_dim=rep_dim
|
||||||
|
) # Pass rep_dim to network builder
|
||||||
|
|
||||||
|
# Pretrain autoencoder with current dimension
|
||||||
|
deepSAD.pretrain(
|
||||||
|
dataset,
|
||||||
|
optimizer_name=cfg.settings["ae_optimizer_name"],
|
||||||
|
lr=cfg.settings["ae_lr"],
|
||||||
|
n_epochs=cfg.settings["ae_n_epochs"],
|
||||||
|
lr_milestones=cfg.settings["ae_lr_milestone"],
|
||||||
|
batch_size=cfg.settings["ae_batch_size"],
|
||||||
|
weight_decay=cfg.settings["ae_weight_decay"],
|
||||||
|
device=device,
|
||||||
|
n_jobs_dataloader=n_jobs_dataloader,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store results for this dimension
|
||||||
|
elbow_results["ae_results"][rep_dim] = {
|
||||||
|
"train_time": deepSAD.ae.train_time,
|
||||||
|
"train_loss": deepSAD.ae.train_loss,
|
||||||
|
"test_auc": deepSAD.ae.test_auc, # if available
|
||||||
|
"test_loss": deepSAD.ae.test_loss,
|
||||||
|
"scores": deepSAD.ae.test_scores,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Finished testing dimension {rep_dim}")
|
||||||
|
logger.info(f"Train time: {deepSAD.ae.train_time:.3f}s")
|
||||||
|
logger.info(f"Final train loss: {deepSAD.ae.train_loss[-1]:.6f}")
|
||||||
|
logger.info(f"Final test loss: {deepSAD.ae.test_loss:.6f}")
|
||||||
|
|
||||||
|
# Clear some memory
|
||||||
|
del deepSAD
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Save all results
|
||||||
|
results_path = Path(xp_path) / "ae_elbow_results.pkl"
|
||||||
|
with open(results_path, "wb") as f:
|
||||||
|
pickle.dump(elbow_results, f)
|
||||||
|
|
||||||
|
logger.info(f"Saved elbow test results to {results_path}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unknown action: {action}")
|
logger.error(f"Unknown action: {action}")
|
||||||
|
|
||||||
|
|||||||
@@ -55,6 +55,15 @@ class DeepSADTrainer(BaseTrainer):
|
|||||||
self.test_time = None
|
self.test_time = None
|
||||||
self.test_scores = None
|
self.test_scores = None
|
||||||
|
|
||||||
|
# Add new attributes for storing indices
|
||||||
|
self.train_indices = None
|
||||||
|
self.train_file_ids = None
|
||||||
|
self.train_frame_ids = None
|
||||||
|
|
||||||
|
self.test_indices = None
|
||||||
|
self.test_file_ids = None
|
||||||
|
self.test_frame_ids = None
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self, dataset: BaseADDataset, net: BaseNet, k_fold_idx: int = None
|
self, dataset: BaseADDataset, net: BaseNet, k_fold_idx: int = None
|
||||||
) -> BaseNet:
|
) -> BaseNet:
|
||||||
@@ -95,12 +104,18 @@ class DeepSADTrainer(BaseTrainer):
|
|||||||
logger.info("Starting training...")
|
logger.info("Starting training...")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
net.train()
|
net.train()
|
||||||
|
|
||||||
|
# Lists to collect all indices during training
|
||||||
|
all_indices = []
|
||||||
|
all_file_ids = []
|
||||||
|
all_frame_ids = []
|
||||||
|
|
||||||
for epoch in range(self.n_epochs):
|
for epoch in range(self.n_epochs):
|
||||||
epoch_loss = 0.0
|
epoch_loss = 0.0
|
||||||
n_batches = 0
|
n_batches = 0
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
for data in train_loader:
|
for data in train_loader:
|
||||||
inputs, _, semi_targets, _, _ = data
|
inputs, _, _, semi_targets, idx, (file_id, frame_id) = data
|
||||||
inputs, semi_targets = (
|
inputs, semi_targets = (
|
||||||
inputs.to(self.device),
|
inputs.to(self.device),
|
||||||
semi_targets.to(self.device),
|
semi_targets.to(self.device),
|
||||||
@@ -124,6 +139,11 @@ class DeepSADTrainer(BaseTrainer):
|
|||||||
epoch_loss += loss.item()
|
epoch_loss += loss.item()
|
||||||
n_batches += 1
|
n_batches += 1
|
||||||
|
|
||||||
|
# Store indices
|
||||||
|
all_indices.extend(idx.cpu().numpy().tolist())
|
||||||
|
all_file_ids.extend(file_id.cpu().numpy().tolist())
|
||||||
|
all_frame_ids.extend(frame_id.cpu().numpy().tolist())
|
||||||
|
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
if epoch in self.lr_milestones:
|
if epoch in self.lr_milestones:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -142,6 +162,11 @@ class DeepSADTrainer(BaseTrainer):
|
|||||||
logger.info("Training Time: {:.3f}s".format(self.train_time))
|
logger.info("Training Time: {:.3f}s".format(self.train_time))
|
||||||
logger.info("Finished training.")
|
logger.info("Finished training.")
|
||||||
|
|
||||||
|
# Store all training indices
|
||||||
|
self.train_indices = np.array(all_indices)
|
||||||
|
self.train_file_ids = np.array(all_file_ids)
|
||||||
|
self.train_frame_ids = np.array(all_frame_ids)
|
||||||
|
|
||||||
return net
|
return net
|
||||||
|
|
||||||
def infer(self, dataset: BaseADDataset, net: BaseNet):
|
def infer(self, dataset: BaseADDataset, net: BaseNet):
|
||||||
@@ -162,9 +187,14 @@ class DeepSADTrainer(BaseTrainer):
|
|||||||
all_outputs = np.zeros((len(inference_loader.dataset), 1024), dtype=np.float32)
|
all_outputs = np.zeros((len(inference_loader.dataset), 1024), dtype=np.float32)
|
||||||
scores = []
|
scores = []
|
||||||
net.eval()
|
net.eval()
|
||||||
|
|
||||||
|
all_indices = []
|
||||||
|
all_file_ids = []
|
||||||
|
all_frame_ids = []
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for data in inference_loader:
|
for data in inference_loader:
|
||||||
inputs, idx = data
|
inputs, idx, (file_id, frame_id) = data
|
||||||
|
|
||||||
inputs = inputs.to(self.device)
|
inputs = inputs.to(self.device)
|
||||||
idx = idx.to(self.device)
|
idx = idx.to(self.device)
|
||||||
@@ -177,6 +207,11 @@ class DeepSADTrainer(BaseTrainer):
|
|||||||
dist = torch.sum((outputs - self.c) ** 2, dim=1)
|
dist = torch.sum((outputs - self.c) ** 2, dim=1)
|
||||||
scores += dist.cpu().data.numpy().tolist()
|
scores += dist.cpu().data.numpy().tolist()
|
||||||
|
|
||||||
|
# Store indices
|
||||||
|
all_indices.extend(idx.cpu().numpy().tolist())
|
||||||
|
all_file_ids.extend(file_id.cpu().numpy().tolist())
|
||||||
|
all_frame_ids.extend(frame_id.cpu().numpy().tolist())
|
||||||
|
|
||||||
n_batches += 1
|
n_batches += 1
|
||||||
|
|
||||||
self.inference_time = time.time() - start_time
|
self.inference_time = time.time() - start_time
|
||||||
@@ -185,6 +220,17 @@ class DeepSADTrainer(BaseTrainer):
|
|||||||
logger.info("Inference Time: {:.3f}s".format(self.inference_time))
|
logger.info("Inference Time: {:.3f}s".format(self.inference_time))
|
||||||
logger.info("Finished inference.")
|
logger.info("Finished inference.")
|
||||||
|
|
||||||
|
# Store all inference indices
|
||||||
|
self.inference_indices = np.array(all_indices)
|
||||||
|
self.inference_file_ids = np.array(all_file_ids)
|
||||||
|
self.inference_frame_ids = np.array(all_frame_ids)
|
||||||
|
|
||||||
|
self.inference_index_mapping = {
|
||||||
|
"indices": self.inference_indices,
|
||||||
|
"file_ids": self.inference_file_ids,
|
||||||
|
"frame_ids": self.inference_frame_ids,
|
||||||
|
}
|
||||||
|
|
||||||
return np.array(scores), all_outputs
|
return np.array(scores), all_outputs
|
||||||
|
|
||||||
def test(self, dataset: BaseADDataset, net: BaseNet, k_fold_idx: int = None):
|
def test(self, dataset: BaseADDataset, net: BaseNet, k_fold_idx: int = None):
|
||||||
@@ -210,15 +256,34 @@ class DeepSADTrainer(BaseTrainer):
|
|||||||
epoch_loss = 0.0
|
epoch_loss = 0.0
|
||||||
n_batches = 0
|
n_batches = 0
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
idx_label_score = []
|
idx_label_score_exp = []
|
||||||
|
idx_label_score_manual = []
|
||||||
|
all_labels_exp = []
|
||||||
|
all_labels_manual = []
|
||||||
|
all_scores = []
|
||||||
|
all_idx = []
|
||||||
|
|
||||||
|
# Lists to collect all indices during testing
|
||||||
|
all_indices = []
|
||||||
|
all_file_ids = []
|
||||||
|
all_frame_ids = []
|
||||||
|
|
||||||
net.eval()
|
net.eval()
|
||||||
net.summary(receptive_field=True)
|
net.summary(receptive_field=True)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for data in test_loader:
|
for data in test_loader:
|
||||||
inputs, labels, semi_targets, idx, _ = data
|
(
|
||||||
|
inputs,
|
||||||
|
labels_exp_based,
|
||||||
|
labels_manual_based,
|
||||||
|
semi_targets,
|
||||||
|
idx,
|
||||||
|
(file_id, frame_id),
|
||||||
|
) = data
|
||||||
|
|
||||||
inputs = inputs.to(self.device)
|
inputs = inputs.to(self.device)
|
||||||
labels = labels.to(self.device)
|
labels_exp_based = labels_exp_based.to(self.device)
|
||||||
|
labels_manual_based = labels_manual_based.to(self.device)
|
||||||
semi_targets = semi_targets.to(self.device)
|
semi_targets = semi_targets.to(self.device)
|
||||||
idx = idx.to(self.device)
|
idx = idx.to(self.device)
|
||||||
|
|
||||||
@@ -232,34 +297,161 @@ class DeepSADTrainer(BaseTrainer):
|
|||||||
loss = torch.mean(losses)
|
loss = torch.mean(losses)
|
||||||
scores = dist
|
scores = dist
|
||||||
|
|
||||||
# Save triples of (idx, label, score) in a list
|
# Save for evaluation
|
||||||
idx_label_score += list(
|
idx_label_score_exp += list(
|
||||||
zip(
|
zip(
|
||||||
idx.cpu().data.numpy().tolist(),
|
idx.cpu().data.numpy().tolist(),
|
||||||
labels.cpu().data.numpy().tolist(),
|
labels_exp_based.cpu().data.numpy().tolist(),
|
||||||
scores.cpu().data.numpy().tolist(),
|
scores.cpu().data.numpy().tolist(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
idx_label_score_manual += list(
|
||||||
|
zip(
|
||||||
|
idx.cpu().data.numpy().tolist(),
|
||||||
|
labels_manual_based.cpu().data.numpy().tolist(),
|
||||||
|
scores.cpu().data.numpy().tolist(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
all_labels_exp.append(labels_exp_based.cpu().numpy())
|
||||||
|
all_labels_manual.append(labels_manual_based.cpu().numpy())
|
||||||
|
all_scores.append(scores.cpu().numpy())
|
||||||
|
all_idx.append(idx.cpu().numpy())
|
||||||
|
|
||||||
|
# Store indices
|
||||||
|
all_indices.extend(idx.cpu().numpy().tolist())
|
||||||
|
all_file_ids.extend(file_id.cpu().numpy().tolist())
|
||||||
|
all_frame_ids.extend(frame_id.cpu().numpy().tolist())
|
||||||
|
|
||||||
epoch_loss += loss.item()
|
epoch_loss += loss.item()
|
||||||
n_batches += 1
|
n_batches += 1
|
||||||
|
|
||||||
self.test_time = time.time() - start_time
|
self.test_time = time.time() - start_time
|
||||||
self.test_scores = idx_label_score
|
self.test_scores_exp_based = idx_label_score_exp
|
||||||
|
self.test_scores_manual_based = idx_label_score_manual
|
||||||
|
|
||||||
# Compute AUC
|
# Flatten arrays for counting and evaluation
|
||||||
_, labels, scores = zip(*idx_label_score)
|
all_labels_exp = np.concatenate(all_labels_exp)
|
||||||
labels = np.array(labels)
|
all_labels_manual = np.concatenate(all_labels_manual)
|
||||||
scores = np.array(scores)
|
all_scores = np.concatenate(all_scores)
|
||||||
self.test_auc = roc_auc_score(labels, scores)
|
all_idx = np.concatenate(all_idx)
|
||||||
self.test_roc = roc_curve(labels, scores)
|
|
||||||
self.test_prc = precision_recall_curve(labels, scores)
|
# Count and log label stats for exp_based
|
||||||
self.test_ap = average_precision_score(labels, scores)
|
n_exp_normal = np.sum(all_labels_exp == 1)
|
||||||
|
n_exp_anomaly = np.sum(all_labels_exp == -1)
|
||||||
|
n_exp_unknown = np.sum(all_labels_exp == 0)
|
||||||
|
logger.info(
|
||||||
|
f"Exp-based labels: normal(1)={n_exp_normal}, "
|
||||||
|
f"anomaly(-1)={n_exp_anomaly}, unknown(0)={n_exp_unknown}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count and log label stats for manual_based
|
||||||
|
n_manual_normal = np.sum(all_labels_manual == 1)
|
||||||
|
n_manual_anomaly = np.sum(all_labels_manual == -1)
|
||||||
|
n_manual_unknown = np.sum(all_labels_manual == 0)
|
||||||
|
logger.info(
|
||||||
|
f"Manual-based labels: normal(1)={n_manual_normal}, "
|
||||||
|
f"anomaly(-1)={n_manual_anomaly}, unknown(0)={n_manual_unknown}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Evaluation for exp_based (only labeled samples) ---
|
||||||
|
idxs_exp, labels_exp, scores_exp = zip(*idx_label_score_exp)
|
||||||
|
labels_exp = np.array(labels_exp)
|
||||||
|
scores_exp = np.array(scores_exp)
|
||||||
|
|
||||||
|
# Filter out unknown labels and convert to binary (1: anomaly, 0: normal) for ROC
|
||||||
|
valid_mask_exp = labels_exp != 0
|
||||||
|
if np.any(valid_mask_exp):
|
||||||
|
# Convert to binary labels for ROC (-1 → 1, 1 → 0)
|
||||||
|
labels_exp_binary = (labels_exp[valid_mask_exp] == -1).astype(int)
|
||||||
|
scores_exp_valid = scores_exp[valid_mask_exp]
|
||||||
|
|
||||||
|
self.test_auc_exp_based = roc_auc_score(labels_exp_binary, scores_exp_valid)
|
||||||
|
self.test_roc_exp_based = roc_curve(labels_exp_binary, scores_exp_valid)
|
||||||
|
self.test_prc_exp_based = precision_recall_curve(
|
||||||
|
labels_exp_binary, scores_exp_valid
|
||||||
|
)
|
||||||
|
self.test_ap_exp_based = average_precision_score(
|
||||||
|
labels_exp_binary, scores_exp_valid
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Test Loss: {:.6f}".format(epoch_loss / n_batches))
|
||||||
|
logger.info(
|
||||||
|
"Test AUC (exp_based): {:.2f}%".format(100.0 * self.test_auc_exp_based)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("Test AUC (exp_based): N/A (no labeled samples)")
|
||||||
|
self.test_auc_exp_based = None
|
||||||
|
self.test_roc_exp_based = None
|
||||||
|
self.test_prc_exp_based = None
|
||||||
|
self.test_ap_exp_based = None
|
||||||
|
|
||||||
# Log results
|
|
||||||
logger.info("Test Loss: {:.6f}".format(epoch_loss / n_batches))
|
|
||||||
logger.info("Test AUC: {:.2f}%".format(100.0 * self.test_auc))
|
|
||||||
logger.info("Test Time: {:.3f}s".format(self.test_time))
|
logger.info("Test Time: {:.3f}s".format(self.test_time))
|
||||||
|
|
||||||
|
# --- Evaluation for manual_based (only labeled samples) ---
|
||||||
|
idxs_manual, labels_manual, scores_manual = zip(*idx_label_score_manual)
|
||||||
|
labels_manual = np.array(labels_manual)
|
||||||
|
scores_manual = np.array(scores_manual)
|
||||||
|
|
||||||
|
# Filter out unknown labels and convert to binary for ROC
|
||||||
|
valid_mask_manual = labels_manual != 0
|
||||||
|
if np.any(valid_mask_manual):
|
||||||
|
# Convert to binary labels for ROC (-1 → 1, 1 → 0)
|
||||||
|
labels_manual_binary = (labels_manual[valid_mask_manual] == -1).astype(int)
|
||||||
|
scores_manual_valid = scores_manual[valid_mask_manual]
|
||||||
|
|
||||||
|
self.test_auc_manual_based = roc_auc_score(
|
||||||
|
labels_manual_binary, scores_manual_valid
|
||||||
|
)
|
||||||
|
self.test_roc_manual_based = roc_curve(
|
||||||
|
labels_manual_binary, scores_manual_valid
|
||||||
|
)
|
||||||
|
self.test_prc_manual_based = precision_recall_curve(
|
||||||
|
labels_manual_binary, scores_manual_valid
|
||||||
|
)
|
||||||
|
self.test_ap_manual_based = average_precision_score(
|
||||||
|
labels_manual_binary, scores_manual_valid
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Test AUC (manual_based): {:.2f}%".format(
|
||||||
|
100.0 * self.test_auc_manual_based
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.test_auc_manual_based = None
|
||||||
|
self.test_roc_manual_based = None
|
||||||
|
self.test_prc_manual_based = None
|
||||||
|
self.test_ap_manual_based = None
|
||||||
|
logger.info("Test AUC (manual_based): N/A (no labeled samples)")
|
||||||
|
|
||||||
|
# Store all test indices
|
||||||
|
self.test_indices = np.array(all_indices)
|
||||||
|
self.test_file_ids = np.array(all_file_ids)
|
||||||
|
self.test_frame_ids = np.array(all_frame_ids)
|
||||||
|
|
||||||
|
# Add logging for indices
|
||||||
|
logger.info(f"Number of test samples: {len(self.test_indices)}")
|
||||||
|
logger.info(f"Number of unique files: {len(np.unique(self.test_file_ids))}")
|
||||||
|
|
||||||
|
# Create a mapping of indices to their file/frame information
|
||||||
|
self.test_index_mapping = {
|
||||||
|
"indices": self.test_indices,
|
||||||
|
"file_ids": self.test_file_ids,
|
||||||
|
"frame_ids": self.test_frame_ids,
|
||||||
|
"exp_based": {
|
||||||
|
"indices": np.array(idxs_exp),
|
||||||
|
"labels": np.array(labels_exp),
|
||||||
|
"scores": np.array(scores_exp),
|
||||||
|
"valid_mask": valid_mask_exp,
|
||||||
|
},
|
||||||
|
"manual_based": {
|
||||||
|
"indices": np.array(idxs_manual),
|
||||||
|
"labels": np.array(labels_manual),
|
||||||
|
"scores": np.array(scores_manual),
|
||||||
|
"valid_mask": valid_mask_manual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
logger.info("Finished testing.")
|
logger.info("Finished testing.")
|
||||||
|
|
||||||
def init_center_c(self, train_loader: DataLoader, net: BaseNet, eps=0.1):
|
def init_center_c(self, train_loader: DataLoader, net: BaseNet, eps=0.1):
|
||||||
@@ -272,7 +464,7 @@ class DeepSADTrainer(BaseTrainer):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for data in train_loader:
|
for data in train_loader:
|
||||||
# get the inputs of the batch
|
# get the inputs of the batch
|
||||||
inputs, _, _, _, _ = data
|
inputs, _, _, _, _, _ = data
|
||||||
inputs = inputs.to(self.device)
|
inputs = inputs.to(self.device)
|
||||||
outputs = net(inputs)
|
outputs = net(inputs)
|
||||||
n_samples += outputs.shape[0]
|
n_samples += outputs.shape[0]
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ class AETrainer(BaseTrainer):
|
|||||||
n_batches = 0
|
n_batches = 0
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
for data in train_loader:
|
for data in train_loader:
|
||||||
inputs, _, _, _, file_frame_ids = data
|
inputs, _, _, _, _, file_frame_ids = data
|
||||||
inputs = inputs.to(self.device)
|
inputs = inputs.to(self.device)
|
||||||
all_training_data.append(
|
all_training_data.append(
|
||||||
np.dstack(
|
np.dstack(
|
||||||
@@ -166,7 +166,7 @@ class AETrainer(BaseTrainer):
|
|||||||
all_training_data = []
|
all_training_data = []
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for data in test_loader:
|
for data in test_loader:
|
||||||
inputs, labels, _, idx, file_frame_ids = data
|
inputs, labels, _, _, idx, file_frame_ids = data
|
||||||
inputs, labels, idx = (
|
inputs, labels, idx = (
|
||||||
inputs.to(self.device),
|
inputs.to(self.device),
|
||||||
labels.to(self.device),
|
labels.to(self.device),
|
||||||
|
|||||||
Reference in New Issue
Block a user