added deepsad base code

This commit is contained in:
Jan Kowalczyk
2024-06-28 07:42:12 +02:00
parent 2eb1bf2e05
commit 914bb020d0
57 changed files with 4974 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from .config import Config
from .visualization.plot_images_grid import plot_images_grid
from .misc import enumerate_discrete, log_sum_exp, binary_cross_entropy

View File

@@ -0,0 +1,23 @@
import json
class Config(object):
"""Base class for experimental setting/configuration."""
def __init__(self, settings):
self.settings = settings
def load_config(self, import_json):
"""Load settings dict from import_json (path/filename.json) JSON-file."""
with open(import_json, 'r') as fp:
settings = json.load(fp)
for key, value in settings.items():
self.settings[key] = value
def save_config(self, export_json):
"""Save settings dict to export_json (path/filename.json) JSON-file."""
with open(export_json, 'w') as fp:
json.dump(self.settings, fp)

View File

@@ -0,0 +1,46 @@
import torch
from torch.autograd import Variable
# Acknowledgements: https://github.com/wohlert/semi-supervised-pytorch
def enumerate_discrete(x, y_dim):
"""
Generates a 'torch.Tensor' of size batch_size x n_labels of the given label.
:param x: tensor with batch size to mimic
:param y_dim: number of total labels
:return variable
"""
def batch(batch_size, label):
labels = (torch.ones(batch_size, 1) * label).type(torch.LongTensor)
y = torch.zeros((batch_size, y_dim))
y.scatter_(1, labels, 1)
return y.type(torch.LongTensor)
batch_size = x.size(0)
generated = torch.cat([batch(batch_size, i) for i in range(y_dim)])
if x.is_cuda:
generated = generated.to(x.device)
return Variable(generated.float())
def log_sum_exp(tensor, dim=-1, sum_op=torch.sum):
"""
Uses the LogSumExp (LSE) as an approximation for the sum in a log-domain.
:param tensor: Tensor to compute LSE over
:param dim: dimension to perform operation over
:param sum_op: reductive operation to be applied, e.g. torch.sum or torch.mean
:return: LSE
"""
max, _ = torch.max(tensor, dim=dim, keepdim=True)
return torch.log(sum_op(torch.exp(tensor - max), dim=dim, keepdim=True) + 1e-8) + max
def binary_cross_entropy(x, y):
eps = 1e-8
return -torch.sum(y * torch.log(x + eps) + (1 - y) * torch.log(1 - x + eps), dim=-1)

View File

@@ -0,0 +1,26 @@
import torch
import matplotlib
matplotlib.use('Agg') # or 'PS', 'PDF', 'SVG'
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid
def plot_images_grid(x: torch.tensor, export_img, title: str = '', nrow=8, padding=2, normalize=False, pad_value=0):
"""Plot 4D Tensor of images of shape (B x C x H x W) as a grid."""
grid = make_grid(x, nrow=nrow, padding=padding, normalize=normalize, pad_value=pad_value)
npgrid = grid.cpu().numpy()
plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest')
ax = plt.gca()
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
if not (title == ''):
plt.title(title)
plt.savefig(export_img, bbox_inches='tight', pad_inches=0.1)
plt.clf()