-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Restucture components to loosely follow PyTorch
- Loading branch information
1 parent
92799e5
commit 037b3ec
Showing
12 changed files
with
94 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
[flake8] | ||
max-line-length = 88 | ||
extend-ignore = E203 | ||
extend-ignore = E203, F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .activation import GatedActivationUnit | ||
from .conv import CausalConv, ResidualStack |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
"""Gated activation function""" | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class GatedActivationUnit(torch.nn.Module): | ||
"""Gated activation unit layer from PixelCNN""" | ||
|
||
def __init__(self): | ||
"""Initialize a gated activation unit""" | ||
super().__init__() | ||
self.tanh = nn.Tanh() | ||
self.sigmoid = nn.Sigmoid() | ||
|
||
def forward(self, filter_data, gate_data): | ||
"""Apply gated activation unit to dilated convolutions. | ||
From: | ||
https://arxiv.org/abs/1606.05328 | ||
z = tanh(W_{f, k} ∗ x) ⊙ σ(W_{g,k} ∗ x) | ||
Where: | ||
∗ denotes a convolution operator | ||
⊙denotes and element-wise multiplication operator | ||
σ(·) is a sigmoid function | ||
k is the layer index | ||
f and g denote filter and gate respectively | ||
W is learnable convolution filter | ||
""" | ||
output = self.tanh(filter_data) * self.sigmoid(gate_data) | ||
|
||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
""" | ||
Networks and components that define WaveNet | ||
""" | ||
from collections import OrderedDict | ||
|
||
import torch | ||
import torch.nn as nn | ||
from wavenet.modules import CausalConv, ResidualStack | ||
|
||
|
||
class WaveNet(torch.nn.Module): | ||
"""Full WaveNet implementation. | ||
Major architecture is based off of Figure 4 in: | ||
https://arxiv.org/abs/1609.03499 | ||
""" | ||
|
||
def __init__(self, in_channels: int, res_channels: int): | ||
"""Initialize WaveNet. | ||
Args: | ||
in_channels: number of channels for input channel. The number of | ||
skip channels is the same as input channels. | ||
res_channels: number of channels for residual input and output | ||
""" | ||
super().__init__() | ||
self.model = nn.Sequential( | ||
OrderedDict( | ||
[ | ||
("causal_conv", CausalConv(in_channels, res_channels)), | ||
( | ||
"residual_stack", | ||
ResidualStack( | ||
in_channels=res_channels, out_channels=in_channels | ||
), | ||
), | ||
("conv_1", nn.Conv1d(in_channels, in_channels, kernel_size=1)), | ||
("relu_1", nn.ReLU()), | ||
("conv_2", nn.Conv1d(in_channels, in_channels, kernel_size=1)), | ||
("relu_2", nn.ReLU()), | ||
("softmax", nn.Softmax()), | ||
] | ||
) | ||
) | ||
|
||
def forward(self, data): | ||
"""Forward pass through full architecture""" | ||
return self.model(data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .logger import new_logger | ||
from .utils import convert_mp3_folder, get_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters