black formatted files before changes
This commit is contained in:
@@ -10,7 +10,7 @@ class Config(object):
|
||||
def load_config(self, import_json):
|
||||
"""Load settings dict from import_json (path/filename.json) JSON-file."""
|
||||
|
||||
with open(import_json, 'r') as fp:
|
||||
with open(import_json, "r") as fp:
|
||||
settings = json.load(fp)
|
||||
|
||||
for key, value in settings.items():
|
||||
@@ -19,5 +19,5 @@ class Config(object):
|
||||
def save_config(self, export_json):
|
||||
"""Save settings dict to export_json (path/filename.json) JSON-file."""
|
||||
|
||||
with open(export_json, 'w') as fp:
|
||||
with open(export_json, "w") as fp:
|
||||
json.dump(self.settings, fp)
|
||||
|
||||
@@ -38,7 +38,9 @@ def log_sum_exp(tensor, dim=-1, sum_op=torch.sum):
|
||||
:return: LSE
|
||||
"""
|
||||
max, _ = torch.max(tensor, dim=dim, keepdim=True)
|
||||
return torch.log(sum_op(torch.exp(tensor - max), dim=dim, keepdim=True) + 1e-8) + max
|
||||
return (
|
||||
torch.log(sum_op(torch.exp(tensor - max), dim=dim, keepdim=True) + 1e-8) + max
|
||||
)
|
||||
|
||||
|
||||
def binary_cross_entropy(x, y):
|
||||
|
||||
@@ -1,26 +1,37 @@
|
||||
import torch
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') # or 'PS', 'PDF', 'SVG'
|
||||
|
||||
matplotlib.use("Agg") # or 'PS', 'PDF', 'SVG'
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
|
||||
def plot_images_grid(x: torch.tensor, export_img, title: str = '', nrow=8, padding=2, normalize=False, pad_value=0):
|
||||
def plot_images_grid(
|
||||
x: torch.tensor,
|
||||
export_img,
|
||||
title: str = "",
|
||||
nrow=8,
|
||||
padding=2,
|
||||
normalize=False,
|
||||
pad_value=0,
|
||||
):
|
||||
"""Plot 4D Tensor of images of shape (B x C x H x W) as a grid."""
|
||||
|
||||
grid = make_grid(x, nrow=nrow, padding=padding, normalize=normalize, pad_value=pad_value)
|
||||
grid = make_grid(
|
||||
x, nrow=nrow, padding=padding, normalize=normalize, pad_value=pad_value
|
||||
)
|
||||
npgrid = grid.cpu().numpy()
|
||||
|
||||
plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest')
|
||||
plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation="nearest")
|
||||
|
||||
ax = plt.gca()
|
||||
ax.xaxis.set_visible(False)
|
||||
ax.yaxis.set_visible(False)
|
||||
|
||||
if not (title == ''):
|
||||
if not (title == ""):
|
||||
plt.title(title)
|
||||
|
||||
plt.savefig(export_img, bbox_inches='tight', pad_inches=0.1)
|
||||
plt.savefig(export_img, bbox_inches="tight", pad_inches=0.1)
|
||||
plt.clf()
|
||||
|
||||
Reference in New Issue
Block a user