Skip to content

Commit

Permalink
Update losses.py
Browse files Browse the repository at this point in the history
  • Loading branch information
OedoSoldier committed Dec 19, 2023
1 parent 97f18c0 commit 2e7c173
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
2 changes: 2 additions & 0 deletions losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def __init__(self, model, wd, model_sr, slm_sr=16000):
self.wd = wd
self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
self.wavlm.eval()
for param in self.wavlm.parameters():
param.requires_grad = False

def forward(self, wav, y_rec):
with torch.no_grad():
Expand Down
19 changes: 10 additions & 9 deletions train_ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,21 @@ def run():
scheduler_dur_disc = None
scaler = GradScaler(enabled=hps.train.bf16_run)

wl = WavLMLoss(
hps.model.slm.model,
net_wd,
hps.data.sampling_rate,
hps.model.slm.sr,
).to(local_rank)

for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0:
train_and_evaluate(
rank,
local_rank,
epoch,
hps,
[net_g, net_d, net_dur_disc, net_wd],
[net_g, net_d, net_dur_disc, net_wd, wl],
[optim_g, optim_d, optim_dur_disc, optim_wd],
[scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd],
scaler,
Expand All @@ -370,7 +377,7 @@ def run():
local_rank,
epoch,
hps,
[net_g, net_d, net_dur_disc, net_wd],
[net_g, net_d, net_dur_disc, net_wd, wl],
[optim_g, optim_d, optim_dur_disc, optim_wd],
[scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd],
scaler,
Expand All @@ -397,18 +404,12 @@ def train_and_evaluate(
logger,
writers,
):
net_g, net_d, net_dur_disc, net_wd = nets
net_g, net_d, net_dur_disc, net_wd, wl = nets
optim_g, optim_d, optim_dur_disc, optim_wd = optims
scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd = schedulers
train_loader, eval_loader = loaders
if writers is not None:
writer, writer_eval = writers
wl = WavLMLoss(
hps.model.slm.model,
net_wd,
hps.data.sampling_rate,
hps.model.slm.sr,
).to(local_rank)

train_loader.batch_sampler.set_epoch(epoch)
global global_step
Expand Down

0 comments on commit 2e7c173

Please sign in to comment.