split subter implementation (training + inference)
This commit is contained in:
@@ -10,11 +10,18 @@
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
outputs = { self, nixpkgs, flake-utils, poetry2nix }:
|
outputs =
|
||||||
flake-utils.lib.eachDefaultSystem (system:
|
{
|
||||||
|
self,
|
||||||
|
nixpkgs,
|
||||||
|
flake-utils,
|
||||||
|
poetry2nix,
|
||||||
|
}:
|
||||||
|
flake-utils.lib.eachDefaultSystem (
|
||||||
|
system:
|
||||||
let
|
let
|
||||||
# see https://github.com/nix-community/poetry2nix/tree/master#api for more functions and examples.
|
# see https://github.com/nix-community/poetry2nix/tree/master#api for more functions and examples.
|
||||||
pkgs = import nixpkgs{
|
pkgs = import nixpkgs {
|
||||||
inherit system;
|
inherit system;
|
||||||
config.allowUnfree = true;
|
config.allowUnfree = true;
|
||||||
config.cudaSupport = true;
|
config.cudaSupport = true;
|
||||||
@@ -43,7 +50,11 @@
|
|||||||
};
|
};
|
||||||
|
|
||||||
devShells.poetry = pkgs.mkShell {
|
devShells.poetry = pkgs.mkShell {
|
||||||
packages = [ pkgs.poetry pkgs.python311 ];
|
packages = [
|
||||||
|
pkgs.poetry
|
||||||
|
pkgs.python311
|
||||||
|
];
|
||||||
};
|
};
|
||||||
});
|
}
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
from .mnist import MNIST_Dataset
|
|
||||||
from .elpv import ELPV_Dataset
|
|
||||||
from .subter import SubTer_Dataset
|
|
||||||
from .fmnist import FashionMNIST_Dataset
|
|
||||||
from .cifar10 import CIFAR10_Dataset
|
from .cifar10 import CIFAR10_Dataset
|
||||||
|
from .elpv import ELPV_Dataset
|
||||||
|
from .fmnist import FashionMNIST_Dataset
|
||||||
|
from .mnist import MNIST_Dataset
|
||||||
from .odds import ODDSADDataset
|
from .odds import ODDSADDataset
|
||||||
|
from .subter import SubTer_Dataset
|
||||||
|
from .subtersplit import SubTerSplit_Dataset
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(
|
def load_dataset(
|
||||||
@@ -24,6 +25,7 @@ def load_dataset(
|
|||||||
"mnist",
|
"mnist",
|
||||||
"elpv",
|
"elpv",
|
||||||
"subter",
|
"subter",
|
||||||
|
"subtersplit",
|
||||||
"fmnist",
|
"fmnist",
|
||||||
"cifar10",
|
"cifar10",
|
||||||
"arrhythmia",
|
"arrhythmia",
|
||||||
@@ -46,6 +48,15 @@ def load_dataset(
|
|||||||
inference=inference,
|
inference=inference,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if dataset_name == "subtersplit":
|
||||||
|
dataset = SubTerSplit_Dataset(
|
||||||
|
root=data_path,
|
||||||
|
ratio_known_normal=ratio_known_normal,
|
||||||
|
ratio_known_outlier=ratio_known_outlier,
|
||||||
|
ratio_pollution=ratio_pollution,
|
||||||
|
inference=inference,
|
||||||
|
)
|
||||||
|
|
||||||
if dataset_name == "elpv":
|
if dataset_name == "elpv":
|
||||||
dataset = ELPV_Dataset(
|
dataset = ELPV_Dataset(
|
||||||
root=data_path,
|
root=data_path,
|
||||||
|
|||||||
@@ -1,22 +1,21 @@
|
|||||||
from torch.utils.data import Subset
|
import logging
|
||||||
from PIL import Image
|
import random
|
||||||
from torch.utils.data.dataset import ConcatDataset
|
from pathlib import Path
|
||||||
from torchvision.datasets import VisionDataset
|
|
||||||
from base.torchvision_dataset import TorchvisionDataset
|
|
||||||
from .preprocessing import create_semisupervised_setting
|
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import logging
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
import random
|
from base.torchvision_dataset import TorchvisionDataset
|
||||||
import numpy as np
|
from PIL import Image
|
||||||
|
from torch.utils.data import Subset
|
||||||
|
from torch.utils.data.dataset import ConcatDataset
|
||||||
|
from torchvision.datasets import VisionDataset
|
||||||
|
|
||||||
from pathlib import Path
|
from .preprocessing import create_semisupervised_setting
|
||||||
|
|
||||||
|
|
||||||
class SubTer_Dataset(TorchvisionDataset):
|
class SubTer_Dataset(TorchvisionDataset):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root: str,
|
root: str,
|
||||||
@@ -31,6 +30,7 @@ class SubTer_Dataset(TorchvisionDataset):
|
|||||||
self.n_classes = 2 # 0: normal, 1: outlier
|
self.n_classes = 2 # 0: normal, 1: outlier
|
||||||
self.normal_classes = tuple([0])
|
self.normal_classes = tuple([0])
|
||||||
self.outlier_classes = tuple([1])
|
self.outlier_classes = tuple([1])
|
||||||
|
self.inference_set = None
|
||||||
|
|
||||||
# MNIST preprocessing: feature scaling to [0, 1]
|
# MNIST preprocessing: feature scaling to [0, 1]
|
||||||
# FIXME understand mnist feature scaling and check if it or other preprocessing is necessary for elpv
|
# FIXME understand mnist feature scaling and check if it or other preprocessing is necessary for elpv
|
||||||
@@ -78,7 +78,6 @@ class SubTer_Dataset(TorchvisionDataset):
|
|||||||
|
|
||||||
|
|
||||||
class SubTerTraining(VisionDataset):
|
class SubTerTraining(VisionDataset):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root: str,
|
root: str,
|
||||||
@@ -95,10 +94,18 @@ class SubTerTraining(VisionDataset):
|
|||||||
|
|
||||||
experiments_data = []
|
experiments_data = []
|
||||||
experiments_targets = []
|
experiments_targets = []
|
||||||
|
validation_files = []
|
||||||
|
experiment_files = []
|
||||||
|
|
||||||
for experiment_file in Path(root).iterdir():
|
for experiment_file in Path(root).iterdir():
|
||||||
|
if experiment_file.is_dir() and experiment_file.name == "validation":
|
||||||
|
for validation_file in experiment_file.iterdir():
|
||||||
|
if validation_file.suffix != ".npy":
|
||||||
|
continue
|
||||||
|
validation_files.append(experiment_file)
|
||||||
if experiment_file.suffix != ".npy":
|
if experiment_file.suffix != ".npy":
|
||||||
continue
|
continue
|
||||||
|
experiment_files.append(experiment_file)
|
||||||
experiment_data = np.load(experiment_file)
|
experiment_data = np.load(experiment_file)
|
||||||
# experiment_data = np.lib.format.open_memmap(experiment_file, mode='r+')
|
# experiment_data = np.lib.format.open_memmap(experiment_file, mode='r+')
|
||||||
experiment_targets = (
|
experiment_targets = (
|
||||||
@@ -109,6 +116,26 @@ class SubTerTraining(VisionDataset):
|
|||||||
experiments_data.append(experiment_data)
|
experiments_data.append(experiment_data)
|
||||||
experiments_targets.append(experiment_targets)
|
experiments_targets.append(experiment_targets)
|
||||||
|
|
||||||
|
filtered_validation_files = []
|
||||||
|
for validation_file in validation_files:
|
||||||
|
validation_file_name = validation_file.name
|
||||||
|
file_exists_in_experiments = any(
|
||||||
|
experiment_file.name == validation_file_name
|
||||||
|
for experiment_file in experiment_files
|
||||||
|
)
|
||||||
|
if not file_exists_in_experiments:
|
||||||
|
filtered_validation_files.append(validation_file)
|
||||||
|
validation_files = filtered_validation_files
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Train/Test experiments: {[experiment_file.name for experiment_file in experiment_files]}"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Validation experiments: {[validation_file.name for validation_file in validation_files]}"
|
||||||
|
)
|
||||||
|
|
||||||
lidar_projections = np.concatenate(experiments_data)
|
lidar_projections = np.concatenate(experiments_data)
|
||||||
smoke_presence = np.concatenate(experiments_targets)
|
smoke_presence = np.concatenate(experiments_targets)
|
||||||
|
|
||||||
@@ -166,7 +193,6 @@ class SubTerTraining(VisionDataset):
|
|||||||
|
|
||||||
|
|
||||||
class SubTerInference(VisionDataset):
|
class SubTerInference(VisionDataset):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root: str,
|
root: str,
|
||||||
|
|||||||
263
Deep-SAD-PyTorch/src/datasets/subtersplit.py
Normal file
263
Deep-SAD-PyTorch/src/datasets/subtersplit.py
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from base.torchvision_dataset import TorchvisionDataset
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Subset
|
||||||
|
from torchvision.datasets import VisionDataset
|
||||||
|
|
||||||
|
from .preprocessing import create_semisupervised_setting
|
||||||
|
|
||||||
|
|
||||||
|
class SubTerSplit_Dataset(TorchvisionDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root: str,
|
||||||
|
ratio_known_normal: float = 0.0,
|
||||||
|
ratio_known_outlier: float = 0.0,
|
||||||
|
ratio_pollution: float = 0.0,
|
||||||
|
inference: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(root)
|
||||||
|
|
||||||
|
# Define normal and outlier classes
|
||||||
|
self.n_classes = 2 # 0: normal, 1: outlier
|
||||||
|
self.normal_classes = tuple([0])
|
||||||
|
self.outlier_classes = tuple([1])
|
||||||
|
self.inference_set = None
|
||||||
|
|
||||||
|
# MNIST preprocessing: feature scaling to [0, 1]
|
||||||
|
# FIXME understand mnist feature scaling and check if it or other preprocessing is necessary for elpv
|
||||||
|
transform = transforms.ToTensor()
|
||||||
|
target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))
|
||||||
|
|
||||||
|
if inference:
|
||||||
|
self.inference_set = SubTerSplitInference(
|
||||||
|
root=self.root,
|
||||||
|
transform=transform,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Get train set
|
||||||
|
train_set = SubTerSplitTraining(
|
||||||
|
root=self.root,
|
||||||
|
transform=transform,
|
||||||
|
target_transform=target_transform,
|
||||||
|
train=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create semi-supervised setting
|
||||||
|
idx, _, semi_targets = create_semisupervised_setting(
|
||||||
|
train_set.targets.cpu().data.numpy(),
|
||||||
|
self.normal_classes,
|
||||||
|
self.outlier_classes,
|
||||||
|
self.outlier_classes,
|
||||||
|
ratio_known_normal,
|
||||||
|
ratio_known_outlier,
|
||||||
|
ratio_pollution,
|
||||||
|
)
|
||||||
|
train_set.semi_targets[idx] = torch.tensor(
|
||||||
|
np.array(semi_targets, dtype=np.int8)
|
||||||
|
) # set respective semi-supervised labels
|
||||||
|
|
||||||
|
# Subset train_set to semi-supervised setup
|
||||||
|
self.train_set = Subset(train_set, idx)
|
||||||
|
|
||||||
|
# Get test set
|
||||||
|
self.test_set = SubTerSplitTraining(
|
||||||
|
root=self.root,
|
||||||
|
train=False,
|
||||||
|
transform=transform,
|
||||||
|
target_transform=target_transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def split_array_into_subarrays(array, split_height, split_width):
|
||||||
|
original_shape = array.shape
|
||||||
|
height, width = original_shape[-2], original_shape[-1]
|
||||||
|
assert height % split_height == 0, "The height is not divisible by the split_height"
|
||||||
|
assert width % split_width == 0, "The width is not divisible by the split_width"
|
||||||
|
num_splits_height = height // split_height
|
||||||
|
num_splits_width = width // split_width
|
||||||
|
reshaped_array = array.reshape(
|
||||||
|
-1, num_splits_height, split_height, num_splits_width, split_width
|
||||||
|
)
|
||||||
|
transposed_array = reshaped_array.transpose(0, 1, 3, 2, 4)
|
||||||
|
final_array = transposed_array.reshape(-1, split_height, split_width)
|
||||||
|
return final_array
|
||||||
|
|
||||||
|
|
||||||
|
class SubTerSplitTraining(VisionDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root: str,
|
||||||
|
transforms: Optional[Callable] = None,
|
||||||
|
transform: Optional[Callable] = None,
|
||||||
|
target_transform: Optional[Callable] = None,
|
||||||
|
train=False,
|
||||||
|
split=0.7,
|
||||||
|
seed=0,
|
||||||
|
height=16,
|
||||||
|
width=256,
|
||||||
|
):
|
||||||
|
super(SubTerSplitTraining, self).__init__(
|
||||||
|
root, transforms, transform, target_transform
|
||||||
|
)
|
||||||
|
|
||||||
|
experiments_data = []
|
||||||
|
experiments_targets = []
|
||||||
|
validation_files = []
|
||||||
|
experiment_files = []
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
for experiment_file in Path(root).iterdir():
|
||||||
|
if experiment_file.is_dir() and experiment_file.name == "validation":
|
||||||
|
for validation_file in experiment_file.iterdir():
|
||||||
|
if validation_file.suffix != ".npy":
|
||||||
|
continue
|
||||||
|
validation_files.append(experiment_file)
|
||||||
|
if experiment_file.suffix != ".npy":
|
||||||
|
continue
|
||||||
|
experiment_files.append(experiment_file)
|
||||||
|
experiment_data = np.load(experiment_file)
|
||||||
|
|
||||||
|
if (
|
||||||
|
experiment_data.shape[1] % height != 0
|
||||||
|
or experiment_data.shape[2] % width != 0
|
||||||
|
):
|
||||||
|
logger.error(
|
||||||
|
f"Experiment {experiment_file.name} has shape {experiment_data.shape} which is not divisible by {height}x{width}"
|
||||||
|
)
|
||||||
|
experiment_data = split_array_into_subarrays(experiment_data, height, width)
|
||||||
|
|
||||||
|
# experiment_data = np.lib.format.open_memmap(experiment_file, mode='r+')
|
||||||
|
experiment_targets = (
|
||||||
|
np.ones(experiment_data.shape[0], dtype=np.int8)
|
||||||
|
if "smoke" in experiment_file.name
|
||||||
|
else np.zeros(experiment_data.shape[0], dtype=np.int8)
|
||||||
|
)
|
||||||
|
experiments_data.append(experiment_data)
|
||||||
|
experiments_targets.append(experiment_targets)
|
||||||
|
|
||||||
|
filtered_validation_files = []
|
||||||
|
for validation_file in validation_files:
|
||||||
|
validation_file_name = validation_file.name
|
||||||
|
file_exists_in_experiments = any(
|
||||||
|
experiment_file.name == validation_file_name
|
||||||
|
for experiment_file in experiment_files
|
||||||
|
)
|
||||||
|
if not file_exists_in_experiments:
|
||||||
|
filtered_validation_files.append(validation_file)
|
||||||
|
validation_files = filtered_validation_files
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Train/Test experiments: {[experiment_file.name for experiment_file in experiment_files]}"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Validation experiments: {[validation_file.name for validation_file in validation_files]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
lidar_projections = np.concatenate(experiments_data)
|
||||||
|
smoke_presence = np.concatenate(experiments_targets)
|
||||||
|
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
shuffled_indices = np.random.permutation(lidar_projections.shape[0])
|
||||||
|
shuffled_lidar_projections = lidar_projections[shuffled_indices]
|
||||||
|
shuffled_smoke_presence = smoke_presence[shuffled_indices]
|
||||||
|
|
||||||
|
split_idx = int(split * shuffled_lidar_projections.shape[0])
|
||||||
|
|
||||||
|
if train:
|
||||||
|
self.data = shuffled_lidar_projections[:split_idx]
|
||||||
|
self.targets = shuffled_smoke_presence[:split_idx]
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.data = shuffled_lidar_projections[split_idx:]
|
||||||
|
self.targets = shuffled_smoke_presence[split_idx:]
|
||||||
|
|
||||||
|
self.data = np.nan_to_num(self.data)
|
||||||
|
|
||||||
|
self.data = torch.tensor(self.data)
|
||||||
|
self.targets = torch.tensor(self.targets, dtype=torch.int8)
|
||||||
|
|
||||||
|
self.semi_targets = torch.zeros_like(self.targets, dtype=torch.int8)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
"""Override the original method of the MNIST class.
|
||||||
|
Args:
|
||||||
|
index (int): Index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (image, target, semi_target, index)
|
||||||
|
"""
|
||||||
|
img, target, semi_target = (
|
||||||
|
self.data[index],
|
||||||
|
int(self.targets[index]),
|
||||||
|
int(self.semi_targets[index]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# doing this so that it is consistent with all other datasets
|
||||||
|
# to return a PIL Image
|
||||||
|
img = Image.fromarray(img.numpy(), mode="F")
|
||||||
|
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
|
||||||
|
if self.target_transform is not None:
|
||||||
|
target = self.target_transform(target)
|
||||||
|
|
||||||
|
return img, target, semi_target, index
|
||||||
|
|
||||||
|
|
||||||
|
class SubTerSplitInference(VisionDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root: str,
|
||||||
|
transforms: Optional[Callable] = None,
|
||||||
|
transform: Optional[Callable] = None,
|
||||||
|
):
|
||||||
|
super(SubTerSplitInference, self).__init__(root, transforms, transform)
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
self.experiment_file_path = Path(root)
|
||||||
|
|
||||||
|
if not self.experiment_file_path.is_file():
|
||||||
|
logger.error(
|
||||||
|
"For inference the data path has to be a single experiment file!"
|
||||||
|
)
|
||||||
|
raise Exception("Inference data is not a loadable file!")
|
||||||
|
|
||||||
|
self.data = np.load(self.experiment_file_path)
|
||||||
|
self.data = split_array_into_subarrays(self.data, 16, 256)
|
||||||
|
self.data = np.nan_to_num(self.data)
|
||||||
|
self.data = torch.tensor(self.data)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
"""Override the original method of the MNIST class.
|
||||||
|
Args:
|
||||||
|
index (int): Index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (image, index)
|
||||||
|
"""
|
||||||
|
img = self.data[index]
|
||||||
|
|
||||||
|
# doing this so that it is consistent with all other datasets
|
||||||
|
# to return a PIL Image
|
||||||
|
img = Image.fromarray(img.numpy(), mode="F")
|
||||||
|
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
|
||||||
|
return img, index
|
||||||
@@ -1,14 +1,14 @@
|
|||||||
import click
|
|
||||||
import torch
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import click
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from datasets.main import load_dataset
|
||||||
|
from DeepSAD import DeepSAD
|
||||||
from utils.config import Config
|
from utils.config import Config
|
||||||
from utils.visualization.plot_images_grid import plot_images_grid
|
from utils.visualization.plot_images_grid import plot_images_grid
|
||||||
from DeepSAD import DeepSAD
|
|
||||||
from datasets.main import load_dataset
|
|
||||||
|
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
@@ -31,6 +31,7 @@ from datasets.main import load_dataset
|
|||||||
"mnist",
|
"mnist",
|
||||||
"elpv",
|
"elpv",
|
||||||
"subter",
|
"subter",
|
||||||
|
"subtersplit",
|
||||||
"fmnist",
|
"fmnist",
|
||||||
"cifar10",
|
"cifar10",
|
||||||
"arrhythmia",
|
"arrhythmia",
|
||||||
@@ -49,6 +50,7 @@ from datasets.main import load_dataset
|
|||||||
"mnist_LeNet",
|
"mnist_LeNet",
|
||||||
"elpv_LeNet",
|
"elpv_LeNet",
|
||||||
"subter_LeNet",
|
"subter_LeNet",
|
||||||
|
"subter_LeNet_Split",
|
||||||
"fmnist_LeNet",
|
"fmnist_LeNet",
|
||||||
"cifar10_LeNet",
|
"cifar10_LeNet",
|
||||||
"arrhythmia_mlp",
|
"arrhythmia_mlp",
|
||||||
@@ -315,7 +317,6 @@ def main(
|
|||||||
logger.info("Number of dataloader workers: %d" % n_jobs_dataloader)
|
logger.info("Number of dataloader workers: %d" % n_jobs_dataloader)
|
||||||
|
|
||||||
if action == "train":
|
if action == "train":
|
||||||
|
|
||||||
# Load data
|
# Load data
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
dataset_name,
|
dataset_name,
|
||||||
@@ -413,7 +414,6 @@ def main(
|
|||||||
] # from lowest to highest score
|
] # from lowest to highest score
|
||||||
|
|
||||||
if dataset_name in ("mnist", "fmnist", "cifar10", "elpv"):
|
if dataset_name in ("mnist", "fmnist", "cifar10", "elpv"):
|
||||||
|
|
||||||
if dataset_name in ("mnist", "fmnist", "elpv"):
|
if dataset_name in ("mnist", "fmnist", "elpv"):
|
||||||
X_all_low = dataset.test_set.data[idx_all_sorted[:32], ...].unsqueeze(1)
|
X_all_low = dataset.test_set.data[idx_all_sorted[:32], ...].unsqueeze(1)
|
||||||
X_all_high = dataset.test_set.data[idx_all_sorted[-32:], ...].unsqueeze(
|
X_all_high = dataset.test_set.data[idx_all_sorted[-32:], ...].unsqueeze(
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
from .mnist_LeNet import MNIST_LeNet, MNIST_LeNet_Autoencoder
|
|
||||||
from .elpv_LeNet import ELPV_LeNet, ELPV_LeNet_Autoencoder
|
|
||||||
from .subter_LeNet import SubTer_LeNet, SubTer_LeNet_Autoencoder
|
|
||||||
from .fmnist_LeNet import FashionMNIST_LeNet, FashionMNIST_LeNet_Autoencoder
|
|
||||||
from .cifar10_LeNet import CIFAR10_LeNet, CIFAR10_LeNet_Autoencoder
|
from .cifar10_LeNet import CIFAR10_LeNet, CIFAR10_LeNet_Autoencoder
|
||||||
from .mlp import MLP, MLP_Autoencoder
|
|
||||||
from .vae import VariationalAutoencoder
|
|
||||||
from .dgm import DeepGenerativeModel, StackedDeepGenerativeModel
|
from .dgm import DeepGenerativeModel, StackedDeepGenerativeModel
|
||||||
|
from .elpv_LeNet import ELPV_LeNet, ELPV_LeNet_Autoencoder
|
||||||
|
from .fmnist_LeNet import FashionMNIST_LeNet, FashionMNIST_LeNet_Autoencoder
|
||||||
|
from .mlp import MLP, MLP_Autoencoder
|
||||||
|
from .mnist_LeNet import MNIST_LeNet, MNIST_LeNet_Autoencoder
|
||||||
|
from .subter_LeNet import SubTer_LeNet, SubTer_LeNet_Autoencoder
|
||||||
|
from .subter_LeNet_Split import SubTer_LeNet_Split, SubTer_LeNet_Split_Autoencoder
|
||||||
|
from .vae import VariationalAutoencoder
|
||||||
|
|
||||||
|
|
||||||
def build_network(net_name, ae_net=None):
|
def build_network(net_name, ae_net=None):
|
||||||
@@ -15,6 +16,7 @@ def build_network(net_name, ae_net=None):
|
|||||||
"mnist_LeNet",
|
"mnist_LeNet",
|
||||||
"elpv_LeNet",
|
"elpv_LeNet",
|
||||||
"subter_LeNet",
|
"subter_LeNet",
|
||||||
|
"subter_LeNet_Split",
|
||||||
"mnist_DGM_M2",
|
"mnist_DGM_M2",
|
||||||
"mnist_DGM_M1M2",
|
"mnist_DGM_M1M2",
|
||||||
"fmnist_LeNet",
|
"fmnist_LeNet",
|
||||||
@@ -46,6 +48,9 @@ def build_network(net_name, ae_net=None):
|
|||||||
if net_name == "subter_LeNet":
|
if net_name == "subter_LeNet":
|
||||||
net = SubTer_LeNet()
|
net = SubTer_LeNet()
|
||||||
|
|
||||||
|
if net_name == "subter_LeNet_Split":
|
||||||
|
net = SubTer_LeNet_Split()
|
||||||
|
|
||||||
if net_name == "elpv_LeNet":
|
if net_name == "elpv_LeNet":
|
||||||
net = ELPV_LeNet()
|
net = ELPV_LeNet()
|
||||||
|
|
||||||
@@ -130,6 +135,7 @@ def build_autoencoder(net_name):
|
|||||||
implemented_networks = (
|
implemented_networks = (
|
||||||
"elpv_LeNet",
|
"elpv_LeNet",
|
||||||
"subter_LeNet",
|
"subter_LeNet",
|
||||||
|
"subter_LeNet_Split",
|
||||||
"mnist_LeNet",
|
"mnist_LeNet",
|
||||||
"mnist_DGM_M1M2",
|
"mnist_DGM_M1M2",
|
||||||
"fmnist_LeNet",
|
"fmnist_LeNet",
|
||||||
@@ -154,6 +160,9 @@ def build_autoencoder(net_name):
|
|||||||
if net_name == "subter_LeNet":
|
if net_name == "subter_LeNet":
|
||||||
ae_net = SubTer_LeNet_Autoencoder()
|
ae_net = SubTer_LeNet_Autoencoder()
|
||||||
|
|
||||||
|
if net_name == "subter_LeNet_Split":
|
||||||
|
ae_net = SubTer_LeNet_Split_Autoencoder()
|
||||||
|
|
||||||
if net_name == "elpv_LeNet":
|
if net_name == "elpv_LeNet":
|
||||||
ae_net = ELPV_LeNet_Autoencoder()
|
ae_net = ELPV_LeNet_Autoencoder()
|
||||||
|
|
||||||
|
|||||||
66
Deep-SAD-PyTorch/src/networks/subter_LeNet_Split.py
Normal file
66
Deep-SAD-PyTorch/src/networks/subter_LeNet_Split.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from base.base_net import BaseNet
|
||||||
|
|
||||||
|
|
||||||
|
class SubTer_LeNet_Split(BaseNet):
|
||||||
|
def __init__(self, rep_dim=256):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.rep_dim = rep_dim
|
||||||
|
self.pool = nn.MaxPool2d(2, 2)
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2)
|
||||||
|
self.bn1 = nn.BatchNorm2d(8, eps=1e-04, affine=False)
|
||||||
|
self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2)
|
||||||
|
self.bn2 = nn.BatchNorm2d(4, eps=1e-04, affine=False)
|
||||||
|
self.fc1 = nn.Linear(4 * 64 * 4, self.rep_dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.view(-1, 1, 16, 256)
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.pool(F.leaky_relu(self.bn1(x)))
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.pool(F.leaky_relu(self.bn2(x)))
|
||||||
|
x = x.view(int(x.size(0)), -1)
|
||||||
|
x = self.fc1(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SubTer_LeNet_Split_Decoder(BaseNet):
|
||||||
|
def __init__(self, rep_dim=256):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.rep_dim = rep_dim
|
||||||
|
|
||||||
|
# Decoder network
|
||||||
|
self.fc3 = nn.Linear(self.rep_dim, 4 * 64 * 4, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(4, eps=1e-04, affine=False)
|
||||||
|
self.deconv1 = nn.ConvTranspose2d(4, 8, 5, bias=False, padding=2)
|
||||||
|
self.bn4 = nn.BatchNorm2d(8, eps=1e-04, affine=False)
|
||||||
|
self.deconv2 = nn.ConvTranspose2d(8, 1, 5, bias=False, padding=2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc3(x)
|
||||||
|
x = x.view(int(x.size(0)), 4, 4, 64)
|
||||||
|
x = F.interpolate(F.leaky_relu(self.bn3(x)), scale_factor=2)
|
||||||
|
x = self.deconv1(x)
|
||||||
|
x = F.interpolate(F.leaky_relu(self.bn4(x)), scale_factor=2)
|
||||||
|
x = self.deconv2(x)
|
||||||
|
x = torch.sigmoid(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SubTer_LeNet_Split_Autoencoder(BaseNet):
|
||||||
|
def __init__(self, rep_dim=256):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.rep_dim = rep_dim
|
||||||
|
self.encoder = SubTer_LeNet_Split(rep_dim=rep_dim)
|
||||||
|
self.decoder = SubTer_LeNet_Split_Decoder(rep_dim=rep_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.encoder(x)
|
||||||
|
x = self.decoder(x)
|
||||||
|
return x
|
||||||
17
tools/poetry.lock
generated
17
tools/poetry.lock
generated
@@ -1435,6 +1435,21 @@ toolz = "*"
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
complete = ["blosc", "numpy (>=1.20.0)", "pandas (>=1.3)", "pyzmq"]
|
complete = ["blosc", "numpy (>=1.20.0)", "pandas (>=1.3)", "pyzmq"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pathvalidate"
|
||||||
|
version = "3.2.0"
|
||||||
|
description = "pathvalidate is a Python library to sanitize/validate a string such as filenames/file-paths/etc."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "pathvalidate-3.2.0-py3-none-any.whl", hash = "sha256:cc593caa6299b22b37f228148257997e2fa850eea2daf7e4cc9205cef6908dee"},
|
||||||
|
{file = "pathvalidate-3.2.0.tar.gz", hash = "sha256:5e8378cf6712bff67fbe7a8307d99fa8c1a0cb28aa477056f8fc374f0dff24ad"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
docs = ["Sphinx (>=2.4)", "sphinx-rtd-theme (>=1.2.2)", "urllib3 (<2)"]
|
||||||
|
test = ["Faker (>=1.0.8)", "allpairspy (>=2)", "click (>=6.2)", "pytest (>=6.0.1)", "pytest-discord (>=0.1.4)", "pytest-md-report (>=0.4.1)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pillow"
|
name = "pillow"
|
||||||
version = "10.3.0"
|
version = "10.3.0"
|
||||||
@@ -2416,4 +2431,4 @@ cffi = ["cffi (>=1.11)"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "607634c6605566c6bec22c1e8f41d056d2920de45108297193592fe57b10507b"
|
content-hash = "cd1f01135a813fbbc06da686dfa67b0981d26ccb3038f7d0e4a17359326f2c8d"
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ matplotlib = "^3.8.4"
|
|||||||
dask = "^2024.4.2"
|
dask = "^2024.4.2"
|
||||||
dask-expr = "^1.1.3"
|
dask-expr = "^1.1.3"
|
||||||
pandas = "^2.2.2"
|
pandas = "^2.2.2"
|
||||||
|
pathvalidate = "^3.2.0"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|||||||
@@ -1,18 +1,22 @@
|
|||||||
from configargparse import (
|
|
||||||
ArgParser,
|
|
||||||
YAMLConfigFileParser,
|
|
||||||
ArgumentDefaultsRawHelpFormatter,
|
|
||||||
)
|
|
||||||
from sys import exit
|
|
||||||
from pathlib import Path
|
|
||||||
from pointcloudset import Dataset
|
|
||||||
from rich.progress import track
|
|
||||||
from pandas import DataFrame
|
|
||||||
from PIL import Image
|
|
||||||
from math import pi
|
from math import pi
|
||||||
from typing import Optional
|
from multiprocessing import Pool
|
||||||
|
from pathlib import Path
|
||||||
|
from sys import exit
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import matplotlib
|
import matplotlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from configargparse import (
|
||||||
|
ArgParser,
|
||||||
|
ArgumentDefaultsRawHelpFormatter,
|
||||||
|
YAMLConfigFileParser,
|
||||||
|
)
|
||||||
|
from pandas import DataFrame
|
||||||
|
from pathvalidate import sanitize_filename
|
||||||
|
from PIL import Image
|
||||||
|
from pointcloudset import Dataset
|
||||||
|
from rich.progress import track
|
||||||
|
from rosbags.highlevel import AnyReader
|
||||||
|
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@@ -20,16 +24,44 @@ import matplotlib.pyplot as plt
|
|||||||
from util import (
|
from util import (
|
||||||
angle,
|
angle,
|
||||||
angle_width,
|
angle_width,
|
||||||
positive_int,
|
|
||||||
load_dataset,
|
|
||||||
existing_path,
|
|
||||||
create_video_from_images,
|
|
||||||
calculate_average_frame_rate,
|
calculate_average_frame_rate,
|
||||||
|
create_video_from_images,
|
||||||
|
existing_path,
|
||||||
get_colormap_with_special_missing_color,
|
get_colormap_with_special_missing_color,
|
||||||
|
load_dataset,
|
||||||
|
positive_int,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def crop_lidar_data_to_roi(
|
def save_image_topics(
|
||||||
|
bag_file_path: Path, output_folder_path: Path, topic_names: list[str]
|
||||||
|
) -> list[Path]:
|
||||||
|
with AnyReader([bag_file_path]) as reader:
|
||||||
|
topic_paths = []
|
||||||
|
for topic_name in topic_names:
|
||||||
|
connections = [x for x in reader.connections if x.topic == topic_name]
|
||||||
|
frame_count = 0
|
||||||
|
topic_output_folder_path = output_folder_path / sanitize_filename(
|
||||||
|
topic_name
|
||||||
|
)
|
||||||
|
topic_output_folder_path.mkdir(exist_ok=True, parents=True)
|
||||||
|
topic_paths.append(topic_output_folder_path)
|
||||||
|
for connection, timestamp, rawdata in reader.messages(
|
||||||
|
connections=connections
|
||||||
|
):
|
||||||
|
img_msg = reader.deserialize(rawdata, connection.msgtype)
|
||||||
|
|
||||||
|
img_data = np.frombuffer(
|
||||||
|
img_msg.data,
|
||||||
|
dtype=np.uint8 if "8" in img_msg.encoding else np.uint16,
|
||||||
|
).reshape((img_msg.height, img_msg.width))
|
||||||
|
img = Image.fromarray(img_data, mode="L")
|
||||||
|
frame_count += 1
|
||||||
|
img.save(topic_output_folder_path / f"frame_{frame_count:04d}.png")
|
||||||
|
return topic_paths
|
||||||
|
|
||||||
|
|
||||||
|
def crop_projection_data_to_roi(
|
||||||
data: DataFrame,
|
data: DataFrame,
|
||||||
roi_angle_start: float,
|
roi_angle_start: float,
|
||||||
roi_angle_width: float,
|
roi_angle_width: float,
|
||||||
@@ -83,103 +115,196 @@ def create_2d_projection(
|
|||||||
tmp_file_path.unlink()
|
tmp_file_path.unlink()
|
||||||
|
|
||||||
|
|
||||||
def create_projection_data(
|
def process_frame(args) -> Tuple[int, np.ndarray, Optional[Path]]:
|
||||||
dataset: Dataset,
|
(
|
||||||
output_path: Path,
|
i,
|
||||||
colormap_name: str,
|
pc,
|
||||||
missing_data_color: str,
|
output_path,
|
||||||
reverse_colormap: bool,
|
colormap_name,
|
||||||
horizontal_resolution: int,
|
missing_data_color,
|
||||||
vertical_scale: int,
|
reverse_colormap,
|
||||||
horizontal_scale: int,
|
vertical_resolution,
|
||||||
roi_angle_start: float,
|
horizontal_resolution,
|
||||||
roi_angle_width: float,
|
vertical_scale,
|
||||||
render_images: bool,
|
horizontal_scale,
|
||||||
) -> (np.ndarray, Optional[list[Path]]):
|
roi_angle_start,
|
||||||
rendered_images = []
|
roi_angle_width,
|
||||||
converted_lidar_frames = []
|
render_images,
|
||||||
|
) = args
|
||||||
for i, pc in track(
|
|
||||||
enumerate(dataset, 1), description="Creating projections...", total=len(dataset)
|
|
||||||
):
|
|
||||||
vertical_resolution = int(pc.data["ring"].max() + 1)
|
|
||||||
|
|
||||||
# Angle calculation implementation
|
|
||||||
|
|
||||||
# projected_data = pc.data.copy()
|
|
||||||
# projected_data["arctan"] = np.arctan2(projected_data["y"], projected_data["x"])
|
|
||||||
# projected_data["arctan_normalized"] = 0.5 * (projected_data["arctan"] / pi + 1.0)
|
|
||||||
# projected_data["arctan_scaled"] = projected_data["arctan_normalized"] * horizontal_resolution
|
|
||||||
# #projected_data["horizontal_position"] = np.floor(projected_data["arctan_scaled"])
|
|
||||||
# projected_data["horizontal_position"] = np.round(projected_data["arctan_scaled"])
|
|
||||||
# projected_data["normalized_range"] = 1 / np.sqrt(
|
|
||||||
# projected_data["x"] ** 2 + projected_data["y"] ** 2 + projected_data["z"] ** 2
|
|
||||||
# )
|
|
||||||
# duplicates = projected_data[projected_data.duplicated(subset=['ring', 'horizontal_position'], keep=False)].sort_values(by=['ring', 'horizontal_position'])
|
|
||||||
# sorted = projected_data.sort_values(by=['ring', 'horizontal_position'])
|
|
||||||
|
|
||||||
# FIXME: following pivot fails due to duplicates in the data, some points (x, y) are mapped to the same pixel in the projection, have to decide how to handles
|
|
||||||
# these cases
|
|
||||||
|
|
||||||
# projected_image_data = projected_data.pivot(
|
|
||||||
# index="ring", columns="horizontal_position", values="normalized_range"
|
|
||||||
# )
|
|
||||||
# projected_image_data = projected_image_data.reindex(columns=range(horizontal_resolution), fill_value=0)
|
|
||||||
|
|
||||||
# projected_image_data, output_horizontal_resolution = crop_lidar_data_to_roi(
|
|
||||||
# projected_image_data, roi_angle_start, roi_angle_width, horizontal_resolution
|
|
||||||
# )
|
|
||||||
|
|
||||||
# create_2d_projection(
|
|
||||||
# projected_image_data,
|
|
||||||
# output_path / f"frame_{i:04d}_projection.png",
|
|
||||||
# output_path / "tmp.png",
|
|
||||||
# colormap_name,
|
|
||||||
# missing_data_color,
|
|
||||||
# reverse_colormap,
|
|
||||||
# horizontal_resolution=output_horizontal_resolution * horizontal_scale,
|
|
||||||
# vertical_resolution=vertical_resolution * vertical_scale,
|
|
||||||
# )
|
|
||||||
|
|
||||||
lidar_data = pc.data.copy()
|
lidar_data = pc.data.copy()
|
||||||
|
|
||||||
lidar_data["horizontal_position"] = (
|
lidar_data["horizontal_position"] = (
|
||||||
lidar_data["original_id"] % horizontal_resolution
|
lidar_data["original_id"] % horizontal_resolution
|
||||||
)
|
)
|
||||||
|
lidar_data["horizontal_position_yaw_f"] = (
|
||||||
|
0.5
|
||||||
|
* horizontal_resolution
|
||||||
|
* (np.arctan2(lidar_data["y"], lidar_data["x"]) / pi + 1.0)
|
||||||
|
)
|
||||||
|
lidar_data["horizontal_position_yaw"] = np.floor(
|
||||||
|
lidar_data["horizontal_position_yaw_f"]
|
||||||
|
)
|
||||||
|
lidar_data["vertical_position"] = np.floor(
|
||||||
|
lidar_data["original_id"] / horizontal_resolution
|
||||||
|
)
|
||||||
|
# fov = 32 * pi / 180
|
||||||
|
# fov_down = 17 * pi / 180
|
||||||
|
fov = 31.76 * pi / 180
|
||||||
|
fov_down = 17.3 * pi / 180
|
||||||
|
lidar_data["vertical_angle"] = np.arcsin(
|
||||||
|
lidar_data["z"]
|
||||||
|
/ np.sqrt(lidar_data["x"] ** 2 + lidar_data["y"] ** 2 + lidar_data["z"] ** 2)
|
||||||
|
)
|
||||||
|
lidar_data["vertical_angle_degree"] = lidar_data["vertical_angle"] * 180 / pi
|
||||||
|
|
||||||
|
lidar_data["vertical_position_pitch_f"] = vertical_resolution * (
|
||||||
|
1 - ((lidar_data["vertical_angle"] + fov_down) / fov)
|
||||||
|
)
|
||||||
|
lidar_data["vertical_position_pitch"] = np.floor(
|
||||||
|
lidar_data["vertical_position_pitch_f"]
|
||||||
|
)
|
||||||
|
|
||||||
|
duplicates = lidar_data[
|
||||||
|
lidar_data.duplicated(
|
||||||
|
subset=["vertical_position_pitch", "horizontal_position_yaw"],
|
||||||
|
keep=False,
|
||||||
|
)
|
||||||
|
].sort_values(by=["vertical_position_pitch", "horizontal_position_yaw"])
|
||||||
|
|
||||||
lidar_data["normalized_range"] = 1 / np.sqrt(
|
lidar_data["normalized_range"] = 1 / np.sqrt(
|
||||||
lidar_data["x"] ** 2 + lidar_data["y"] ** 2 + lidar_data["z"] ** 2
|
lidar_data["x"] ** 2 + lidar_data["y"] ** 2 + lidar_data["z"] ** 2
|
||||||
)
|
)
|
||||||
lidar_data = lidar_data.pivot(
|
projection_data = lidar_data.pivot(
|
||||||
index="ring", columns="horizontal_position", values="normalized_range"
|
index="vertical_position_pitch",
|
||||||
|
columns="horizontal_position_yaw",
|
||||||
|
values="normalized_range",
|
||||||
)
|
)
|
||||||
lidar_data = lidar_data.reindex(
|
projection_data = projection_data.reindex(
|
||||||
columns=range(horizontal_resolution), fill_value=0
|
columns=range(horizontal_resolution), fill_value=0
|
||||||
)
|
)
|
||||||
lidar_data = lidar_data.reindex(index=range(vertical_resolution), fill_value=0)
|
projection_data = projection_data.reindex(
|
||||||
lidar_data, output_horizontal_resolution = crop_lidar_data_to_roi(
|
index=range(vertical_resolution), fill_value=0
|
||||||
lidar_data, roi_angle_start, roi_angle_width, horizontal_resolution
|
)
|
||||||
|
projection_data, output_horizontal_resolution = crop_projection_data_to_roi(
|
||||||
|
projection_data, roi_angle_start, roi_angle_width, horizontal_resolution
|
||||||
)
|
)
|
||||||
|
|
||||||
converted_lidar_frames.append(lidar_data.to_numpy())
|
|
||||||
if render_images:
|
if render_images:
|
||||||
image_path = create_2d_projection(
|
image_path = create_2d_projection(
|
||||||
lidar_data,
|
projection_data,
|
||||||
output_path / f"frame_{i:04d}.png",
|
output_path / f"frame_{i:04d}.png",
|
||||||
output_path / "tmp.png",
|
output_path / f"tmp_{i:04d}.png",
|
||||||
colormap_name,
|
colormap_name,
|
||||||
missing_data_color,
|
missing_data_color,
|
||||||
reverse_colormap,
|
reverse_colormap,
|
||||||
horizontal_resolution=output_horizontal_resolution * horizontal_scale,
|
horizontal_resolution=output_horizontal_resolution * horizontal_scale,
|
||||||
vertical_resolution=vertical_resolution * vertical_scale,
|
vertical_resolution=vertical_resolution * vertical_scale,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
image_path = None
|
||||||
|
|
||||||
rendered_images.append(image_path)
|
return (
|
||||||
|
i,
|
||||||
|
projection_data.to_numpy(),
|
||||||
|
image_path,
|
||||||
|
lidar_data["vertical_position_pitch_f"].min(),
|
||||||
|
lidar_data["vertical_position_pitch_f"].max(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Adjusted to use a generator for args_list
|
||||||
|
def create_projection_data(
|
||||||
|
dataset: Dataset,
|
||||||
|
output_path: Path,
|
||||||
|
colormap_name: str,
|
||||||
|
missing_data_color: str,
|
||||||
|
reverse_colormap: bool,
|
||||||
|
vertical_resolution: int,
|
||||||
|
horizontal_resolution: int,
|
||||||
|
vertical_scale: int,
|
||||||
|
horizontal_scale: int,
|
||||||
|
roi_angle_start: float,
|
||||||
|
roi_angle_width: float,
|
||||||
|
render_images: bool,
|
||||||
|
) -> Tuple[np.ndarray, Optional[list[Path]]]:
|
||||||
|
rendered_images = []
|
||||||
|
converted_lidar_frames = []
|
||||||
|
|
||||||
|
# Generator for args_list
|
||||||
|
# def args_generator():
|
||||||
|
# for i, pc in enumerate(dataset, 1):
|
||||||
|
# yield (
|
||||||
|
# i,
|
||||||
|
# pc,
|
||||||
|
# output_path,
|
||||||
|
# colormap_name,
|
||||||
|
# missing_data_color,
|
||||||
|
# reverse_colormap,
|
||||||
|
# vertical_resolution,
|
||||||
|
# horizontal_resolution,
|
||||||
|
# vertical_scale,
|
||||||
|
# horizontal_scale,
|
||||||
|
# roi_angle_start,
|
||||||
|
# roi_angle_width,
|
||||||
|
# render_images,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# results = []
|
||||||
|
|
||||||
|
# with Pool() as pool:
|
||||||
|
# results_gen = pool.imap(process_frame, args_generator())
|
||||||
|
# for result in track(
|
||||||
|
# results_gen, description="Processing...", total=len(dataset)
|
||||||
|
# ):
|
||||||
|
# results.append(result)
|
||||||
|
|
||||||
|
# results.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
for i, pc in track(
|
||||||
|
enumerate(dataset, 1), description="Processing...", total=len(dataset)
|
||||||
|
):
|
||||||
|
results = process_frame(
|
||||||
|
(
|
||||||
|
i,
|
||||||
|
pc,
|
||||||
|
output_path,
|
||||||
|
colormap_name,
|
||||||
|
missing_data_color,
|
||||||
|
reverse_colormap,
|
||||||
|
vertical_resolution,
|
||||||
|
horizontal_resolution,
|
||||||
|
vertical_scale,
|
||||||
|
horizontal_scale,
|
||||||
|
roi_angle_start,
|
||||||
|
roi_angle_width,
|
||||||
|
render_images,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
converted_lidar_frames.append(results[1])
|
||||||
|
if render_images:
|
||||||
|
rendered_images.append(results[2])
|
||||||
|
|
||||||
|
# min_all = 100
|
||||||
|
# max_all = 0
|
||||||
|
|
||||||
|
# for _, data, img_path, min_frame, max_frame in results:
|
||||||
|
# converted_lidar_frames.append(data)
|
||||||
|
# if img_path:
|
||||||
|
# rendered_images.append(img_path)
|
||||||
|
# if min_frame < min_all:
|
||||||
|
# min_all = min_frame
|
||||||
|
# if max_frame > max_all:
|
||||||
|
# max_all = max_frame
|
||||||
|
|
||||||
|
# print(f"{min_all=}, {max_all=}")
|
||||||
|
|
||||||
projection_data = np.stack(converted_lidar_frames, axis=0)
|
projection_data = np.stack(converted_lidar_frames, axis=0)
|
||||||
|
|
||||||
if render_images:
|
if render_images:
|
||||||
return rendered_images, projection_data
|
return projection_data, rendered_images
|
||||||
else:
|
else:
|
||||||
return projection_data
|
return (projection_data,)
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
def main() -> int:
|
||||||
@@ -204,6 +329,12 @@ def main() -> int:
|
|||||||
type=str,
|
type=str,
|
||||||
help="topic in the ros/mcap bag file containing the point cloud data",
|
help="topic in the ros/mcap bag file containing the point cloud data",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image-topics",
|
||||||
|
default=[],
|
||||||
|
nargs="+",
|
||||||
|
help="topics in the ros/mcap bag file containing the image data",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-path",
|
"--output-path",
|
||||||
default=Path("./output"),
|
default=Path("./output"),
|
||||||
@@ -248,11 +379,17 @@ def main() -> int:
|
|||||||
type=bool,
|
type=bool,
|
||||||
help="if colormap should be reversed",
|
help="if colormap should be reversed",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vertical-resolution",
|
||||||
|
default=32,
|
||||||
|
type=positive_int,
|
||||||
|
help="number of vertical lidar data point rows",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--horizontal-resolution",
|
"--horizontal-resolution",
|
||||||
default=2048,
|
default=2048,
|
||||||
type=positive_int,
|
type=positive_int,
|
||||||
help="number of horizontal lidar data points",
|
help="number of horizontal lidar data point columns",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vertical-scale",
|
"--vertical-scale",
|
||||||
@@ -296,6 +433,7 @@ def main() -> int:
|
|||||||
dataset = load_dataset(args.input_experiment_path, args.pointcloud_topic)
|
dataset = load_dataset(args.input_experiment_path, args.pointcloud_topic)
|
||||||
|
|
||||||
images = []
|
images = []
|
||||||
|
topic_paths = []
|
||||||
|
|
||||||
if not args.output_no_images or not args.output_no_video:
|
if not args.output_no_images or not args.output_no_video:
|
||||||
if not args.force_generation and all(
|
if not args.force_generation and all(
|
||||||
@@ -312,6 +450,7 @@ def main() -> int:
|
|||||||
args.colormap_name,
|
args.colormap_name,
|
||||||
args.missing_data_color,
|
args.missing_data_color,
|
||||||
args.reverse_colormap,
|
args.reverse_colormap,
|
||||||
|
args.vertical_resolution,
|
||||||
args.horizontal_resolution,
|
args.horizontal_resolution,
|
||||||
args.vertical_scale,
|
args.vertical_scale,
|
||||||
args.horizontal_scale,
|
args.horizontal_scale,
|
||||||
@@ -319,6 +458,10 @@ def main() -> int:
|
|||||||
args.roi_angle_width,
|
args.roi_angle_width,
|
||||||
render_images=True,
|
render_images=True,
|
||||||
)
|
)
|
||||||
|
if args.image_topics:
|
||||||
|
topic_paths = save_image_topics(
|
||||||
|
args.input_experiment_path, output_path, args.image_topics
|
||||||
|
)
|
||||||
|
|
||||||
output_numpy_path = (output_path / args.input_experiment_path.stem).with_suffix(
|
output_numpy_path = (output_path / args.input_experiment_path.stem).with_suffix(
|
||||||
".npy"
|
".npy"
|
||||||
@@ -336,6 +479,7 @@ def main() -> int:
|
|||||||
args.colormap_name,
|
args.colormap_name,
|
||||||
args.missing_data_color,
|
args.missing_data_color,
|
||||||
args.reverse_colormap,
|
args.reverse_colormap,
|
||||||
|
args.vertical_resolution,
|
||||||
args.horizontal_resolution,
|
args.horizontal_resolution,
|
||||||
args.vertical_scale,
|
args.vertical_scale,
|
||||||
args.horizontal_scale,
|
args.horizontal_scale,
|
||||||
@@ -358,11 +502,19 @@ def main() -> int:
|
|||||||
f"Skipping video generation for {args.input_experiment_path} as {output_path / args.input_experiment_path.stem}.mp4 already exists"
|
f"Skipping video generation for {args.input_experiment_path} as {output_path / args.input_experiment_path.stem}.mp4 already exists"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
frame_rate = calculate_average_frame_rate(dataset)
|
||||||
input_images_pattern = f"{tmp_path}/frame_%04d.png"
|
input_images_pattern = f"{tmp_path}/frame_%04d.png"
|
||||||
create_video_from_images(
|
create_video_from_images(
|
||||||
input_images_pattern,
|
input_images_pattern,
|
||||||
(output_path / args.input_experiment_path.stem).with_suffix(".mp4"),
|
(output_path / args.input_experiment_path.stem).with_suffix(".mp4"),
|
||||||
calculate_average_frame_rate(dataset),
|
frame_rate,
|
||||||
|
)
|
||||||
|
for topic_path in topic_paths:
|
||||||
|
input_images_pattern = f"{topic_path}/frame_%04d.png"
|
||||||
|
create_video_from_images(
|
||||||
|
input_images_pattern,
|
||||||
|
(output_path / topic_path.stem).with_suffix(".mp4"),
|
||||||
|
frame_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.output_no_images:
|
if args.output_no_images:
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
from pointcloudset import Dataset
|
|
||||||
from pointcloudset.io.dataset.ros import dataset_from_ros
|
|
||||||
from pathlib import Path
|
|
||||||
from argparse import ArgumentTypeError
|
from argparse import ArgumentTypeError
|
||||||
from subprocess import run
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from matplotlib.colors import Colormap
|
from pathlib import Path
|
||||||
|
from subprocess import run
|
||||||
|
|
||||||
from matplotlib import colormaps
|
from matplotlib import colormaps
|
||||||
|
from matplotlib.colors import Colormap
|
||||||
|
from pointcloudset import Dataset
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(
|
def load_dataset(
|
||||||
|
|||||||
Reference in New Issue
Block a user