Skip to content

Commit

Permalink
feat: give rustbot a voice (#7)
Browse files Browse the repository at this point in the history
Just like with gobot we use playht API and stream the audio on the default audio device using
rodio crate.. We introduce a tts module that handles the TTS tasks in rustbot. 
We use the playht_rs crate for TTS synthesis and stream the audio to the default audio device.

Signed-off-by: Milos Gajdos <[email protected]>
  • Loading branch information
milosgajdos authored Apr 23, 2024
1 parent c8a9c4d commit 77ce5c0
Show file tree
Hide file tree
Showing 9 changed files with 315 additions and 24 deletions.
2 changes: 2 additions & 0 deletions rustbot/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ rand = "0.8"
ollama-rs = { version = "0.1", features = ["stream"] }
bytes = { version = "1", features = ["serde"] }
clap = { version = "4.5.4", features = ["derive"] }
playht_rs = "0.2.0"
rodio = "0.17.3"
86 changes: 86 additions & 0 deletions rustbot/src/audio.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use crate::prelude::*;
use bytes::BytesMut;
use rodio::{Decoder, Sink};
use std::io::Cursor;
use tokio::{
self,
io::{self, AsyncReadExt},
sync::watch,
time::{self, Duration, Instant},
};

pub async fn play(
mut audio_rd: io::DuplexStream,
sink: Sink,
audio_done: watch::Sender<bool>,
mut done: watch::Receiver<bool>,
) -> Result<()> {
println!("launching audio player");
let mut audio_data = BytesMut::new();
// TODO: make this a cli switch as this value has been picked rather arbitrarily
let interval_duration = Duration::from_millis(AUDIO_INTERVAL);
let mut interval = time::interval(interval_duration);
let mut last_play_time = Instant::now();
let mut has_played_audio = false;

loop {
tokio::select! {
_ = done.changed() => {
if *done.borrow() {
break;
}
}
result = audio_rd.read_buf(&mut audio_data) => {
if let Ok(chunk) = result {
if chunk == 0 {
break;
}
if audio_data.len() > AUDIO_BUFFER_SIZE {
// NOTE: this avoids unnecessary data duplication and manages the buffer efficiently
let cursor = Cursor::new(audio_data.split_to(AUDIO_BUFFER_SIZE).freeze().to_vec());
match Decoder::new(cursor) {
Ok(source) => {
sink.append(source);
last_play_time = Instant::now();
has_played_audio = true;
}
Err(e) => {
eprintln!("Failed to decode received audio: {}", e);
}
}
}
}
}
_ = interval.tick() => {
// No audio data received in the past interval_duration ms and we've
// already played some audio -- that means we can proceed with dialogue
// by writing a followup question into JetStream through jet::writer.
if has_played_audio && last_play_time.elapsed() >= interval_duration && sink.empty() {
if !audio_data.is_empty() {
let cursor = Cursor::new(audio_data.clone().freeze().to_vec());
if let Ok(source) = Decoder::new(cursor) {
sink.append(source);
audio_data.clear();
}
}
sink.sleep_until_end();
// NOTE: notify jet::writer
audio_done.send(true)?;
has_played_audio = false;
}
}
}
}

// Flush any remaining data
if !audio_data.is_empty() {
let cursor = Cursor::new(audio_data.clone().to_vec());
if let Ok(source) = Decoder::new(cursor) {
sink.append(source);
}
}
if !sink.empty() {
sink.sleep_until_end();
}
Ok(())
}
52 changes: 52 additions & 0 deletions rustbot/src/buffer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use bytes::{BufMut, Bytes, BytesMut};
use std::error::Error;
use std::fmt;

#[derive(Debug)]
pub struct BufferFullError {
pub bytes_written: usize,
}

impl fmt::Display for BufferFullError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "buffer is full, {} bytes written", self.bytes_written)
}
}

impl Error for BufferFullError {}

pub struct Buffer {
buffer: BytesMut,
max_size: usize,
}

impl Buffer {
pub fn new(max_size: usize) -> Self {
Buffer {
buffer: BytesMut::with_capacity(max_size),
max_size,
}
}

pub fn write(&mut self, data: &[u8]) -> Result<usize, BufferFullError> {
let available = self.max_size - self.buffer.len();
let write_len = std::cmp::min(data.len(), available);

self.buffer.put_slice(&data[..write_len]);

if self.buffer.len() == self.max_size {
return Err(BufferFullError {
bytes_written: write_len,
});
}
Ok(write_len)
}

pub fn reset(&mut self) {
self.buffer.clear();
}

pub fn as_bytes(&self) -> Bytes {
self.buffer.clone().freeze()
}
}
10 changes: 8 additions & 2 deletions rustbot/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ pub struct App {
pub llm: LLM,
#[command(flatten)]
pub bot: Bot,
#[command(flatten)]
pub tts: TTS,
}

#[derive(Args, Debug)]
pub struct Prompt {
#[arg(long, default_value = DEFAULT_SYSTEM_PROMPT, help = "system prompt")]
pub system: Option<String>,
#[arg(long, default_value = DEFAULT_SEED_PROMPT, help = "instruction prompt")]
pub seed: Option<String>,
}
Expand All @@ -40,3 +40,9 @@ pub struct Bot {
#[arg(short = 'b', long, default_value = BOT_SUB_SUBJECT, help = "jetstream subscribe subject")]
pub sub_subject: String,
}

#[derive(Args, Debug)]
pub struct TTS {
#[arg(short, default_value = DEFAULT_VOICE_ID, help = "bot name")]
pub voice_id: String,
}
17 changes: 13 additions & 4 deletions rustbot/src/jet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ impl Writer {
pub async fn write(
self,
mut chunks: Receiver<Bytes>,
mut audio_done: watch::Receiver<bool>,
mut done: watch::Receiver<bool>,
) -> Result<()> {
println!("launching JetStream Writer");
Expand All @@ -129,10 +130,18 @@ impl Writer {
if chunk.is_empty() {
let msg = String::from_utf8(b.to_vec()).unwrap();
println!("\n[A]: {}", msg);
self.tx.publish(self.subject.to_string(), b.clone().freeze())
.await?;
b.clear();
continue;
loop {
tokio::select! {
_ = audio_done.changed() => {
if *audio_done.borrow() {
self.tx.publish(self.subject.to_string(), b.clone().freeze())
.await?;
b.clear();
break;
}
},
}
}
}
b.extend_from_slice(&chunk);
}
Expand Down
24 changes: 22 additions & 2 deletions rustbot/src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use tokio::{
self,
sync::mpsc::{Receiver, Sender},
sync::watch,
task::JoinHandle,
};
use tokio_stream::StreamExt;

Expand Down Expand Up @@ -46,7 +47,8 @@ impl LLM {
pub async fn stream(
self,
mut prompts: Receiver<String>,
chunks: Sender<Bytes>,
jet_chunks: Sender<Bytes>,
tts_chunks: Sender<Bytes>,
mut done: watch::Receiver<bool>,
) -> Result<()> {
println!("launching LLM stream");
Expand Down Expand Up @@ -77,7 +79,25 @@ impl LLM {
while let Some(res) = stream.next().await {
let responses = res?;
for resp in responses {
chunks.send(Bytes::from(resp.response)).await?;
let resp_bytes = Bytes::from(resp.response);
let jet_bytes = resp_bytes.clone();
let jet_ch = jet_chunks.clone();
let jet_task: JoinHandle<Result<()>> = tokio::spawn(async move {
jet_ch.send(Bytes::from(jet_bytes)).await?;
Ok(())
});
let tts_bytes = resp_bytes.clone();
let tts_ch = tts_chunks.clone();
let tts_task: JoinHandle<Result<()>> = tokio::spawn(async move {
tts_ch.send(Bytes::from(tts_bytes)).await?;
Ok(())
});
match tokio::try_join!(jet_task, tts_task) {
Ok(_) => {}
Err(e) => {
return Err(Box::new(e));
}
}
}
}
},
Expand Down
40 changes: 30 additions & 10 deletions rustbot/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
use bytes::Bytes;
use clap::Parser;
use prelude::*;
use rodio::{OutputStream, Sink};
use tokio::{
self, signal,
self, io, signal,
sync::{mpsc, watch},
task::JoinHandle,
};

mod audio;
mod buffer;
mod cli;
mod history;
mod jet;
mod llm;
mod prelude;
mod tts;

#[tokio::main]
async fn main() -> Result<()> {
let args = cli::App::parse();

let system_prompt = args.prompt.system.unwrap();
let seed_prompt = args.prompt.seed.unwrap();
let prompt = system_prompt + "\n" + &seed_prompt;

// NOTE: we could also add Stream::builder to jet module
// NOTE: we could also add Stream::builder to the jet module
// and instead of passing config we could build it by chaining methods.
let c = jet::Config {
durable_name: args.bot.name,
Expand All @@ -32,29 +34,47 @@ async fn main() -> Result<()> {
};
let s = jet::Stream::new(c).await?;

// NOTE: we could also add LLM::builder to llm module
// NOTE: we could also add LLM::builder to the llm module
// and instead of passing config we could build it by chaining methods.
let c = llm::Config {
hist_size: args.llm.hist_size,
model_name: args.llm.model_name,
seed_prompt: Some(prompt),
seed_prompt: Some(seed_prompt),
..llm::Config::default()
};
let l = llm::LLM::new(c);

// NOTE: we could also add TTS::builder to the tts module
// and instead of passing config we could build it by chaining methods.
let c = tts::Config {
voice_id: Some(args.tts.voice_id),
..tts::Config::default()
};
let t = tts::TTS::new(c);

let (prompts_tx, prompts_rx) = mpsc::channel::<String>(32);
let (chunks_tx, chunks_rx) = mpsc::channel::<Bytes>(32);
let (jet_chunks_tx, jet_chunks_rx) = mpsc::channel::<Bytes>(32);
let (tts_chunks_tx, tts_chunks_rx) = mpsc::channel::<Bytes>(32);
let (aud_done_tx, aud_done_rx) = watch::channel(false);

// NOTE: used for cancellation when SIGINT is trapped.
let (watch_tx, watch_rx) = watch::channel(false);
let jet_wr_watch_rx = watch_rx.clone();
let jet_rd_watch_rx = watch_rx.clone();
let tts_watch_rx = watch_rx.clone();
let aud_watch_rx = watch_rx.clone();

println!("launching workers");

let llm_stream = tokio::spawn(l.stream(prompts_rx, chunks_tx, watch_rx));
let jet_write = tokio::spawn(s.writer.write(chunks_rx, jet_wr_watch_rx));
let (_stream, stream_handle) = OutputStream::try_default().unwrap();
let sink = Sink::try_new(&stream_handle).unwrap();
let (audio_wr, audio_rd) = io::duplex(1024);

let tts_stream = tokio::spawn(t.stream(audio_wr, tts_chunks_rx, tts_watch_rx));
let llm_stream = tokio::spawn(l.stream(prompts_rx, jet_chunks_tx, tts_chunks_tx, watch_rx));
let jet_write = tokio::spawn(s.writer.write(jet_chunks_rx, aud_done_rx, jet_wr_watch_rx));
let jet_read = tokio::spawn(s.reader.read(prompts_tx, jet_rd_watch_rx));
let audio_task = tokio::spawn(audio::play(audio_rd, sink, aud_done_tx, aud_watch_rx));
let sig_handler: JoinHandle<Result<()>> = tokio::spawn(async move {
tokio::select! {
_ = signal::ctrl_c() => {
Expand All @@ -65,7 +85,7 @@ async fn main() -> Result<()> {
Ok(())
});

match tokio::try_join!(llm_stream, jet_write, jet_read) {
match tokio::try_join!(tts_stream, llm_stream, jet_write, jet_read, audio_task) {
Ok(_) => {}
Err(e) => {
println!("Error running bot: {}", e);
Expand Down
18 changes: 12 additions & 6 deletions rustbot/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,21 @@ pub const BOT_NAME: &str = "rustbot";
pub const BOT_SUB_SUBJECT: &str = "rust";
pub const BOT_PUB_SUBJECT: &str = "go";

pub const DEFAULT_SYSTEM_PROMPT: &str = "You are a Rust programming language expert \
pub const DEFAULT_SEED_PROMPT: &str = "You are a Rust programming language expert \
and a helpful AI assistant trying to learn about Go programming language. \
You will answer questions ONLY about Rust and ONLY ask questions about Go. \
You do NOT explain how Go works, you ONLY compare Go to Rust. When you receive \
the response you will evaluate it from a Rust programmer point of view and ask \
followup questions about Go. NEVER use emojis in your answers! Your response \
must NOT be longer than 100 words!";
pub const DEFAULT_SEED_PROMPT: &str = "Question: What is the biggest strength of Rust?
You do NOT explain how Go works. You are NOT Go expert! You ONLY compare Go \
to Rust. When you receive the response you will evaluate it from a Rust programmer \
point of view and ask followup questions about Go. NEVER use emojis in your answers! \
Your answers must NOT be longer than 100 words! \
Question: What is the biggest strength of Rust?
Assistant: Rust's biggest strength lies in its focus on safety, particularly memory \
safety, without sacrificing performance. Can you tell me what are some of the biggest \
strengths of Go that make it stand out from other programming languages?
Question: ";

pub const DEFAULT_VOICE_ID: &str =
"s3://voice-cloning-zero-shot/b3def996-302e-486f-a234-172fa0279f0e/anthonysaad/manifest.json";
pub const MAX_TTS_BUFFER_SIZE: usize = 1000;
pub const AUDIO_BUFFER_SIZE: usize = 1024 * 10;
pub const AUDIO_INTERVAL: u64 = 200;
Loading

0 comments on commit 77ce5c0

Please sign in to comment.