2nd subter network arch

This commit is contained in:
Jan Kowalczyk
2025-06-17 07:26:03 +02:00
parent 9298dea329
commit bbd093da0c
9 changed files with 248 additions and 30 deletions

View File

@@ -73,6 +73,22 @@
"type": "github"
}
},
"nixpkgs-newest": {
"locked": {
"lastModified": 1749285348,
"narHash": "sha256-frdhQvPbmDYaScPFiCnfdh3B/Vh81Uuoo0w5TkWmmjU=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "3e3afe5174c561dee0df6f2c2b2236990146329f",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"poetry2nix": {
"inputs": {
"flake-utils": "flake-utils_2",
@@ -101,6 +117,7 @@
"inputs": {
"flake-utils": "flake-utils",
"nixpkgs": "nixpkgs",
"nixpkgs-newest": "nixpkgs-newest",
"poetry2nix": "poetry2nix"
}
},

View File

@@ -4,6 +4,8 @@
inputs = {
flake-utils.url = "github:numtide/flake-utils";
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable-small";
# Added newest nixpkgs for an updated poetry package.
nixpkgs-newest.url = "github:NixOS/nixpkgs/nixos-unstable";
poetry2nix = {
url = "github:nix-community/poetry2nix";
inputs.nixpkgs.follows = "nixpkgs";
@@ -14,6 +16,7 @@
{
self,
nixpkgs,
nixpkgs-newest,
flake-utils,
poetry2nix,
}:
@@ -26,6 +29,7 @@
config.allowUnfree = true;
config.cudaSupport = true;
};
pkgsNew = nixpkgs-newest.legacyPackages.${system};
thundersvm = import ./nix/thundersvm.nix {
inherit pkgs;
inherit (pkgs) fetchFromGitHub cmake gcc12Stdenv;
@@ -37,7 +41,7 @@
pythonPackages = pkgs.python311Packages;
thundersvm = thundersvm;
};
inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryApplication;
inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryApplication defaultPoetryOverrides;
in
{
packages = {
@@ -45,6 +49,13 @@
projectDir = self;
preferWheels = true;
python = pkgs.python311;
overrides = defaultPoetryOverrides.extend (
final: prev: {
torch-receptive-field = prev.torch-receptive-field.overridePythonAttrs (old: {
buildInputs = (old.buildInputs or [ ]) ++ [ prev.setuptools ];
});
}
);
};
default = self.packages.${system}.deepsad;
};
@@ -63,7 +74,7 @@
devShells.poetry = pkgs.mkShell {
packages = [
pkgs.poetry
pkgsNew.poetry
pkgs.python311
];
};

View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
[[package]]
name = "click"
@@ -6,6 +6,7 @@ version = "8.1.7"
description = "Composable command line interface toolkit"
optional = false
python-versions = ">=3.7"
groups = ["main"]
files = [
{file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"},
{file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"},
@@ -20,6 +21,8 @@ version = "0.4.6"
description = "Cross-platform colored terminal text."
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
groups = ["main"]
markers = "platform_system == \"Windows\""
files = [
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
@@ -31,6 +34,7 @@ version = "1.2.1"
description = "Python library for calculating contours of 2D quadrilateral grids"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "contourpy-1.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040"},
{file = "contourpy-1.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd"},
@@ -94,6 +98,7 @@ version = "1.3.2"
description = "Convex optimization package"
optional = false
python-versions = ">=3, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
groups = ["main"]
files = [
{file = "cvxopt-1.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cd4a1bba537a34808b92f1e793e3499029d339a7a2ab6d989f82e395b7b740ff"},
{file = "cvxopt-1.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e3cd2db913b1cf64d84cdb7bc467a8a15adbd1f0f83a7a45a7167ad590f79408"},
@@ -126,6 +131,7 @@ version = "0.12.1"
description = "Composable style cycles"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30"},
{file = "cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c"},
@@ -141,6 +147,7 @@ version = "4.53.0"
description = "Tools to manipulate font files"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "fonttools-4.53.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:52a6e0a7a0bf611c19bc8ec8f7592bdae79c8296c70eb05917fd831354699b20"},
{file = "fonttools-4.53.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:099634631b9dd271d4a835d2b2a9e042ccc94ecdf7e2dd9f7f34f7daf333358d"},
@@ -187,18 +194,18 @@ files = [
]
[package.extras]
all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "pycairo", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0)", "xattr", "zopfli (>=0.1.4)"]
all = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "fs (>=2.2.0,<3)", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres ; platform_python_implementation == \"PyPy\"", "pycairo", "scipy ; platform_python_implementation != \"PyPy\"", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0) ; python_version <= \"3.12\"", "xattr ; sys_platform == \"darwin\"", "zopfli (>=0.1.4)"]
graphite = ["lz4 (>=1.7.4.2)"]
interpolatable = ["munkres", "pycairo", "scipy"]
interpolatable = ["munkres ; platform_python_implementation == \"PyPy\"", "pycairo", "scipy ; platform_python_implementation != \"PyPy\""]
lxml = ["lxml (>=4.0)"]
pathops = ["skia-pathops (>=0.5.0)"]
plot = ["matplotlib"]
repacker = ["uharfbuzz (>=0.23.0)"]
symfont = ["sympy"]
type1 = ["xattr"]
type1 = ["xattr ; sys_platform == \"darwin\""]
ufo = ["fs (>=2.2.0,<3)"]
unicode = ["unicodedata2 (>=15.1.0)"]
woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"]
unicode = ["unicodedata2 (>=15.1.0) ; python_version <= \"3.12\""]
woff = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "zopfli (>=0.1.4)"]
[[package]]
name = "joblib"
@@ -206,6 +213,7 @@ version = "1.4.2"
description = "Lightweight pipelining with Python functions"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"},
{file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"},
@@ -217,6 +225,7 @@ version = "1.4.5"
description = "A fast implementation of the Cassowary constraint solver"
optional = false
python-versions = ">=3.7"
groups = ["main"]
files = [
{file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af"},
{file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3"},
@@ -330,6 +339,7 @@ version = "3.9.0"
description = "Python plotting package"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "matplotlib-3.9.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2bcee1dffaf60fe7656183ac2190bd630842ff87b3153afb3e384d966b57fe56"},
{file = "matplotlib-3.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3f988bafb0fa39d1074ddd5bacd958c853e11def40800c5824556eb630f94d3b"},
@@ -382,6 +392,7 @@ version = "2.0.0"
description = "Fundamental package for array computing in Python"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "numpy-2.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f"},
{file = "numpy-2.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2635dbd200c2d6faf2ef9a0d04f0ecc6b13b3cad54f7c67c61155138835515d2"},
@@ -436,6 +447,8 @@ version = "11.10.3.66"
description = "CUBLAS native runtime libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\""
files = [
{file = "nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl", hash = "sha256:d32e4d75f94ddfb93ea0a5dda08389bcc65d8916a25cb9f37ac89edaeed3bded"},
{file = "nvidia_cublas_cu11-11.10.3.66-py3-none-win_amd64.whl", hash = "sha256:8ac17ba6ade3ed56ab898a036f9ae0756f1e81052a317bf98f8c6d18dc3ae49e"},
@@ -451,6 +464,8 @@ version = "11.7.99"
description = "NVRTC native runtime libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\""
files = [
{file = "nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:9f1562822ea264b7e34ed5930567e89242d266448e936b85bc97a3370feabb03"},
{file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:f7d9610d9b7c331fa0da2d1b2858a4a8315e6d49765091d28711c8946e7425e7"},
@@ -467,6 +482,8 @@ version = "11.7.99"
description = "CUDA Runtime native Libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\""
files = [
{file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:cc768314ae58d2641f07eac350f40f99dcb35719c4faff4bc458a7cd2b119e31"},
{file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:bc77fa59a7679310df9d5c70ab13c4e34c64ae2124dd1efd7e5474b71be125c7"},
@@ -482,6 +499,8 @@ version = "8.5.0.96"
description = "cuDNN runtime libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\""
files = [
{file = "nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:402f40adfc6f418f9dae9ab402e773cfed9beae52333f6d86ae3107a1b9527e7"},
{file = "nvidia_cudnn_cu11-8.5.0.96-py3-none-manylinux1_x86_64.whl", hash = "sha256:71f8111eb830879ff2836db3cccf03bbd735df9b0d17cd93761732ac50a8a108"},
@@ -497,6 +516,7 @@ version = "24.1"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"},
{file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"},
@@ -508,6 +528,7 @@ version = "2.2.2"
description = "Powerful data structures for data analysis, time series, and statistics"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
{file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"},
@@ -541,10 +562,7 @@ files = [
]
[package.dependencies]
numpy = [
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
]
numpy = {version = ">=1.23.2", markers = "python_version == \"3.11\""}
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
tzdata = ">=2022.7"
@@ -580,6 +598,7 @@ version = "10.3.0"
description = "Python Imaging Library (Fork)"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "pillow-10.3.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45"},
{file = "pillow-10.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c"},
@@ -657,7 +676,7 @@ docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline
fpx = ["olefile"]
mic = ["olefile"]
tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"]
typing = ["typing-extensions"]
typing = ["typing-extensions ; python_version < \"3.10\""]
xmp = ["defusedxml"]
[[package]]
@@ -666,6 +685,7 @@ version = "3.1.2"
description = "pyparsing module - Classes and methods to define and execute parsing grammars"
optional = false
python-versions = ">=3.6.8"
groups = ["main"]
files = [
{file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"},
{file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"},
@@ -680,6 +700,7 @@ version = "2.9.0.post0"
description = "Extensions to the standard Python datetime module"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
groups = ["main"]
files = [
{file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
{file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
@@ -694,6 +715,7 @@ version = "2024.1"
description = "World timezone definitions, modern and historical"
optional = false
python-versions = "*"
groups = ["main"]
files = [
{file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"},
{file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"},
@@ -705,6 +727,7 @@ version = "1.5.0"
description = "A set of python modules for machine learning and data mining"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "scikit_learn-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:12e40ac48555e6b551f0a0a5743cc94cc5a765c9513fe708e01f0aa001da2801"},
{file = "scikit_learn-1.5.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:f405c4dae288f5f6553b10c4ac9ea7754d5180ec11e296464adb5d6ac68b6ef5"},
@@ -750,6 +773,7 @@ version = "1.14.0"
description = "Fundamental algorithms for scientific computing in Python"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "scipy-1.14.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7e911933d54ead4d557c02402710c2396529540b81dd554fc1ba270eb7308484"},
{file = "scipy-1.14.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:687af0a35462402dd851726295c1a5ae5f987bd6e9026f52e9505994e2f84ef6"},
@@ -784,7 +808,7 @@ numpy = ">=1.23.5,<2.3"
[package.extras]
dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"]
doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"]
test = ["Cython", "array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
test = ["Cython", "array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja ; sys_platform != \"emscripten\"", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
[[package]]
name = "seaborn"
@@ -792,6 +816,7 @@ version = "0.13.2"
description = "Statistical data visualization"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987"},
{file = "seaborn-0.13.2.tar.gz", hash = "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7"},
@@ -813,19 +838,21 @@ version = "80.9.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
python-versions = ">=3.9"
groups = ["main"]
markers = "platform_system == \"Linux\""
files = [
{file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"},
{file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"},
]
[package.extras]
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"]
core = ["importlib_metadata (>=6)", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""]
core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"]
cover = ["pytest-cov"]
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
enabler = ["pytest-enabler (>=2.2)"]
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"]
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.14.*)", "pytest-mypy"]
[[package]]
name = "six"
@@ -833,6 +860,7 @@ version = "1.16.0"
description = "Python 2 and 3 compatibility utilities"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
groups = ["main"]
files = [
{file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
@@ -844,6 +872,7 @@ version = "3.5.0"
description = "threadpoolctl"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"},
{file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"},
@@ -855,6 +884,7 @@ version = "1.13.1"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
optional = false
python-versions = ">=3.7.0"
groups = ["main"]
files = [
{file = "torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:fd12043868a34a8da7d490bf6db66991108b00ffbeecb034228bfcbbd4197143"},
{file = "torch-1.13.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d9fe785d375f2e26a5d5eba5de91f89e6a3be5d11efb497e76705fdf93fa3c2e"},
@@ -889,12 +919,29 @@ typing-extensions = "*"
[package.extras]
opt-einsum = ["opt-einsum (>=3.3)"]
[[package]]
name = "torch_receptive_field"
version = "0.1.0"
description = "Compute CNN receptive field size in pytorch in one line"
optional = false
python-versions = "*"
groups = ["main"]
files = []
develop = false
[package.source]
type = "git"
url = "https://github.com/Fangyh09/pytorch-receptive-field.git"
reference = "HEAD"
resolved_reference = "0aeb7f80cd1dd8aa1ed8e6a6882f651dd7e6e877"
[[package]]
name = "torchscan"
version = "0.1.2"
description = "Useful information about your Pytorch module"
optional = false
python-versions = "<4,>=3.6"
groups = ["main"]
files = [
{file = "torchscan-0.1.2-py3-none-any.whl", hash = "sha256:5c69f9cc20e5041cc1d307efc8070fac992d472539260fbbd269e24079eb8971"},
{file = "torchscan-0.1.2.tar.gz", hash = "sha256:acff638d5cfd639fa6838b2f61969f4206927ae9a41d6c358513c943bd2c4cfa"},
@@ -915,6 +962,7 @@ version = "4.13.2"
description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c"},
{file = "typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef"},
@@ -926,6 +974,7 @@ version = "2024.1"
description = "Provider of IANA time zone data"
optional = false
python-versions = ">=2"
groups = ["main"]
files = [
{file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"},
{file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"},
@@ -937,6 +986,8 @@ version = "0.45.1"
description = "A built-package format for Python"
optional = false
python-versions = ">=3.8"
groups = ["main"]
markers = "platform_system == \"Linux\""
files = [
{file = "wheel-0.45.1-py3-none-any.whl", hash = "sha256:708e7481cc80179af0e556bbf0cc00b8444c7321e2700b8d8580231d13017248"},
{file = "wheel-0.45.1.tar.gz", hash = "sha256:661e1abd9198507b1409a20c02106d9670b2576e916d58f520316666abca6729"},
@@ -946,6 +997,6 @@ files = [
test = ["pytest (>=6.0.0)", "setuptools (>=65)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "810be7af5be0353a0399b8d1db6dc64a3681b6988860623c6ceda7b81b865968"
lock-version = "2.1"
python-versions = ">=3.11,<3.12"
content-hash = "9091021019614bb6f1e01f945ff51fd509262c1611c1a84d66655bef630ccb7b"

View File

@@ -6,7 +6,7 @@ authors = ["Your Name <you@example.com>"]
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.11"
python = ">=3.11,<3.12"
click = "^8.1.7"
matplotlib = "^3.9.0"
numpy = "^2.0.0"
@@ -24,6 +24,7 @@ scipy = "^1.14.0"
seaborn = "^0.13.2"
six = "^1.16.0"
torchscan = "^0.1.2"
torch-receptive-field = {git = "https://github.com/Fangyh09/pytorch-receptive-field.git"}
[build-system]

View File

@@ -1,6 +1,5 @@
import logging
import numpy as np
import torch.nn as nn
import torchscan
@@ -32,8 +31,5 @@ class BaseNet(nn.Module):
"Input dimension is not set. Please set input_dim before calling summary."
)
return
self.logger.info(
self.logger.info("torchscan:\n")
torchscan.summary(self, self.input_dim, receptive_field=receptive_field)
)
module_info = torchscan.crawl_module(self, self.input_dim)
pass

View File

@@ -55,6 +55,7 @@ from utils.visualization.plot_images_grid import plot_images_grid
"mnist_LeNet",
"elpv_LeNet",
"subter_LeNet",
"subter_efficient",
"subter_LeNet_Split",
"fmnist_LeNet",
"cifar10_LeNet",

View File

@@ -5,6 +5,7 @@ from .fmnist_LeNet import FashionMNIST_LeNet, FashionMNIST_LeNet_Autoencoder
from .mlp import MLP, MLP_Autoencoder
from .mnist_LeNet import MNIST_LeNet, MNIST_LeNet_Autoencoder
from .subter_LeNet import SubTer_LeNet, SubTer_LeNet_Autoencoder
from .subter_LeNet_rf import SubTer_Efficient_AE, SubTer_EfficientEncoder
from .subter_LeNet_Split import SubTer_LeNet_Split, SubTer_LeNet_Split_Autoencoder
from .vae import VariationalAutoencoder
@@ -16,6 +17,7 @@ def build_network(net_name, rep_dim, ae_net=None):
"mnist_LeNet",
"elpv_LeNet",
"subter_LeNet",
"subter_efficient",
"subter_LeNet_Split",
"mnist_DGM_M2",
"mnist_DGM_M1M2",
@@ -48,6 +50,9 @@ def build_network(net_name, rep_dim, ae_net=None):
if net_name == "subter_LeNet":
net = SubTer_LeNet(rep_dim=rep_dim)
if net_name == "subter_efficient":
net = SubTer_EfficientEncoder(rep_dim=rep_dim)
if net_name == "subter_LeNet_Split":
net = SubTer_LeNet_Split()
@@ -135,6 +140,7 @@ def build_autoencoder(net_name, rep_dim):
implemented_networks = (
"elpv_LeNet",
"subter_LeNet",
"subter_efficient",
"subter_LeNet_Split",
"mnist_LeNet",
"mnist_DGM_M1M2",
@@ -160,6 +166,9 @@ def build_autoencoder(net_name, rep_dim):
if net_name == "subter_LeNet":
ae_net = SubTer_LeNet_Autoencoder(rep_dim=rep_dim)
if net_name == "subter_efficient":
ae_net = SubTer_Efficient_AE(rep_dim=rep_dim)
if net_name == "subter_LeNet_Split":
ae_net = SubTer_LeNet_Split_Autoencoder()

View File

@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_receptive_field
from base.base_net import BaseNet
@@ -29,6 +30,13 @@ class SubTer_LeNet(BaseNet):
x = self.fc1(x)
return x
def summary(self, receptive_field: bool = False):
# first run super method to log parameters and structure
super().summary(receptive_field=receptive_field)
self.logger.info("torch_receptive_field:")
torch_receptive_field.receptive_field(self, input_size=self.input_dim)
# torch_receptive_field.receptive_field_for_unit(rf, "2", (2,2))
class SubTer_LeNet_Decoder(BaseNet):
def __init__(self, rep_dim=1024):

View File

@@ -0,0 +1,124 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_receptive_field
from base.base_net import BaseNet
# ---------------------- helper ---------------------------------------------
def circ_pad_x(x, pad_w):
"""Circular pad on width (azimuth) only."""
return F.pad(x, (pad_w, pad_w, 0, 0), mode="circular")
class DWSeparableConv(nn.Module):
"""Depthwise separable 3×17 conv + 1×1 pointwise + optional channel shuffle."""
def __init__(self, c_in: int, c_out: int, shuffle: bool = False):
super().__init__()
self.dw = nn.Conv2d(
c_in, c_in, kernel_size=(3, 17), padding=(1, 0), groups=c_in, bias=False
)
self.pw = nn.Conv2d(c_in, c_out, kernel_size=1, bias=False)
self.bn = nn.BatchNorm2d(c_out, eps=1e-4, affine=False)
self.shuffle = shuffle
def _shuffle(self, x):
if x.size(1) % 2 != 0:
return x # can't shuffle odd channels
b, c, h, w = x.shape
x = x.view(b, 2, c // 2, h, w)
x = torch.transpose(x, 1, 2).contiguous()
return x.view(b, c, h, w)
def forward(self, x):
x = circ_pad_x(x, 8)
x = self.dw(x)
x = self.pw(x)
if self.shuffle:
x = self._shuffle(x)
return F.leaky_relu(self.bn(x), 0.1)
def summary(self, receptive_field: bool = False):
# first run super method to log parameters and structure
super().summary(receptive_field=receptive_field)
self.logger.info("torch_receptive_field:")
torch_receptive_field.receptive_field(self, input_size=self.input_dim)
# torch_receptive_field.receptive_field_for_unit(rf, "2", (2,2))
# ---------------------- encoder --------------------------------------------
class SubTer_EfficientEncoder(BaseNet):
def __init__(self, rep_dim: int = 512):
super().__init__()
self.input_dim = (1, 32, 2048)
self.rep_dim = rep_dim
self.conv1 = DWSeparableConv(1, 16)
self.pool_h4 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4)) # 2048 ➔ 512
self.conv2 = DWSeparableConv(16, 32, shuffle=True)
self.pool2 = nn.MaxPool2d(2, 2) # 32 ➔ 16 vertically, 512 ➔ 256 horizontally
self.pool3 = nn.MaxPool2d(2, 2) # 16 ➔ 8 , 256 ➔ 128
self.squeeze = nn.Conv2d(32, 8, 1, bias=False)
self.fc = nn.Linear(8 * 8 * 128, rep_dim, bias=False)
def forward(self, x):
x = x.view(-1, 1, 32, 2048)
x = self.conv1(x)
x = self.pool_h4(x)
x = self.conv2(x)
x = self.pool2(x)
x = self.pool3(x)
x = self.squeeze(x)
return self.fc(x.flatten(1))
# ---------------------- decoder (NN upsample) ------------------------------
class SubTer_EfficientDecoder(BaseNet):
def __init__(self, rep_dim: int = 512):
super().__init__()
self.fc = nn.Linear(rep_dim, 8 * 8 * 128, bias=False)
self.expand = nn.Conv2d(8, 32, 1, bias=False)
self.rep_dim = rep_dim
# Nearestneighbour upsampling layers
self.up1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
DWSeparableConv(32, 32, shuffle=True),
) # 8×128 ➔ 16×256
self.up2 = nn.Sequential(
nn.Upsample(scale_factor=(1, 4), mode="nearest"),
DWSeparableConv(32, 16),
) # 16×256 ➔ 16×1024
self.up3 = nn.Sequential(
nn.Upsample(scale_factor=(2, 2), mode="nearest"),
DWSeparableConv(16, 8),
) # 16×1024 ➔ 32×2048
self.out_conv = nn.Conv2d(8, 1, kernel_size=(3, 17), padding=(1, 0), bias=False)
def forward(self, x):
x = self.fc(x).view(x.size(0), 8, 8, 128)
x = self.expand(x)
x = self.up1(x)
x = self.up2(x)
x = self.up3(x)
x = circ_pad_x(x, 8)
return torch.sigmoid(self.out_conv(x))
# ---------------------- autoencoder wrapper -------------------------------
class SubTer_Efficient_AE(BaseNet):
def __init__(self, rep_dim: int = 512):
super().__init__()
self.input_dim = (1, 32, 2048) # Input dimension for the network
self.rep_dim = rep_dim
self.encoder = SubTer_EfficientEncoder(rep_dim)
self.decoder = SubTer_EfficientDecoder(rep_dim)
def forward(self, x):
return self.decoder(self.encoder(x))