Skip to content

Commit

Permalink
waifu2x: Change random noise aug
Browse files Browse the repository at this point in the history
  • Loading branch information
nagadomi committed Nov 28, 2024
1 parent 905be0a commit b18ce4e
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions waifu2x/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,12 @@ def diff_pair_random_noise(input, target, strength=0.01, p=0.1):
if random.uniform(0., 1.) < p:
B, C, H, W = input.shape
noise1x = torch.randn((B, C, H, W), dtype=input.dtype, device=input.device)
noise2x = torch.randn((B, C, H // 2, W // 2), dtype=input.dtype, device=input.device)
noise2x = torch.nn.functional.interpolate(noise2x, size=(H, W), mode="nearest")
noise = ((noise1x + noise2x) * (strength / 2.0)).detach()
if random.uniform(0., 1.) < 0.5:
noise2x = torch.randn((B, C, H // 2, W // 2), dtype=input.dtype, device=input.device)
noise2x = torch.nn.functional.interpolate(noise2x, size=(H, W), mode="nearest")
noise = ((noise1x + noise2x) * (strength / 2.0)).detach()
else:
noise = (noise1x * strength).detach()
return input + noise, target + noise
else:
return input, target
Expand Down

0 comments on commit b18ce4e

Please sign in to comment.