black formatted files before changes

This commit is contained in:
Jan Kowalczyk
2024-06-28 11:36:46 +02:00
parent d33c6b1e16
commit 71f9662022
40 changed files with 2938 additions and 1260 deletions

View File

@@ -10,7 +10,7 @@ class Config(object):
def load_config(self, import_json):
"""Load settings dict from import_json (path/filename.json) JSON-file."""
with open(import_json, 'r') as fp:
with open(import_json, "r") as fp:
settings = json.load(fp)
for key, value in settings.items():
@@ -19,5 +19,5 @@ class Config(object):
def save_config(self, export_json):
"""Save settings dict to export_json (path/filename.json) JSON-file."""
with open(export_json, 'w') as fp:
with open(export_json, "w") as fp:
json.dump(self.settings, fp)

View File

@@ -38,7 +38,9 @@ def log_sum_exp(tensor, dim=-1, sum_op=torch.sum):
:return: LSE
"""
max, _ = torch.max(tensor, dim=dim, keepdim=True)
return torch.log(sum_op(torch.exp(tensor - max), dim=dim, keepdim=True) + 1e-8) + max
return (
torch.log(sum_op(torch.exp(tensor - max), dim=dim, keepdim=True) + 1e-8) + max
)
def binary_cross_entropy(x, y):

View File

@@ -1,26 +1,37 @@
import torch
import matplotlib
matplotlib.use('Agg') # or 'PS', 'PDF', 'SVG'
matplotlib.use("Agg") # or 'PS', 'PDF', 'SVG'
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid
def plot_images_grid(x: torch.tensor, export_img, title: str = '', nrow=8, padding=2, normalize=False, pad_value=0):
def plot_images_grid(
x: torch.tensor,
export_img,
title: str = "",
nrow=8,
padding=2,
normalize=False,
pad_value=0,
):
"""Plot 4D Tensor of images of shape (B x C x H x W) as a grid."""
grid = make_grid(x, nrow=nrow, padding=padding, normalize=normalize, pad_value=pad_value)
grid = make_grid(
x, nrow=nrow, padding=padding, normalize=normalize, pad_value=pad_value
)
npgrid = grid.cpu().numpy()
plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest')
plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation="nearest")
ax = plt.gca()
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
if not (title == ''):
if not (title == ""):
plt.title(title)
plt.savefig(export_img, bbox_inches='tight', pad_inches=0.1)
plt.savefig(export_img, bbox_inches="tight", pad_inches=0.1)
plt.clf()