Skip to content

Commit

Permalink
Restucture components to loosely follow PyTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
garrettgibo committed Mar 17, 2021
1 parent 92799e5 commit 037b3ec
Show file tree
Hide file tree
Showing 12 changed files with 94 additions and 82 deletions.
2 changes: 1 addition & 1 deletion .github/linters/.flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[flake8]
max-line-length = 88
extend-ignore = E203
extend-ignore = E203, F401
3 changes: 1 addition & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import json

from wavenet import demo
from wavenet.logger import new_logger
from wavenet.utils import convert_mp3_folder, get_data
from wavenet.utils import convert_mp3_folder, get_data, new_logger

logger = new_logger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion wavenet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Demo of WaveNet Pipeline"""

from torch.utils.data import DataLoader
from wavenet.data import WAVData
from wavenet.utils import convert_mp3_folder
from wavenet.utils.data import WAVData


def demo(
Expand Down
2 changes: 2 additions & 0 deletions wavenet/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .activation import GatedActivationUnit
from .conv import CausalConv, ResidualStack
33 changes: 33 additions & 0 deletions wavenet/modules/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Gated activation function"""
import torch
import torch.nn as nn


class GatedActivationUnit(torch.nn.Module):
"""Gated activation unit layer from PixelCNN"""

def __init__(self):
"""Initialize a gated activation unit"""
super().__init__()
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()

def forward(self, filter_data, gate_data):
"""Apply gated activation unit to dilated convolutions.
From:
https://arxiv.org/abs/1606.05328
z = tanh(W_{f, k} ∗ x) ⊙ σ(W_{g,k} ∗ x)
Where:
∗ denotes a convolution operator
⊙denotes and element-wise multiplication operator
σ(·) is a sigmoid function
k is the layer index
f and g denote filter and gate respectively
W is learnable convolution filter
"""
output = self.tanh(filter_data) * self.sigmoid(gate_data)

return output
80 changes: 4 additions & 76 deletions wavenet/model.py → wavenet/modules/conv.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,8 @@
"""
Networks and components that define WaveNet
"""
from collections import OrderedDict

"""Convolution layers for WaveNet"""
import torch
import torch.nn as nn


class WaveNet(torch.nn.Module):
"""Full WaveNet implementation.
Major architecture is based off of Figure 4 in:
https://arxiv.org/abs/1609.03499
"""

def __init__(self, in_channels: int, res_channels: int):
"""Initialize WaveNet.
Args:
in_channels: number of channels for input channel. The number of
skip channels is the same as input channels.
res_channels: number of channels for residual input and output
"""
super().__init__()
self.model = nn.Sequential(
OrderedDict(
[
("causal_conv", CausalConv(in_channels, res_channels)),
(
"residual_stack",
ResidualStack(
in_channels=res_channels, out_channels=in_channels
),
),
("conv_1", nn.Conv1d(in_channels, in_channels, kernel_size=1)),
("relu_1", nn.ReLU()),
("conv_2", nn.Conv1d(in_channels, in_channels, kernel_size=1)),
("relu_2", nn.ReLU()),
("softmax", nn.Softmax()),
]
)
)

def forward(self, data):
"""Forward pass through full architecture"""
return self.model(data)
import wavenet
from wavenet.modules import GatedActivationUnit


class CausalConv(torch.nn.Module):
Expand Down Expand Up @@ -142,7 +100,7 @@ def __init__(self, in_channels: int, out_channels: int, dilation: int):
)
self.gated_activation = GatedActivationUnit()
self.conv = nn.Conv1d(
out_channels, out_channels, kernel_size=1, stride=1, bias=False,
out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False
)

def forward(self, data):
Expand Down Expand Up @@ -193,33 +151,3 @@ def forward(self, data):
gate_output = self.conv_gate(data)

return filter_output, gate_output


class GatedActivationUnit(torch.nn.Module):
"""Gated activation unit layer from PixelCNN"""

def __init__(self):
"""Initialize a gated activation unit"""
super().__init__()
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()

def forward(self, filter_data, gate_data):
"""Apply gated activation unit to dilated convolutions.
From:
https://arxiv.org/abs/1606.05328
z = tanh(W_{f, k} ∗ x) ⊙ σ(W_{g,k} ∗ x)
Where:
∗ denotes a convolution operator
⊙denotes and element-wise multiplication operator
σ(·) is a sigmoid function
k is the layer index
f and g denote filter and gate respectively
W is learnable convolution filter
"""
output = self.tanh(filter_data) * self.sigmoid(gate_data)

return output
48 changes: 48 additions & 0 deletions wavenet/modules/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Networks and components that define WaveNet
"""
from collections import OrderedDict

import torch
import torch.nn as nn
from wavenet.modules import CausalConv, ResidualStack


class WaveNet(torch.nn.Module):
"""Full WaveNet implementation.
Major architecture is based off of Figure 4 in:
https://arxiv.org/abs/1609.03499
"""

def __init__(self, in_channels: int, res_channels: int):
"""Initialize WaveNet.
Args:
in_channels: number of channels for input channel. The number of
skip channels is the same as input channels.
res_channels: number of channels for residual input and output
"""
super().__init__()
self.model = nn.Sequential(
OrderedDict(
[
("causal_conv", CausalConv(in_channels, res_channels)),
(
"residual_stack",
ResidualStack(
in_channels=res_channels, out_channels=in_channels
),
),
("conv_1", nn.Conv1d(in_channels, in_channels, kernel_size=1)),
("relu_1", nn.ReLU()),
("conv_2", nn.Conv1d(in_channels, in_channels, kernel_size=1)),
("relu_2", nn.ReLU()),
("softmax", nn.Softmax()),
]
)
)

def forward(self, data):
"""Forward pass through full architecture"""
return self.model(data)
2 changes: 2 additions & 0 deletions wavenet/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .logger import new_logger
from .utils import convert_mp3_folder, get_data
2 changes: 1 addition & 1 deletion wavenet/data.py → wavenet/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
from scipy.io import wavfile
from torch.utils.data import Dataset
from wavenet.logger import new_logger
from wavenet.utils import new_logger

logger = new_logger(__name__)

Expand Down
File renamed without changes.
Empty file added wavenet/utils/train.py
Empty file.
2 changes: 1 addition & 1 deletion wavenet/utils.py → wavenet/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os

from pydub import AudioSegment
from wavenet.logger import new_logger
from wavenet.utils import new_logger

logger = new_logger(__name__)

Expand Down

0 comments on commit 037b3ec

Please sign in to comment.