From ddf4e4aa36beda31f2631cb7b3f74d1be8b6a767 Mon Sep 17 00:00:00 2001 From: Jan Kowalczyk Date: Thu, 17 Oct 2024 08:36:18 +0200 Subject: [PATCH] split subter implementation (training + inference) --- Deep-SAD-PyTorch/flake.nix | 51 +-- Deep-SAD-PyTorch/src/datasets/main.py | 19 +- Deep-SAD-PyTorch/src/datasets/subter.py | 52 ++- Deep-SAD-PyTorch/src/datasets/subtersplit.py | 263 ++++++++++++++ Deep-SAD-PyTorch/src/main.py | 14 +- Deep-SAD-PyTorch/src/networks/main.py | 21 +- .../src/networks/subter_LeNet_Split.py | 66 ++++ tools/poetry.lock | 17 +- tools/pyproject.toml | 1 + tools/render2d.py | 330 +++++++++++++----- tools/util.py | 10 +- 11 files changed, 699 insertions(+), 145 deletions(-) create mode 100644 Deep-SAD-PyTorch/src/datasets/subtersplit.py create mode 100644 Deep-SAD-PyTorch/src/networks/subter_LeNet_Split.py diff --git a/Deep-SAD-PyTorch/flake.nix b/Deep-SAD-PyTorch/flake.nix index 8db9b66..a2d315d 100644 --- a/Deep-SAD-PyTorch/flake.nix +++ b/Deep-SAD-PyTorch/flake.nix @@ -10,40 +10,51 @@ }; }; - outputs = { self, nixpkgs, flake-utils, poetry2nix }: - flake-utils.lib.eachDefaultSystem (system: + outputs = + { + self, + nixpkgs, + flake-utils, + poetry2nix, + }: + flake-utils.lib.eachDefaultSystem ( + system: let # see https://github.com/nix-community/poetry2nix/tree/master#api for more functions and examples. - pkgs = import nixpkgs{ - inherit system; - config.allowUnfree = true; - config.cudaSupport = true; - }; + pkgs = import nixpkgs { + inherit system; + config.allowUnfree = true; + config.cudaSupport = true; + }; inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryApplication; in { packages = { - deepsad = mkPoetryApplication { - projectDir = self; - preferWheels = true; - python = pkgs.python311; + deepsad = mkPoetryApplication { + projectDir = self; + preferWheels = true; + python = pkgs.python311; }; default = self.packages.${system}.deepsad; }; devShells.default = pkgs.mkShell { inputsFrom = [ self.packages.${system}.deepsad ]; - buildInputs = with pkgs.python311Packages; [ - torch-bin - torchvision-bin - ]; - #LD_LIBRARY_PATH = with pkgs; lib.makeLibraryPath [ - #pkgs.stdenv.cc.cc - #]; + buildInputs = with pkgs.python311Packages; [ + torch-bin + torchvision-bin + ]; + #LD_LIBRARY_PATH = with pkgs; lib.makeLibraryPath [ + #pkgs.stdenv.cc.cc + #]; }; devShells.poetry = pkgs.mkShell { - packages = [ pkgs.poetry pkgs.python311 ]; + packages = [ + pkgs.poetry + pkgs.python311 + ]; }; - }); + } + ); } diff --git a/Deep-SAD-PyTorch/src/datasets/main.py b/Deep-SAD-PyTorch/src/datasets/main.py index 12b9f2e..34b0c45 100644 --- a/Deep-SAD-PyTorch/src/datasets/main.py +++ b/Deep-SAD-PyTorch/src/datasets/main.py @@ -1,9 +1,10 @@ -from .mnist import MNIST_Dataset -from .elpv import ELPV_Dataset -from .subter import SubTer_Dataset -from .fmnist import FashionMNIST_Dataset from .cifar10 import CIFAR10_Dataset +from .elpv import ELPV_Dataset +from .fmnist import FashionMNIST_Dataset +from .mnist import MNIST_Dataset from .odds import ODDSADDataset +from .subter import SubTer_Dataset +from .subtersplit import SubTerSplit_Dataset def load_dataset( @@ -24,6 +25,7 @@ def load_dataset( "mnist", "elpv", "subter", + "subtersplit", "fmnist", "cifar10", "arrhythmia", @@ -46,6 +48,15 @@ def load_dataset( inference=inference, ) + if dataset_name == "subtersplit": + dataset = SubTerSplit_Dataset( + root=data_path, + ratio_known_normal=ratio_known_normal, + ratio_known_outlier=ratio_known_outlier, + ratio_pollution=ratio_pollution, + inference=inference, + ) + if dataset_name == "elpv": dataset = ELPV_Dataset( root=data_path, diff --git a/Deep-SAD-PyTorch/src/datasets/subter.py b/Deep-SAD-PyTorch/src/datasets/subter.py index c063141..d75f6b5 100644 --- a/Deep-SAD-PyTorch/src/datasets/subter.py +++ b/Deep-SAD-PyTorch/src/datasets/subter.py @@ -1,22 +1,21 @@ -from torch.utils.data import Subset -from PIL import Image -from torch.utils.data.dataset import ConcatDataset -from torchvision.datasets import VisionDataset -from base.torchvision_dataset import TorchvisionDataset -from .preprocessing import create_semisupervised_setting +import logging +import random +from pathlib import Path from typing import Callable, Optional -import logging +import numpy as np import torch import torchvision.transforms as transforms -import random -import numpy as np +from base.torchvision_dataset import TorchvisionDataset +from PIL import Image +from torch.utils.data import Subset +from torch.utils.data.dataset import ConcatDataset +from torchvision.datasets import VisionDataset -from pathlib import Path +from .preprocessing import create_semisupervised_setting class SubTer_Dataset(TorchvisionDataset): - def __init__( self, root: str, @@ -31,6 +30,7 @@ class SubTer_Dataset(TorchvisionDataset): self.n_classes = 2 # 0: normal, 1: outlier self.normal_classes = tuple([0]) self.outlier_classes = tuple([1]) + self.inference_set = None # MNIST preprocessing: feature scaling to [0, 1] # FIXME understand mnist feature scaling and check if it or other preprocessing is necessary for elpv @@ -78,7 +78,6 @@ class SubTer_Dataset(TorchvisionDataset): class SubTerTraining(VisionDataset): - def __init__( self, root: str, @@ -95,10 +94,18 @@ class SubTerTraining(VisionDataset): experiments_data = [] experiments_targets = [] + validation_files = [] + experiment_files = [] for experiment_file in Path(root).iterdir(): + if experiment_file.is_dir() and experiment_file.name == "validation": + for validation_file in experiment_file.iterdir(): + if validation_file.suffix != ".npy": + continue + validation_files.append(experiment_file) if experiment_file.suffix != ".npy": continue + experiment_files.append(experiment_file) experiment_data = np.load(experiment_file) # experiment_data = np.lib.format.open_memmap(experiment_file, mode='r+') experiment_targets = ( @@ -109,6 +116,26 @@ class SubTerTraining(VisionDataset): experiments_data.append(experiment_data) experiments_targets.append(experiment_targets) + filtered_validation_files = [] + for validation_file in validation_files: + validation_file_name = validation_file.name + file_exists_in_experiments = any( + experiment_file.name == validation_file_name + for experiment_file in experiment_files + ) + if not file_exists_in_experiments: + filtered_validation_files.append(validation_file) + validation_files = filtered_validation_files + + logger = logging.getLogger() + + logger.info( + f"Train/Test experiments: {[experiment_file.name for experiment_file in experiment_files]}" + ) + logger.info( + f"Validation experiments: {[validation_file.name for validation_file in validation_files]}" + ) + lidar_projections = np.concatenate(experiments_data) smoke_presence = np.concatenate(experiments_targets) @@ -166,7 +193,6 @@ class SubTerTraining(VisionDataset): class SubTerInference(VisionDataset): - def __init__( self, root: str, diff --git a/Deep-SAD-PyTorch/src/datasets/subtersplit.py b/Deep-SAD-PyTorch/src/datasets/subtersplit.py new file mode 100644 index 0000000..13a94a4 --- /dev/null +++ b/Deep-SAD-PyTorch/src/datasets/subtersplit.py @@ -0,0 +1,263 @@ +import logging +from pathlib import Path +from typing import Callable, Optional + +import numpy as np +import torch +import torchvision.transforms as transforms +from base.torchvision_dataset import TorchvisionDataset +from PIL import Image +from torch.utils.data import Subset +from torchvision.datasets import VisionDataset + +from .preprocessing import create_semisupervised_setting + + +class SubTerSplit_Dataset(TorchvisionDataset): + def __init__( + self, + root: str, + ratio_known_normal: float = 0.0, + ratio_known_outlier: float = 0.0, + ratio_pollution: float = 0.0, + inference: bool = False, + ): + super().__init__(root) + + # Define normal and outlier classes + self.n_classes = 2 # 0: normal, 1: outlier + self.normal_classes = tuple([0]) + self.outlier_classes = tuple([1]) + self.inference_set = None + + # MNIST preprocessing: feature scaling to [0, 1] + # FIXME understand mnist feature scaling and check if it or other preprocessing is necessary for elpv + transform = transforms.ToTensor() + target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes)) + + if inference: + self.inference_set = SubTerSplitInference( + root=self.root, + transform=transform, + ) + else: + # Get train set + train_set = SubTerSplitTraining( + root=self.root, + transform=transform, + target_transform=target_transform, + train=True, + ) + + # Create semi-supervised setting + idx, _, semi_targets = create_semisupervised_setting( + train_set.targets.cpu().data.numpy(), + self.normal_classes, + self.outlier_classes, + self.outlier_classes, + ratio_known_normal, + ratio_known_outlier, + ratio_pollution, + ) + train_set.semi_targets[idx] = torch.tensor( + np.array(semi_targets, dtype=np.int8) + ) # set respective semi-supervised labels + + # Subset train_set to semi-supervised setup + self.train_set = Subset(train_set, idx) + + # Get test set + self.test_set = SubTerSplitTraining( + root=self.root, + train=False, + transform=transform, + target_transform=target_transform, + ) + + +def split_array_into_subarrays(array, split_height, split_width): + original_shape = array.shape + height, width = original_shape[-2], original_shape[-1] + assert height % split_height == 0, "The height is not divisible by the split_height" + assert width % split_width == 0, "The width is not divisible by the split_width" + num_splits_height = height // split_height + num_splits_width = width // split_width + reshaped_array = array.reshape( + -1, num_splits_height, split_height, num_splits_width, split_width + ) + transposed_array = reshaped_array.transpose(0, 1, 3, 2, 4) + final_array = transposed_array.reshape(-1, split_height, split_width) + return final_array + + +class SubTerSplitTraining(VisionDataset): + def __init__( + self, + root: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + train=False, + split=0.7, + seed=0, + height=16, + width=256, + ): + super(SubTerSplitTraining, self).__init__( + root, transforms, transform, target_transform + ) + + experiments_data = [] + experiments_targets = [] + validation_files = [] + experiment_files = [] + + logger = logging.getLogger() + + for experiment_file in Path(root).iterdir(): + if experiment_file.is_dir() and experiment_file.name == "validation": + for validation_file in experiment_file.iterdir(): + if validation_file.suffix != ".npy": + continue + validation_files.append(experiment_file) + if experiment_file.suffix != ".npy": + continue + experiment_files.append(experiment_file) + experiment_data = np.load(experiment_file) + + if ( + experiment_data.shape[1] % height != 0 + or experiment_data.shape[2] % width != 0 + ): + logger.error( + f"Experiment {experiment_file.name} has shape {experiment_data.shape} which is not divisible by {height}x{width}" + ) + experiment_data = split_array_into_subarrays(experiment_data, height, width) + + # experiment_data = np.lib.format.open_memmap(experiment_file, mode='r+') + experiment_targets = ( + np.ones(experiment_data.shape[0], dtype=np.int8) + if "smoke" in experiment_file.name + else np.zeros(experiment_data.shape[0], dtype=np.int8) + ) + experiments_data.append(experiment_data) + experiments_targets.append(experiment_targets) + + filtered_validation_files = [] + for validation_file in validation_files: + validation_file_name = validation_file.name + file_exists_in_experiments = any( + experiment_file.name == validation_file_name + for experiment_file in experiment_files + ) + if not file_exists_in_experiments: + filtered_validation_files.append(validation_file) + validation_files = filtered_validation_files + + logger.info( + f"Train/Test experiments: {[experiment_file.name for experiment_file in experiment_files]}" + ) + logger.info( + f"Validation experiments: {[validation_file.name for validation_file in validation_files]}" + ) + + lidar_projections = np.concatenate(experiments_data) + smoke_presence = np.concatenate(experiments_targets) + + np.random.seed(seed) + + shuffled_indices = np.random.permutation(lidar_projections.shape[0]) + shuffled_lidar_projections = lidar_projections[shuffled_indices] + shuffled_smoke_presence = smoke_presence[shuffled_indices] + + split_idx = int(split * shuffled_lidar_projections.shape[0]) + + if train: + self.data = shuffled_lidar_projections[:split_idx] + self.targets = shuffled_smoke_presence[:split_idx] + + else: + self.data = shuffled_lidar_projections[split_idx:] + self.targets = shuffled_smoke_presence[split_idx:] + + self.data = np.nan_to_num(self.data) + + self.data = torch.tensor(self.data) + self.targets = torch.tensor(self.targets, dtype=torch.int8) + + self.semi_targets = torch.zeros_like(self.targets, dtype=torch.int8) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + """Override the original method of the MNIST class. + Args: + index (int): Index + + Returns: + tuple: (image, target, semi_target, index) + """ + img, target, semi_target = ( + self.data[index], + int(self.targets[index]), + int(self.semi_targets[index]), + ) + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img.numpy(), mode="F") + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target, semi_target, index + + +class SubTerSplitInference(VisionDataset): + def __init__( + self, + root: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + ): + super(SubTerSplitInference, self).__init__(root, transforms, transform) + logger = logging.getLogger() + + self.experiment_file_path = Path(root) + + if not self.experiment_file_path.is_file(): + logger.error( + "For inference the data path has to be a single experiment file!" + ) + raise Exception("Inference data is not a loadable file!") + + self.data = np.load(self.experiment_file_path) + self.data = split_array_into_subarrays(self.data, 16, 256) + self.data = np.nan_to_num(self.data) + self.data = torch.tensor(self.data) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + """Override the original method of the MNIST class. + Args: + index (int): Index + + Returns: + tuple: (image, index) + """ + img = self.data[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img.numpy(), mode="F") + + if self.transform is not None: + img = self.transform(img) + + return img, index diff --git a/Deep-SAD-PyTorch/src/main.py b/Deep-SAD-PyTorch/src/main.py index 09c071a..084c0fb 100644 --- a/Deep-SAD-PyTorch/src/main.py +++ b/Deep-SAD-PyTorch/src/main.py @@ -1,14 +1,14 @@ -import click -import torch import logging import random -import numpy as np from pathlib import Path +import click +import numpy as np +import torch +from datasets.main import load_dataset +from DeepSAD import DeepSAD from utils.config import Config from utils.visualization.plot_images_grid import plot_images_grid -from DeepSAD import DeepSAD -from datasets.main import load_dataset ################################################################################ @@ -31,6 +31,7 @@ from datasets.main import load_dataset "mnist", "elpv", "subter", + "subtersplit", "fmnist", "cifar10", "arrhythmia", @@ -49,6 +50,7 @@ from datasets.main import load_dataset "mnist_LeNet", "elpv_LeNet", "subter_LeNet", + "subter_LeNet_Split", "fmnist_LeNet", "cifar10_LeNet", "arrhythmia_mlp", @@ -315,7 +317,6 @@ def main( logger.info("Number of dataloader workers: %d" % n_jobs_dataloader) if action == "train": - # Load data dataset = load_dataset( dataset_name, @@ -413,7 +414,6 @@ def main( ] # from lowest to highest score if dataset_name in ("mnist", "fmnist", "cifar10", "elpv"): - if dataset_name in ("mnist", "fmnist", "elpv"): X_all_low = dataset.test_set.data[idx_all_sorted[:32], ...].unsqueeze(1) X_all_high = dataset.test_set.data[idx_all_sorted[-32:], ...].unsqueeze( diff --git a/Deep-SAD-PyTorch/src/networks/main.py b/Deep-SAD-PyTorch/src/networks/main.py index 97a0ae0..95542a0 100644 --- a/Deep-SAD-PyTorch/src/networks/main.py +++ b/Deep-SAD-PyTorch/src/networks/main.py @@ -1,11 +1,12 @@ -from .mnist_LeNet import MNIST_LeNet, MNIST_LeNet_Autoencoder -from .elpv_LeNet import ELPV_LeNet, ELPV_LeNet_Autoencoder -from .subter_LeNet import SubTer_LeNet, SubTer_LeNet_Autoencoder -from .fmnist_LeNet import FashionMNIST_LeNet, FashionMNIST_LeNet_Autoencoder from .cifar10_LeNet import CIFAR10_LeNet, CIFAR10_LeNet_Autoencoder -from .mlp import MLP, MLP_Autoencoder -from .vae import VariationalAutoencoder from .dgm import DeepGenerativeModel, StackedDeepGenerativeModel +from .elpv_LeNet import ELPV_LeNet, ELPV_LeNet_Autoencoder +from .fmnist_LeNet import FashionMNIST_LeNet, FashionMNIST_LeNet_Autoencoder +from .mlp import MLP, MLP_Autoencoder +from .mnist_LeNet import MNIST_LeNet, MNIST_LeNet_Autoencoder +from .subter_LeNet import SubTer_LeNet, SubTer_LeNet_Autoencoder +from .subter_LeNet_Split import SubTer_LeNet_Split, SubTer_LeNet_Split_Autoencoder +from .vae import VariationalAutoencoder def build_network(net_name, ae_net=None): @@ -15,6 +16,7 @@ def build_network(net_name, ae_net=None): "mnist_LeNet", "elpv_LeNet", "subter_LeNet", + "subter_LeNet_Split", "mnist_DGM_M2", "mnist_DGM_M1M2", "fmnist_LeNet", @@ -46,6 +48,9 @@ def build_network(net_name, ae_net=None): if net_name == "subter_LeNet": net = SubTer_LeNet() + if net_name == "subter_LeNet_Split": + net = SubTer_LeNet_Split() + if net_name == "elpv_LeNet": net = ELPV_LeNet() @@ -130,6 +135,7 @@ def build_autoencoder(net_name): implemented_networks = ( "elpv_LeNet", "subter_LeNet", + "subter_LeNet_Split", "mnist_LeNet", "mnist_DGM_M1M2", "fmnist_LeNet", @@ -154,6 +160,9 @@ def build_autoencoder(net_name): if net_name == "subter_LeNet": ae_net = SubTer_LeNet_Autoencoder() + if net_name == "subter_LeNet_Split": + ae_net = SubTer_LeNet_Split_Autoencoder() + if net_name == "elpv_LeNet": ae_net = ELPV_LeNet_Autoencoder() diff --git a/Deep-SAD-PyTorch/src/networks/subter_LeNet_Split.py b/Deep-SAD-PyTorch/src/networks/subter_LeNet_Split.py new file mode 100644 index 0000000..f6ea7cc --- /dev/null +++ b/Deep-SAD-PyTorch/src/networks/subter_LeNet_Split.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from base.base_net import BaseNet + + +class SubTer_LeNet_Split(BaseNet): + def __init__(self, rep_dim=256): + super().__init__() + + self.rep_dim = rep_dim + self.pool = nn.MaxPool2d(2, 2) + + self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2) + self.bn1 = nn.BatchNorm2d(8, eps=1e-04, affine=False) + self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2) + self.bn2 = nn.BatchNorm2d(4, eps=1e-04, affine=False) + self.fc1 = nn.Linear(4 * 64 * 4, self.rep_dim, bias=False) + + def forward(self, x): + x = x.view(-1, 1, 16, 256) + x = self.conv1(x) + x = self.pool(F.leaky_relu(self.bn1(x))) + x = self.conv2(x) + x = self.pool(F.leaky_relu(self.bn2(x))) + x = x.view(int(x.size(0)), -1) + x = self.fc1(x) + return x + + +class SubTer_LeNet_Split_Decoder(BaseNet): + def __init__(self, rep_dim=256): + super().__init__() + + self.rep_dim = rep_dim + + # Decoder network + self.fc3 = nn.Linear(self.rep_dim, 4 * 64 * 4, bias=False) + self.bn3 = nn.BatchNorm2d(4, eps=1e-04, affine=False) + self.deconv1 = nn.ConvTranspose2d(4, 8, 5, bias=False, padding=2) + self.bn4 = nn.BatchNorm2d(8, eps=1e-04, affine=False) + self.deconv2 = nn.ConvTranspose2d(8, 1, 5, bias=False, padding=2) + + def forward(self, x): + x = self.fc3(x) + x = x.view(int(x.size(0)), 4, 4, 64) + x = F.interpolate(F.leaky_relu(self.bn3(x)), scale_factor=2) + x = self.deconv1(x) + x = F.interpolate(F.leaky_relu(self.bn4(x)), scale_factor=2) + x = self.deconv2(x) + x = torch.sigmoid(x) + return x + + +class SubTer_LeNet_Split_Autoencoder(BaseNet): + def __init__(self, rep_dim=256): + super().__init__() + + self.rep_dim = rep_dim + self.encoder = SubTer_LeNet_Split(rep_dim=rep_dim) + self.decoder = SubTer_LeNet_Split_Decoder(rep_dim=rep_dim) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x diff --git a/tools/poetry.lock b/tools/poetry.lock index c7b33a4..1fc61b8 100644 --- a/tools/poetry.lock +++ b/tools/poetry.lock @@ -1435,6 +1435,21 @@ toolz = "*" [package.extras] complete = ["blosc", "numpy (>=1.20.0)", "pandas (>=1.3)", "pyzmq"] +[[package]] +name = "pathvalidate" +version = "3.2.0" +description = "pathvalidate is a Python library to sanitize/validate a string such as filenames/file-paths/etc." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pathvalidate-3.2.0-py3-none-any.whl", hash = "sha256:cc593caa6299b22b37f228148257997e2fa850eea2daf7e4cc9205cef6908dee"}, + {file = "pathvalidate-3.2.0.tar.gz", hash = "sha256:5e8378cf6712bff67fbe7a8307d99fa8c1a0cb28aa477056f8fc374f0dff24ad"}, +] + +[package.extras] +docs = ["Sphinx (>=2.4)", "sphinx-rtd-theme (>=1.2.2)", "urllib3 (<2)"] +test = ["Faker (>=1.0.8)", "allpairspy (>=2)", "click (>=6.2)", "pytest (>=6.0.1)", "pytest-discord (>=0.1.4)", "pytest-md-report (>=0.4.1)"] + [[package]] name = "pillow" version = "10.3.0" @@ -2416,4 +2431,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "607634c6605566c6bec22c1e8f41d056d2920de45108297193592fe57b10507b" +content-hash = "cd1f01135a813fbbc06da686dfa67b0981d26ccb3038f7d0e4a17359326f2c8d" diff --git a/tools/pyproject.toml b/tools/pyproject.toml index 280ac71..66be338 100644 --- a/tools/pyproject.toml +++ b/tools/pyproject.toml @@ -19,6 +19,7 @@ matplotlib = "^3.8.4" dask = "^2024.4.2" dask-expr = "^1.1.3" pandas = "^2.2.2" +pathvalidate = "^3.2.0" [build-system] requires = ["poetry-core"] diff --git a/tools/render2d.py b/tools/render2d.py index 7d31f31..66240c6 100644 --- a/tools/render2d.py +++ b/tools/render2d.py @@ -1,18 +1,22 @@ -from configargparse import ( - ArgParser, - YAMLConfigFileParser, - ArgumentDefaultsRawHelpFormatter, -) -from sys import exit -from pathlib import Path -from pointcloudset import Dataset -from rich.progress import track -from pandas import DataFrame -from PIL import Image from math import pi -from typing import Optional +from multiprocessing import Pool +from pathlib import Path +from sys import exit +from typing import Optional, Tuple + import matplotlib import numpy as np +from configargparse import ( + ArgParser, + ArgumentDefaultsRawHelpFormatter, + YAMLConfigFileParser, +) +from pandas import DataFrame +from pathvalidate import sanitize_filename +from PIL import Image +from pointcloudset import Dataset +from rich.progress import track +from rosbags.highlevel import AnyReader matplotlib.use("Agg") import matplotlib.pyplot as plt @@ -20,16 +24,44 @@ import matplotlib.pyplot as plt from util import ( angle, angle_width, - positive_int, - load_dataset, - existing_path, - create_video_from_images, calculate_average_frame_rate, + create_video_from_images, + existing_path, get_colormap_with_special_missing_color, + load_dataset, + positive_int, ) -def crop_lidar_data_to_roi( +def save_image_topics( + bag_file_path: Path, output_folder_path: Path, topic_names: list[str] +) -> list[Path]: + with AnyReader([bag_file_path]) as reader: + topic_paths = [] + for topic_name in topic_names: + connections = [x for x in reader.connections if x.topic == topic_name] + frame_count = 0 + topic_output_folder_path = output_folder_path / sanitize_filename( + topic_name + ) + topic_output_folder_path.mkdir(exist_ok=True, parents=True) + topic_paths.append(topic_output_folder_path) + for connection, timestamp, rawdata in reader.messages( + connections=connections + ): + img_msg = reader.deserialize(rawdata, connection.msgtype) + + img_data = np.frombuffer( + img_msg.data, + dtype=np.uint8 if "8" in img_msg.encoding else np.uint16, + ).reshape((img_msg.height, img_msg.width)) + img = Image.fromarray(img_data, mode="L") + frame_count += 1 + img.save(topic_output_folder_path / f"frame_{frame_count:04d}.png") + return topic_paths + + +def crop_projection_data_to_roi( data: DataFrame, roi_angle_start: float, roi_angle_width: float, @@ -83,103 +115,196 @@ def create_2d_projection( tmp_file_path.unlink() +def process_frame(args) -> Tuple[int, np.ndarray, Optional[Path]]: + ( + i, + pc, + output_path, + colormap_name, + missing_data_color, + reverse_colormap, + vertical_resolution, + horizontal_resolution, + vertical_scale, + horizontal_scale, + roi_angle_start, + roi_angle_width, + render_images, + ) = args + + lidar_data = pc.data.copy() + + lidar_data["horizontal_position"] = ( + lidar_data["original_id"] % horizontal_resolution + ) + lidar_data["horizontal_position_yaw_f"] = ( + 0.5 + * horizontal_resolution + * (np.arctan2(lidar_data["y"], lidar_data["x"]) / pi + 1.0) + ) + lidar_data["horizontal_position_yaw"] = np.floor( + lidar_data["horizontal_position_yaw_f"] + ) + lidar_data["vertical_position"] = np.floor( + lidar_data["original_id"] / horizontal_resolution + ) + # fov = 32 * pi / 180 + # fov_down = 17 * pi / 180 + fov = 31.76 * pi / 180 + fov_down = 17.3 * pi / 180 + lidar_data["vertical_angle"] = np.arcsin( + lidar_data["z"] + / np.sqrt(lidar_data["x"] ** 2 + lidar_data["y"] ** 2 + lidar_data["z"] ** 2) + ) + lidar_data["vertical_angle_degree"] = lidar_data["vertical_angle"] * 180 / pi + + lidar_data["vertical_position_pitch_f"] = vertical_resolution * ( + 1 - ((lidar_data["vertical_angle"] + fov_down) / fov) + ) + lidar_data["vertical_position_pitch"] = np.floor( + lidar_data["vertical_position_pitch_f"] + ) + + duplicates = lidar_data[ + lidar_data.duplicated( + subset=["vertical_position_pitch", "horizontal_position_yaw"], + keep=False, + ) + ].sort_values(by=["vertical_position_pitch", "horizontal_position_yaw"]) + + lidar_data["normalized_range"] = 1 / np.sqrt( + lidar_data["x"] ** 2 + lidar_data["y"] ** 2 + lidar_data["z"] ** 2 + ) + projection_data = lidar_data.pivot( + index="vertical_position_pitch", + columns="horizontal_position_yaw", + values="normalized_range", + ) + projection_data = projection_data.reindex( + columns=range(horizontal_resolution), fill_value=0 + ) + projection_data = projection_data.reindex( + index=range(vertical_resolution), fill_value=0 + ) + projection_data, output_horizontal_resolution = crop_projection_data_to_roi( + projection_data, roi_angle_start, roi_angle_width, horizontal_resolution + ) + + if render_images: + image_path = create_2d_projection( + projection_data, + output_path / f"frame_{i:04d}.png", + output_path / f"tmp_{i:04d}.png", + colormap_name, + missing_data_color, + reverse_colormap, + horizontal_resolution=output_horizontal_resolution * horizontal_scale, + vertical_resolution=vertical_resolution * vertical_scale, + ) + else: + image_path = None + + return ( + i, + projection_data.to_numpy(), + image_path, + lidar_data["vertical_position_pitch_f"].min(), + lidar_data["vertical_position_pitch_f"].max(), + ) + + +# Adjusted to use a generator for args_list def create_projection_data( dataset: Dataset, output_path: Path, colormap_name: str, missing_data_color: str, reverse_colormap: bool, + vertical_resolution: int, horizontal_resolution: int, vertical_scale: int, horizontal_scale: int, roi_angle_start: float, roi_angle_width: float, render_images: bool, -) -> (np.ndarray, Optional[list[Path]]): +) -> Tuple[np.ndarray, Optional[list[Path]]]: rendered_images = [] converted_lidar_frames = [] + # Generator for args_list + # def args_generator(): + # for i, pc in enumerate(dataset, 1): + # yield ( + # i, + # pc, + # output_path, + # colormap_name, + # missing_data_color, + # reverse_colormap, + # vertical_resolution, + # horizontal_resolution, + # vertical_scale, + # horizontal_scale, + # roi_angle_start, + # roi_angle_width, + # render_images, + # ) + + # results = [] + + # with Pool() as pool: + # results_gen = pool.imap(process_frame, args_generator()) + # for result in track( + # results_gen, description="Processing...", total=len(dataset) + # ): + # results.append(result) + + # results.sort(key=lambda x: x[0]) + for i, pc in track( - enumerate(dataset, 1), description="Creating projections...", total=len(dataset) + enumerate(dataset, 1), description="Processing...", total=len(dataset) ): - vertical_resolution = int(pc.data["ring"].max() + 1) - - # Angle calculation implementation - - # projected_data = pc.data.copy() - # projected_data["arctan"] = np.arctan2(projected_data["y"], projected_data["x"]) - # projected_data["arctan_normalized"] = 0.5 * (projected_data["arctan"] / pi + 1.0) - # projected_data["arctan_scaled"] = projected_data["arctan_normalized"] * horizontal_resolution - # #projected_data["horizontal_position"] = np.floor(projected_data["arctan_scaled"]) - # projected_data["horizontal_position"] = np.round(projected_data["arctan_scaled"]) - # projected_data["normalized_range"] = 1 / np.sqrt( - # projected_data["x"] ** 2 + projected_data["y"] ** 2 + projected_data["z"] ** 2 - # ) - # duplicates = projected_data[projected_data.duplicated(subset=['ring', 'horizontal_position'], keep=False)].sort_values(by=['ring', 'horizontal_position']) - # sorted = projected_data.sort_values(by=['ring', 'horizontal_position']) - - # FIXME: following pivot fails due to duplicates in the data, some points (x, y) are mapped to the same pixel in the projection, have to decide how to handles - # these cases - - # projected_image_data = projected_data.pivot( - # index="ring", columns="horizontal_position", values="normalized_range" - # ) - # projected_image_data = projected_image_data.reindex(columns=range(horizontal_resolution), fill_value=0) - - # projected_image_data, output_horizontal_resolution = crop_lidar_data_to_roi( - # projected_image_data, roi_angle_start, roi_angle_width, horizontal_resolution - # ) - - # create_2d_projection( - # projected_image_data, - # output_path / f"frame_{i:04d}_projection.png", - # output_path / "tmp.png", - # colormap_name, - # missing_data_color, - # reverse_colormap, - # horizontal_resolution=output_horizontal_resolution * horizontal_scale, - # vertical_resolution=vertical_resolution * vertical_scale, - # ) - - lidar_data = pc.data.copy() - lidar_data["horizontal_position"] = ( - lidar_data["original_id"] % horizontal_resolution - ) - lidar_data["normalized_range"] = 1 / np.sqrt( - lidar_data["x"] ** 2 + lidar_data["y"] ** 2 + lidar_data["z"] ** 2 - ) - lidar_data = lidar_data.pivot( - index="ring", columns="horizontal_position", values="normalized_range" - ) - lidar_data = lidar_data.reindex( - columns=range(horizontal_resolution), fill_value=0 - ) - lidar_data = lidar_data.reindex(index=range(vertical_resolution), fill_value=0) - lidar_data, output_horizontal_resolution = crop_lidar_data_to_roi( - lidar_data, roi_angle_start, roi_angle_width, horizontal_resolution - ) - - converted_lidar_frames.append(lidar_data.to_numpy()) - if render_images: - image_path = create_2d_projection( - lidar_data, - output_path / f"frame_{i:04d}.png", - output_path / "tmp.png", + results = process_frame( + ( + i, + pc, + output_path, colormap_name, missing_data_color, reverse_colormap, - horizontal_resolution=output_horizontal_resolution * horizontal_scale, - vertical_resolution=vertical_resolution * vertical_scale, + vertical_resolution, + horizontal_resolution, + vertical_scale, + horizontal_scale, + roi_angle_start, + roi_angle_width, + render_images, ) + ) + converted_lidar_frames.append(results[1]) + if render_images: + rendered_images.append(results[2]) - rendered_images.append(image_path) + # min_all = 100 + # max_all = 0 + + # for _, data, img_path, min_frame, max_frame in results: + # converted_lidar_frames.append(data) + # if img_path: + # rendered_images.append(img_path) + # if min_frame < min_all: + # min_all = min_frame + # if max_frame > max_all: + # max_all = max_frame + + # print(f"{min_all=}, {max_all=}") projection_data = np.stack(converted_lidar_frames, axis=0) if render_images: - return rendered_images, projection_data + return projection_data, rendered_images else: - return projection_data + return (projection_data,) def main() -> int: @@ -204,6 +329,12 @@ def main() -> int: type=str, help="topic in the ros/mcap bag file containing the point cloud data", ) + parser.add_argument( + "--image-topics", + default=[], + nargs="+", + help="topics in the ros/mcap bag file containing the image data", + ) parser.add_argument( "--output-path", default=Path("./output"), @@ -248,11 +379,17 @@ def main() -> int: type=bool, help="if colormap should be reversed", ) + parser.add_argument( + "--vertical-resolution", + default=32, + type=positive_int, + help="number of vertical lidar data point rows", + ) parser.add_argument( "--horizontal-resolution", default=2048, type=positive_int, - help="number of horizontal lidar data points", + help="number of horizontal lidar data point columns", ) parser.add_argument( "--vertical-scale", @@ -296,6 +433,7 @@ def main() -> int: dataset = load_dataset(args.input_experiment_path, args.pointcloud_topic) images = [] + topic_paths = [] if not args.output_no_images or not args.output_no_video: if not args.force_generation and all( @@ -312,6 +450,7 @@ def main() -> int: args.colormap_name, args.missing_data_color, args.reverse_colormap, + args.vertical_resolution, args.horizontal_resolution, args.vertical_scale, args.horizontal_scale, @@ -319,6 +458,10 @@ def main() -> int: args.roi_angle_width, render_images=True, ) + if args.image_topics: + topic_paths = save_image_topics( + args.input_experiment_path, output_path, args.image_topics + ) output_numpy_path = (output_path / args.input_experiment_path.stem).with_suffix( ".npy" @@ -336,6 +479,7 @@ def main() -> int: args.colormap_name, args.missing_data_color, args.reverse_colormap, + args.vertical_resolution, args.horizontal_resolution, args.vertical_scale, args.horizontal_scale, @@ -358,12 +502,20 @@ def main() -> int: f"Skipping video generation for {args.input_experiment_path} as {output_path / args.input_experiment_path.stem}.mp4 already exists" ) else: + frame_rate = calculate_average_frame_rate(dataset) input_images_pattern = f"{tmp_path}/frame_%04d.png" create_video_from_images( input_images_pattern, (output_path / args.input_experiment_path.stem).with_suffix(".mp4"), - calculate_average_frame_rate(dataset), + frame_rate, ) + for topic_path in topic_paths: + input_images_pattern = f"{topic_path}/frame_%04d.png" + create_video_from_images( + input_images_pattern, + (output_path / topic_path.stem).with_suffix(".mp4"), + frame_rate, + ) if args.output_no_images: for image in images: diff --git a/tools/util.py b/tools/util.py index 48d5f7b..f67c8ea 100644 --- a/tools/util.py +++ b/tools/util.py @@ -1,11 +1,11 @@ -from pointcloudset import Dataset -from pointcloudset.io.dataset.ros import dataset_from_ros -from pathlib import Path from argparse import ArgumentTypeError -from subprocess import run from datetime import timedelta -from matplotlib.colors import Colormap +from pathlib import Path +from subprocess import run + from matplotlib import colormaps +from matplotlib.colors import Colormap +from pointcloudset import Dataset def load_dataset(