add torchscan for summary and receptive field (wip)

This commit is contained in:
Jan Kowalczyk
2025-06-04 09:45:24 +02:00
parent 3a0f35f21d
commit 3538b15073
5 changed files with 189 additions and 10 deletions

View File

@@ -430,6 +430,67 @@ files = [
{file = "numpy-2.0.0.tar.gz", hash = "sha256:cf5d1c9e6837f8af9f92b6bd3e86d513cdc11f60fd62185cc49ec7d1aba34864"}, {file = "numpy-2.0.0.tar.gz", hash = "sha256:cf5d1c9e6837f8af9f92b6bd3e86d513cdc11f60fd62185cc49ec7d1aba34864"},
] ]
[[package]]
name = "nvidia-cublas-cu11"
version = "11.10.3.66"
description = "CUBLAS native runtime libraries"
optional = false
python-versions = ">=3"
files = [
{file = "nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl", hash = "sha256:d32e4d75f94ddfb93ea0a5dda08389bcc65d8916a25cb9f37ac89edaeed3bded"},
{file = "nvidia_cublas_cu11-11.10.3.66-py3-none-win_amd64.whl", hash = "sha256:8ac17ba6ade3ed56ab898a036f9ae0756f1e81052a317bf98f8c6d18dc3ae49e"},
]
[package.dependencies]
setuptools = "*"
wheel = "*"
[[package]]
name = "nvidia-cuda-nvrtc-cu11"
version = "11.7.99"
description = "NVRTC native runtime libraries"
optional = false
python-versions = ">=3"
files = [
{file = "nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:9f1562822ea264b7e34ed5930567e89242d266448e936b85bc97a3370feabb03"},
{file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:f7d9610d9b7c331fa0da2d1b2858a4a8315e6d49765091d28711c8946e7425e7"},
{file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:f2effeb1309bdd1b3854fc9b17eaf997808f8b25968ce0c7070945c4265d64a3"},
]
[package.dependencies]
setuptools = "*"
wheel = "*"
[[package]]
name = "nvidia-cuda-runtime-cu11"
version = "11.7.99"
description = "CUDA Runtime native Libraries"
optional = false
python-versions = ">=3"
files = [
{file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:cc768314ae58d2641f07eac350f40f99dcb35719c4faff4bc458a7cd2b119e31"},
{file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:bc77fa59a7679310df9d5c70ab13c4e34c64ae2124dd1efd7e5474b71be125c7"},
]
[package.dependencies]
setuptools = "*"
wheel = "*"
[[package]]
name = "nvidia-cudnn-cu11"
version = "8.5.0.96"
description = "cuDNN runtime libraries"
optional = false
python-versions = ">=3"
files = [
{file = "nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:402f40adfc6f418f9dae9ab402e773cfed9beae52333f6d86ae3107a1b9527e7"},
{file = "nvidia_cudnn_cu11-8.5.0.96-py3-none-manylinux1_x86_64.whl", hash = "sha256:71f8111eb830879ff2836db3cccf03bbd735df9b0d17cd93761732ac50a8a108"},
]
[package.dependencies]
setuptools = "*"
wheel = "*"
[[package]] [[package]]
name = "packaging" name = "packaging"
version = "24.1" version = "24.1"
@@ -746,6 +807,26 @@ dev = ["flake8", "flit", "mypy", "pandas-stubs", "pre-commit", "pytest", "pytest
docs = ["ipykernel", "nbconvert", "numpydoc", "pydata_sphinx_theme (==0.10.0rc2)", "pyyaml", "sphinx (<6.0.0)", "sphinx-copybutton", "sphinx-design", "sphinx-issues"] docs = ["ipykernel", "nbconvert", "numpydoc", "pydata_sphinx_theme (==0.10.0rc2)", "pyyaml", "sphinx (<6.0.0)", "sphinx-copybutton", "sphinx-design", "sphinx-issues"]
stats = ["scipy (>=1.7)", "statsmodels (>=0.12)"] stats = ["scipy (>=1.7)", "statsmodels (>=0.12)"]
[[package]]
name = "setuptools"
version = "80.9.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
python-versions = ">=3.9"
files = [
{file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"},
{file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"},
]
[package.extras]
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"]
core = ["importlib_metadata (>=6)", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
cover = ["pytest-cov"]
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
enabler = ["pytest-enabler (>=2.2)"]
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"]
[[package]] [[package]]
name = "six" name = "six"
version = "1.16.0" version = "1.16.0"
@@ -768,6 +849,77 @@ files = [
{file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"},
] ]
[[package]]
name = "torch"
version = "1.13.1"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:fd12043868a34a8da7d490bf6db66991108b00ffbeecb034228bfcbbd4197143"},
{file = "torch-1.13.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d9fe785d375f2e26a5d5eba5de91f89e6a3be5d11efb497e76705fdf93fa3c2e"},
{file = "torch-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:98124598cdff4c287dbf50f53fb455f0c1e3a88022b39648102957f3445e9b76"},
{file = "torch-1.13.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:393a6273c832e047581063fb74335ff50b4c566217019cc6ace318cd79eb0566"},
{file = "torch-1.13.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:0122806b111b949d21fa1a5f9764d1fd2fcc4a47cb7f8ff914204fd4fc752ed5"},
{file = "torch-1.13.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:22128502fd8f5b25ac1cd849ecb64a418382ae81dd4ce2b5cebaa09ab15b0d9b"},
{file = "torch-1.13.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:76024be052b659ac1304ab8475ab03ea0a12124c3e7626282c9c86798ac7bc11"},
{file = "torch-1.13.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:ea8dda84d796094eb8709df0fcd6b56dc20b58fdd6bc4e8d7109930dafc8e419"},
{file = "torch-1.13.1-cp37-cp37m-win_amd64.whl", hash = "sha256:2ee7b81e9c457252bddd7d3da66fb1f619a5d12c24d7074de91c4ddafb832c93"},
{file = "torch-1.13.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:0d9b8061048cfb78e675b9d2ea8503bfe30db43d583599ae8626b1263a0c1380"},
{file = "torch-1.13.1-cp37-none-macosx_11_0_arm64.whl", hash = "sha256:f402ca80b66e9fbd661ed4287d7553f7f3899d9ab54bf5c67faada1555abde28"},
{file = "torch-1.13.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:727dbf00e2cf858052364c0e2a496684b9cb5aa01dc8a8bc8bbb7c54502bdcdd"},
{file = "torch-1.13.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:df8434b0695e9ceb8cc70650afc1310d8ba949e6db2a0525ddd9c3b2b181e5fe"},
{file = "torch-1.13.1-cp38-cp38-win_amd64.whl", hash = "sha256:5e1e722a41f52a3f26f0c4fcec227e02c6c42f7c094f32e49d4beef7d1e213ea"},
{file = "torch-1.13.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:33e67eea526e0bbb9151263e65417a9ef2d8fa53cbe628e87310060c9dcfa312"},
{file = "torch-1.13.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:eeeb204d30fd40af6a2d80879b46a7efbe3cf43cdbeb8838dd4f3d126cc90b2b"},
{file = "torch-1.13.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:50ff5e76d70074f6653d191fe4f6a42fdbe0cf942fbe2a3af0b75eaa414ac038"},
{file = "torch-1.13.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:2c3581a3fd81eb1f0f22997cddffea569fea53bafa372b2c0471db373b26aafc"},
{file = "torch-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:0aa46f0ac95050c604bcf9ef71da9f1172e5037fdf2ebe051962d47b123848e7"},
{file = "torch-1.13.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6930791efa8757cb6974af73d4996b6b50c592882a324b8fb0589c6a9ba2ddaf"},
{file = "torch-1.13.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:e0df902a7c7dd6c795698532ee5970ce898672625635d885eade9976e5a04949"},
]
[package.dependencies]
nvidia-cublas-cu11 = {version = "11.10.3.66", markers = "platform_system == \"Linux\""}
nvidia-cuda-nvrtc-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""}
nvidia-cuda-runtime-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""}
nvidia-cudnn-cu11 = {version = "8.5.0.96", markers = "platform_system == \"Linux\""}
typing-extensions = "*"
[package.extras]
opt-einsum = ["opt-einsum (>=3.3)"]
[[package]]
name = "torchscan"
version = "0.1.2"
description = "Useful information about your Pytorch module"
optional = false
python-versions = "<4,>=3.6"
files = [
{file = "torchscan-0.1.2-py3-none-any.whl", hash = "sha256:5c69f9cc20e5041cc1d307efc8070fac992d472539260fbbd269e24079eb8971"},
{file = "torchscan-0.1.2.tar.gz", hash = "sha256:acff638d5cfd639fa6838b2f61969f4206927ae9a41d6c358513c943bd2c4cfa"},
]
[package.dependencies]
torch = ">=1.5.0,<2.0.0"
[package.extras]
dev = ["Jinja2 (<3.1)", "black (>=22.1,<23.0)", "coverage[toml] (>=4.5.4)", "flake8 (>=3.9.0)", "furo (>=2022.3.4)", "isort (>=5.7.0)", "mypy (>=0.812)", "pydocstyle[toml] (>=6.0.0)", "pytest (>=5.3.2)", "sphinx (>=3.0.0,!=3.5.0)", "sphinx-copybutton (>=0.3.1)", "sphinxemoji (>=0.1.8)"]
docs = ["Jinja2 (<3.1)", "furo (>=2022.3.4)", "sphinx (>=3.0.0,!=3.5.0)", "sphinx-copybutton (>=0.3.1)", "sphinxemoji (>=0.1.8)"]
quality = ["black (>=22.1,<23.0)", "flake8 (>=3.9.0)", "isort (>=5.7.0)", "mypy (>=0.812)", "pydocstyle[toml] (>=6.0.0)"]
test = ["coverage[toml] (>=4.5.4)", "pytest (>=5.3.2)"]
[[package]]
name = "typing-extensions"
version = "4.13.2"
description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
python-versions = ">=3.8"
files = [
{file = "typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c"},
{file = "typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef"},
]
[[package]] [[package]]
name = "tzdata" name = "tzdata"
version = "2024.1" version = "2024.1"
@@ -779,7 +931,21 @@ files = [
{file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"},
] ]
[[package]]
name = "wheel"
version = "0.45.1"
description = "A built-package format for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "wheel-0.45.1-py3-none-any.whl", hash = "sha256:708e7481cc80179af0e556bbf0cc00b8444c7321e2700b8d8580231d13017248"},
{file = "wheel-0.45.1.tar.gz", hash = "sha256:661e1abd9198507b1409a20c02106d9670b2576e916d58f520316666abca6729"},
]
[package.extras]
test = ["pytest (>=6.0.0)", "setuptools (>=65)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "09616985362fba4910821705a0680cd56af41cf7c1d23bd9507f0e09cd55205f" content-hash = "810be7af5be0353a0399b8d1db6dc64a3681b6988860623c6ceda7b81b865968"

View File

@@ -23,6 +23,7 @@ scikit-learn = "^1.5.0"
scipy = "^1.14.0" scipy = "^1.14.0"
seaborn = "^0.13.2" seaborn = "^0.13.2"
six = "^1.16.0" six = "^1.16.0"
torchscan = "^0.1.2"
[build-system] [build-system]

View File

@@ -1,6 +1,8 @@
import logging import logging
import torch.nn as nn
import numpy as np import numpy as np
import torch.nn as nn
import torchscan
class BaseNet(nn.Module): class BaseNet(nn.Module):
@@ -10,6 +12,7 @@ class BaseNet(nn.Module):
super().__init__() super().__init__()
self.logger = logging.getLogger(self.__class__.__name__) self.logger = logging.getLogger(self.__class__.__name__)
self.rep_dim = None # representation dimensionality, i.e. dim of the code layer or last layer self.rep_dim = None # representation dimensionality, i.e. dim of the code layer or last layer
self.input_dim = None # input dimensionality, i.e. dim of the input layer
def forward(self, *input): def forward(self, *input):
""" """
@@ -18,9 +21,17 @@ class BaseNet(nn.Module):
""" """
raise NotImplementedError raise NotImplementedError
def summary(self): def summary(self, receptive_field: bool = False):
"""Network summary.""" """Network summary."""
net_parameters = filter(lambda p: p.requires_grad, self.parameters()) # net_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size()) for p in net_parameters]) # params = sum([np.prod(p.size()) for p in net_parameters])
self.logger.info("Trainable parameters: {}".format(params)) # self.logger.info("Trainable parameters: {}".format(params))
self.logger.info(self) # self.logger.info(self)
if not self.input_dim:
self.logger.warning(
"Input dimension is not set. Please set input_dim before calling summary."
)
return
self.logger.info(
torchscan.summary(self, self.input_dim, receptive_field=receptive_field)
)

View File

@@ -6,10 +6,10 @@ from base.base_net import BaseNet
class SubTer_LeNet(BaseNet): class SubTer_LeNet(BaseNet):
def __init__(self, rep_dim=1024): def __init__(self, rep_dim=1024):
super().__init__() super().__init__()
self.input_dim = (1, 32, 2048) # Input dimension for the network
self.rep_dim = rep_dim self.rep_dim = rep_dim
self.pool = nn.MaxPool2d(2, 2) self.pool = nn.MaxPool2d(2, 2)
@@ -31,7 +31,6 @@ class SubTer_LeNet(BaseNet):
class SubTer_LeNet_Decoder(BaseNet): class SubTer_LeNet_Decoder(BaseNet):
def __init__(self, rep_dim=1024): def __init__(self, rep_dim=1024):
super().__init__() super().__init__()
@@ -56,10 +55,10 @@ class SubTer_LeNet_Decoder(BaseNet):
class SubTer_LeNet_Autoencoder(BaseNet): class SubTer_LeNet_Autoencoder(BaseNet):
def __init__(self, rep_dim=1024): def __init__(self, rep_dim=1024):
super().__init__() super().__init__()
self.input_dim = (1, 32, 2048) # Input dimension for the network
self.rep_dim = rep_dim self.rep_dim = rep_dim
self.encoder = SubTer_LeNet(rep_dim=rep_dim) self.encoder = SubTer_LeNet(rep_dim=rep_dim)
self.decoder = SubTer_LeNet_Decoder(rep_dim=rep_dim) self.decoder = SubTer_LeNet_Decoder(rep_dim=rep_dim)

View File

@@ -212,6 +212,7 @@ class DeepSADTrainer(BaseTrainer):
start_time = time.time() start_time = time.time()
idx_label_score = [] idx_label_score = []
net.eval() net.eval()
net.summary(receptive_field=True)
with torch.no_grad(): with torch.no_grad():
for data in test_loader: for data in test_loader:
inputs, labels, semi_targets, idx, _ = data inputs, labels, semi_targets, idx, _ = data
@@ -267,6 +268,7 @@ class DeepSADTrainer(BaseTrainer):
c = torch.zeros(net.rep_dim, device=self.device) c = torch.zeros(net.rep_dim, device=self.device)
net.eval() net.eval()
net.summary(receptive_field=True)
with torch.no_grad(): with torch.no_grad():
for data in train_loader: for data in train_loader:
# get the inputs of the batch # get the inputs of the batch