black formatted files before changes

This commit is contained in:
Jan Kowalczyk
2024-06-28 11:36:46 +02:00
parent d33c6b1e16
commit 71f9662022
40 changed files with 2938 additions and 1260 deletions

View File

@@ -38,7 +38,9 @@ def log_sum_exp(tensor, dim=-1, sum_op=torch.sum):
:return: LSE
"""
max, _ = torch.max(tensor, dim=dim, keepdim=True)
return torch.log(sum_op(torch.exp(tensor - max), dim=dim, keepdim=True) + 1e-8) + max
return (
torch.log(sum_op(torch.exp(tensor - max), dim=dim, keepdim=True) + 1e-8) + max
)
def binary_cross_entropy(x, y):