Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attaining arbitrarily long audio generation using chunked generation and latent space interpolation #101

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ services:
network_mode: "host"
stdin_open: true
tty: true
command: ["python3", "gradio_interface.py"]
command: ["bash", "-c", "pip install nltk && python3 -c 'import nltk; nltk.download(\"punkt\"); nltk.download(\"punkt_tab\")' && python3 gradio_interface.py"]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not install nltk with the rest of the python packages? furthermore, don't you already download punkt on lines 99 in gradio_interface.py?

environment:
- NVIDIA_VISIBLE_DEVICES=0
- GRADIO_SHARE=False
- GRADIO_SHARE=False
volumes:
- .:/app
- ./cache:/root/.cache/
186 changes: 164 additions & 22 deletions gradio_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import torchaudio
import gradio as gr
from os import getenv
from typing import Tuple
import numpy as np
import nltk

from zonos.model import Zonos
from zonos.conditioning import make_cond_dict, supported_language_codes
Expand Down Expand Up @@ -82,20 +85,106 @@ def update_ui(model_choice):
)


def generate_with_latent_windows(
model: Zonos,
text: str,
cond_dict: dict,
overlap_seconds: float = 0.3,
cfg_scale: float = 2.0,
min_p: float = 0.15,
seed: int = 420,
) -> Tuple[int, np.ndarray]:
"""Generate audio using sliding windows in latent space."""
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')

# Set global seed at the start
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Split into sentences
sentences = nltk.sent_tokenize(text)
if len(sentences) == 1:
# Single sentence - just generate normally
torch.manual_seed(seed)
conditioning = model.prepare_conditioning(cond_dict)
codes = model.generate(
prefix_conditioning=conditioning,
max_new_tokens=86 * 60, # 60 seconds max
cfg_scale=cfg_scale,
batch_size=1,
sampling_params=dict(min_p=min_p)
)
wav_out = model.autoencoder.decode(codes).cpu().detach()
return model.autoencoder.sampling_rate, wav_out.squeeze().numpy()

# Calculate window sizes in latent tokens
tokens_per_second = 86
overlap_size = int(overlap_seconds * tokens_per_second)

all_codes = []

for i, sentence in enumerate(sentences):
# Create conditioning for this sentence
sent_dict = cond_dict.copy()
sent_dict['espeak'] = ([sentence], [cond_dict['espeak'][1][0]])

# Generate this sentence
torch.manual_seed(seed)
conditioning = model.prepare_conditioning(sent_dict)
codes = model.generate(
prefix_conditioning=conditioning,
max_new_tokens=int(len(sentence) * 1.5 * tokens_per_second / 10), # Rough estimate
cfg_scale=cfg_scale,
batch_size=1,
sampling_params=dict(min_p=min_p)
)

if i == 0:
all_codes.append(codes)
continue

# Crossfade with previous sentence
overlap_a = all_codes[-1][..., -overlap_size:]
overlap_b = codes[..., :overlap_size]

# Replace overlap region in previous sentence
fade = torch.linspace(0, 1, overlap_size, device=codes.device)
fade = fade.view(1, 1, -1)
overlap_region = overlap_a * (1 - fade) + overlap_b * fade

# Replace overlap region in previous sentence
all_codes[-1] = torch.cat([
all_codes[-1][..., :-overlap_size],
overlap_region
], dim=-1)

# Add new sentence (excluding overlap)
all_codes.append(codes[..., overlap_size:])

# Check total length
total_length = sum(c.shape[-1] for c in all_codes)
if total_length >= tokens_per_second * 30: # 30 seconds max
break

# Concatenate all sentences
final_codes = torch.cat(all_codes, dim=-1)
final_codes = final_codes.to(torch.long)

# Decode to audio
wav_out = model.autoencoder.decode(final_codes).cpu().detach()
return model.autoencoder.sampling_rate, wav_out.squeeze().numpy()


def generate_audio(
model_choice,
text,
language,
speaker_audio,
prefix_audio,
e1,
e2,
e3,
e4,
e5,
e6,
e7,
e8,
e1, e2, e3, e4, e5, e6, e7, e8,
vq_single,
fmax,
pitch_std,
Expand All @@ -107,6 +196,9 @@ def generate_audio(
seed,
randomize_seed,
unconditional_keys,
use_windowing,
window_size,
window_overlap,
progress=gr.Progress(),
):
"""
Expand Down Expand Up @@ -177,21 +269,50 @@ def update_progress(_frame: torch.Tensor, step: int, _total_steps: int) -> bool:
progress((step, estimated_total_steps))
return True

codes = selected_model.generate(
prefix_conditioning=conditioning,
audio_prefix_codes=audio_prefix_codes,
max_new_tokens=max_new_tokens,
cfg_scale=cfg_scale,
batch_size=1,
sampling_params=dict(min_p=min_p),
callback=update_progress,
)
if use_windowing:
return generate_with_latent_windows(
selected_model,
text,
cond_dict,
overlap_seconds=window_overlap,
cfg_scale=cfg_scale,
min_p=min_p,
seed=seed
), seed
else:
codes = selected_model.generate(
prefix_conditioning=conditioning,
audio_prefix_codes=audio_prefix_codes,
max_new_tokens=max_new_tokens,
cfg_scale=cfg_scale,
batch_size=1,
sampling_params=dict(min_p=min_p),
callback=update_progress,
)

wav_out = selected_model.autoencoder.decode(codes).cpu().detach()
sr_out = selected_model.autoencoder.sampling_rate
if wav_out.dim() == 2 and wav_out.size(0) > 1:
wav_out = wav_out[0:1, :]
return (sr_out, wav_out.squeeze().numpy()), seed

wav_out = selected_model.autoencoder.decode(codes).cpu().detach()
sr_out = selected_model.autoencoder.sampling_rate
if wav_out.dim() == 2 and wav_out.size(0) > 1:
wav_out = wav_out[0:1, :]
return (sr_out, wav_out.squeeze().numpy()), seed

def validate_window_params(window_size: float, window_overlap: float) -> tuple[str, bool]:
if window_overlap >= window_size:
return "Overlap size must be smaller than window size", False
if window_size > 60:
return "Window size cannot exceed 60 seconds", False
if window_overlap < 0.1:
return "Overlap must be at least 0.1 seconds", False
return "", True


def on_window_change(window_size: float, window_overlap: float):
error_msg, is_valid = validate_window_params(window_size, window_overlap)
return {
generate_button: gr.Button(interactive=is_valid),
error_text: error_msg
}


def build_interface():
Expand Down Expand Up @@ -283,6 +404,12 @@ def build_interface():
emotion8 = gr.Slider(0.0, 1.0, 0.2, 0.05, label="Neutral")

with gr.Column():
gr.Markdown("## Windowing Parameters")
window_size = gr.Slider(1.0, 10.0, value=3.0, step=0.5, label="Window Size (seconds)")
window_overlap = gr.Slider(0.1, 2.0, value=0.3, step=0.1, label="Window Overlap (seconds)")
use_windowing = gr.Checkbox(label="Enable Latent Windowing", value=False)
error_text = gr.Textbox(label="Error", visible=True, interactive=False)

generate_button = gr.Button("Generate Audio")
output_audio = gr.Audio(label="Generated Audio", type="numpy", autoplay=True)

Expand Down Expand Up @@ -367,10 +494,25 @@ def build_interface():
seed_number,
randomize_seed_toggle,
unconditional_keys,
use_windowing,
window_size,
window_overlap,
],
outputs=[output_audio, seed_number],
)

with gr.Column():
window_size.change(
fn=on_window_change,
inputs=[window_size, window_overlap],
outputs=[generate_button, error_text]
)
window_overlap.change(
fn=on_window_change,
inputs=[window_size, window_overlap],
outputs=[generate_button, error_text]
)

return demo


Expand Down
41 changes: 37 additions & 4 deletions zonos/codebook_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,44 @@
import torch.nn.functional as F


def apply_delay_pattern(codes: torch.Tensor, mask_token: int):
def apply_delay_pattern(codes: torch.Tensor, mask_token: int, prev_chunk_end: torch.Tensor | None = None):
"""
Apply delay pattern with optional previous chunk ending.

Args:
codes: Input codes [batch_size, n_codebooks, seq_len]
mask_token: Token to use for masking
prev_chunk_end: Optional last n_codebooks tokens from previous chunk
"""
if prev_chunk_end is not None:
# Prepend previous chunk ending
codes = torch.cat([prev_chunk_end, codes], dim=-1)

codes = F.pad(codes, (0, codes.shape[1]), value=mask_token)
return torch.stack([codes[:, k].roll(k + 1) for k in range(codes.shape[1])], dim=1)


def revert_delay_pattern(codes: torch.Tensor):
def revert_delay_pattern(codes: torch.Tensor, remove_overlap: bool = False):
"""
Revert delay pattern, optionally removing overlap region.
"""
_, n_q, seq_len = codes.shape
return torch.stack([codes[:, k, k + 1 : seq_len - n_q + k + 1] for k in range(n_q)], dim=1)
reverted = torch.stack([codes[:, k, k + 1 : seq_len - n_q + k + 1] for k in range(n_q)], dim=1)

if remove_overlap:
reverted = reverted[..., n_q:]

return reverted

def interpolate_latents(codes1: torch.Tensor, codes2: torch.Tensor, overlap_size: int) -> torch.Tensor:
"""
Smoothly interpolate between two chunks of codes in their overlap region.

Args:
codes1: First chunk of codes ending with overlap region
codes2: Second chunk of codes starting with overlap region
overlap_size: Size of the overlap region
"""
weights = torch.linspace(1, 0, overlap_size, device=codes1.device).view(1, 1, -1)
overlap1 = codes1[..., -overlap_size:]
overlap2 = codes2[..., :overlap_size]
return overlap1 * weights + overlap2 * (1 - weights)
11 changes: 11 additions & 0 deletions zonos/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,17 @@ def apply_cond(self, texts: list[str], languages: list[str]) -> torch.Tensor:

return phoneme_embeds

class ChunkedEspeakConditioner(EspeakPhonemeConditioner):
def apply_cond(self, texts: list[str], languages: list[str], chunk_size: int = 1000, overlap: int = 200) -> torch.Tensor:
# Split text into chunks with overlap
chunked_texts = self._chunk_texts(texts, chunk_size, overlap)
# Process each chunk
chunk_embeddings = []
for chunk in chunked_texts:
emb = super().apply_cond(chunk, languages)
chunk_embeddings.append(emb)
# Interpolate overlapping regions
return self._interpolate_embeddings(chunk_embeddings, overlap)

# ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------

Expand Down
9 changes: 9 additions & 0 deletions zonos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,19 @@ class PrefixConditionerConfig:
projection: Literal["none", "linear", "mlp"]


@dataclass
class GenerationConfig:
default_chunk_size: int = 86 * 10 # 10 seconds
default_overlap_size: int = 86 * 2 # 2 seconds
max_tokens: int = 86 * 60 # 60 seconds
interpolation_method: Literal["linear", "cosine"] = "linear"


@dataclass
class ZonosConfig:
backbone: BackboneConfig
prefix_conditioner: PrefixConditionerConfig
generation: GenerationConfig = field(default_factory=GenerationConfig)
eos_token_id: int = 1024
masked_token_id: int = 1025

Expand Down
Loading