diff --git a/rustbot/Cargo.toml b/rustbot/Cargo.toml index 7a7f3c5..ccbdbef 100644 --- a/rustbot/Cargo.toml +++ b/rustbot/Cargo.toml @@ -15,3 +15,5 @@ rand = "0.8" ollama-rs = { version = "0.1", features = ["stream"] } bytes = { version = "1", features = ["serde"] } clap = { version = "4.5.4", features = ["derive"] } +playht_rs = "0.2.0" +rodio = "0.17.3" diff --git a/rustbot/src/audio.rs b/rustbot/src/audio.rs new file mode 100644 index 0000000..6132be7 --- /dev/null +++ b/rustbot/src/audio.rs @@ -0,0 +1,86 @@ +use crate::prelude::*; +use bytes::BytesMut; +use rodio::{Decoder, Sink}; +use std::io::Cursor; +use tokio::{ + self, + io::{self, AsyncReadExt}, + sync::watch, + time::{self, Duration, Instant}, +}; + +pub async fn play( + mut audio_rd: io::DuplexStream, + sink: Sink, + audio_done: watch::Sender, + mut done: watch::Receiver, +) -> Result<()> { + println!("launching audio player"); + let mut audio_data = BytesMut::new(); + // TODO: make this a cli switch as this value has been picked rather arbitrarily + let interval_duration = Duration::from_millis(AUDIO_INTERVAL); + let mut interval = time::interval(interval_duration); + let mut last_play_time = Instant::now(); + let mut has_played_audio = false; + + loop { + tokio::select! { + _ = done.changed() => { + if *done.borrow() { + break; + } + } + result = audio_rd.read_buf(&mut audio_data) => { + if let Ok(chunk) = result { + if chunk == 0 { + break; + } + if audio_data.len() > AUDIO_BUFFER_SIZE { + // NOTE: this avoids unnecessary data duplication and manages the buffer efficiently + let cursor = Cursor::new(audio_data.split_to(AUDIO_BUFFER_SIZE).freeze().to_vec()); + match Decoder::new(cursor) { + Ok(source) => { + sink.append(source); + last_play_time = Instant::now(); + has_played_audio = true; + } + Err(e) => { + eprintln!("Failed to decode received audio: {}", e); + } + } + } + } + } + _ = interval.tick() => { + // No audio data received in the past interval_duration ms and we've + // already played some audio -- that means we can proceed with dialogue + // by writing a followup question into JetStream through jet::writer. + if has_played_audio && last_play_time.elapsed() >= interval_duration && sink.empty() { + if !audio_data.is_empty() { + let cursor = Cursor::new(audio_data.clone().freeze().to_vec()); + if let Ok(source) = Decoder::new(cursor) { + sink.append(source); + audio_data.clear(); + } + } + sink.sleep_until_end(); + // NOTE: notify jet::writer + audio_done.send(true)?; + has_played_audio = false; + } + } + } + } + + // Flush any remaining data + if !audio_data.is_empty() { + let cursor = Cursor::new(audio_data.clone().to_vec()); + if let Ok(source) = Decoder::new(cursor) { + sink.append(source); + } + } + if !sink.empty() { + sink.sleep_until_end(); + } + Ok(()) +} diff --git a/rustbot/src/buffer.rs b/rustbot/src/buffer.rs new file mode 100644 index 0000000..e60d665 --- /dev/null +++ b/rustbot/src/buffer.rs @@ -0,0 +1,52 @@ +use bytes::{BufMut, Bytes, BytesMut}; +use std::error::Error; +use std::fmt; + +#[derive(Debug)] +pub struct BufferFullError { + pub bytes_written: usize, +} + +impl fmt::Display for BufferFullError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "buffer is full, {} bytes written", self.bytes_written) + } +} + +impl Error for BufferFullError {} + +pub struct Buffer { + buffer: BytesMut, + max_size: usize, +} + +impl Buffer { + pub fn new(max_size: usize) -> Self { + Buffer { + buffer: BytesMut::with_capacity(max_size), + max_size, + } + } + + pub fn write(&mut self, data: &[u8]) -> Result { + let available = self.max_size - self.buffer.len(); + let write_len = std::cmp::min(data.len(), available); + + self.buffer.put_slice(&data[..write_len]); + + if self.buffer.len() == self.max_size { + return Err(BufferFullError { + bytes_written: write_len, + }); + } + Ok(write_len) + } + + pub fn reset(&mut self) { + self.buffer.clear(); + } + + pub fn as_bytes(&self) -> Bytes { + self.buffer.clone().freeze() + } +} diff --git a/rustbot/src/cli.rs b/rustbot/src/cli.rs index e041708..a1e0f92 100644 --- a/rustbot/src/cli.rs +++ b/rustbot/src/cli.rs @@ -10,12 +10,12 @@ pub struct App { pub llm: LLM, #[command(flatten)] pub bot: Bot, + #[command(flatten)] + pub tts: TTS, } #[derive(Args, Debug)] pub struct Prompt { - #[arg(long, default_value = DEFAULT_SYSTEM_PROMPT, help = "system prompt")] - pub system: Option, #[arg(long, default_value = DEFAULT_SEED_PROMPT, help = "instruction prompt")] pub seed: Option, } @@ -40,3 +40,9 @@ pub struct Bot { #[arg(short = 'b', long, default_value = BOT_SUB_SUBJECT, help = "jetstream subscribe subject")] pub sub_subject: String, } + +#[derive(Args, Debug)] +pub struct TTS { + #[arg(short, default_value = DEFAULT_VOICE_ID, help = "bot name")] + pub voice_id: String, +} diff --git a/rustbot/src/jet.rs b/rustbot/src/jet.rs index 6580038..3d45fce 100644 --- a/rustbot/src/jet.rs +++ b/rustbot/src/jet.rs @@ -114,6 +114,7 @@ impl Writer { pub async fn write( self, mut chunks: Receiver, + mut audio_done: watch::Receiver, mut done: watch::Receiver, ) -> Result<()> { println!("launching JetStream Writer"); @@ -129,10 +130,18 @@ impl Writer { if chunk.is_empty() { let msg = String::from_utf8(b.to_vec()).unwrap(); println!("\n[A]: {}", msg); - self.tx.publish(self.subject.to_string(), b.clone().freeze()) - .await?; - b.clear(); - continue; + loop { + tokio::select! { + _ = audio_done.changed() => { + if *audio_done.borrow() { + self.tx.publish(self.subject.to_string(), b.clone().freeze()) + .await?; + b.clear(); + break; + } + }, + } + } } b.extend_from_slice(&chunk); } diff --git a/rustbot/src/llm.rs b/rustbot/src/llm.rs index 4a065f9..f70c7e2 100644 --- a/rustbot/src/llm.rs +++ b/rustbot/src/llm.rs @@ -5,6 +5,7 @@ use tokio::{ self, sync::mpsc::{Receiver, Sender}, sync::watch, + task::JoinHandle, }; use tokio_stream::StreamExt; @@ -46,7 +47,8 @@ impl LLM { pub async fn stream( self, mut prompts: Receiver, - chunks: Sender, + jet_chunks: Sender, + tts_chunks: Sender, mut done: watch::Receiver, ) -> Result<()> { println!("launching LLM stream"); @@ -77,7 +79,25 @@ impl LLM { while let Some(res) = stream.next().await { let responses = res?; for resp in responses { - chunks.send(Bytes::from(resp.response)).await?; + let resp_bytes = Bytes::from(resp.response); + let jet_bytes = resp_bytes.clone(); + let jet_ch = jet_chunks.clone(); + let jet_task: JoinHandle> = tokio::spawn(async move { + jet_ch.send(Bytes::from(jet_bytes)).await?; + Ok(()) + }); + let tts_bytes = resp_bytes.clone(); + let tts_ch = tts_chunks.clone(); + let tts_task: JoinHandle> = tokio::spawn(async move { + tts_ch.send(Bytes::from(tts_bytes)).await?; + Ok(()) + }); + match tokio::try_join!(jet_task, tts_task) { + Ok(_) => {} + Err(e) => { + return Err(Box::new(e)); + } + } } } }, diff --git a/rustbot/src/main.rs b/rustbot/src/main.rs index be1a666..43ec2fe 100644 --- a/rustbot/src/main.rs +++ b/rustbot/src/main.rs @@ -1,27 +1,29 @@ use bytes::Bytes; use clap::Parser; use prelude::*; +use rodio::{OutputStream, Sink}; use tokio::{ - self, signal, + self, io, signal, sync::{mpsc, watch}, task::JoinHandle, }; +mod audio; +mod buffer; mod cli; mod history; mod jet; mod llm; mod prelude; +mod tts; #[tokio::main] async fn main() -> Result<()> { let args = cli::App::parse(); - let system_prompt = args.prompt.system.unwrap(); let seed_prompt = args.prompt.seed.unwrap(); - let prompt = system_prompt + "\n" + &seed_prompt; - // NOTE: we could also add Stream::builder to jet module + // NOTE: we could also add Stream::builder to the jet module // and instead of passing config we could build it by chaining methods. let c = jet::Config { durable_name: args.bot.name, @@ -32,29 +34,47 @@ async fn main() -> Result<()> { }; let s = jet::Stream::new(c).await?; - // NOTE: we could also add LLM::builder to llm module + // NOTE: we could also add LLM::builder to the llm module // and instead of passing config we could build it by chaining methods. let c = llm::Config { hist_size: args.llm.hist_size, model_name: args.llm.model_name, - seed_prompt: Some(prompt), + seed_prompt: Some(seed_prompt), ..llm::Config::default() }; let l = llm::LLM::new(c); + // NOTE: we could also add TTS::builder to the tts module + // and instead of passing config we could build it by chaining methods. + let c = tts::Config { + voice_id: Some(args.tts.voice_id), + ..tts::Config::default() + }; + let t = tts::TTS::new(c); + let (prompts_tx, prompts_rx) = mpsc::channel::(32); - let (chunks_tx, chunks_rx) = mpsc::channel::(32); + let (jet_chunks_tx, jet_chunks_rx) = mpsc::channel::(32); + let (tts_chunks_tx, tts_chunks_rx) = mpsc::channel::(32); + let (aud_done_tx, aud_done_rx) = watch::channel(false); // NOTE: used for cancellation when SIGINT is trapped. let (watch_tx, watch_rx) = watch::channel(false); let jet_wr_watch_rx = watch_rx.clone(); let jet_rd_watch_rx = watch_rx.clone(); + let tts_watch_rx = watch_rx.clone(); + let aud_watch_rx = watch_rx.clone(); println!("launching workers"); - let llm_stream = tokio::spawn(l.stream(prompts_rx, chunks_tx, watch_rx)); - let jet_write = tokio::spawn(s.writer.write(chunks_rx, jet_wr_watch_rx)); + let (_stream, stream_handle) = OutputStream::try_default().unwrap(); + let sink = Sink::try_new(&stream_handle).unwrap(); + let (audio_wr, audio_rd) = io::duplex(1024); + + let tts_stream = tokio::spawn(t.stream(audio_wr, tts_chunks_rx, tts_watch_rx)); + let llm_stream = tokio::spawn(l.stream(prompts_rx, jet_chunks_tx, tts_chunks_tx, watch_rx)); + let jet_write = tokio::spawn(s.writer.write(jet_chunks_rx, aud_done_rx, jet_wr_watch_rx)); let jet_read = tokio::spawn(s.reader.read(prompts_tx, jet_rd_watch_rx)); + let audio_task = tokio::spawn(audio::play(audio_rd, sink, aud_done_tx, aud_watch_rx)); let sig_handler: JoinHandle> = tokio::spawn(async move { tokio::select! { _ = signal::ctrl_c() => { @@ -65,7 +85,7 @@ async fn main() -> Result<()> { Ok(()) }); - match tokio::try_join!(llm_stream, jet_write, jet_read) { + match tokio::try_join!(tts_stream, llm_stream, jet_write, jet_read, audio_task) { Ok(_) => {} Err(e) => { println!("Error running bot: {}", e); diff --git a/rustbot/src/prelude.rs b/rustbot/src/prelude.rs index 6ce9164..c5bde69 100644 --- a/rustbot/src/prelude.rs +++ b/rustbot/src/prelude.rs @@ -8,15 +8,21 @@ pub const BOT_NAME: &str = "rustbot"; pub const BOT_SUB_SUBJECT: &str = "rust"; pub const BOT_PUB_SUBJECT: &str = "go"; -pub const DEFAULT_SYSTEM_PROMPT: &str = "You are a Rust programming language expert \ +pub const DEFAULT_SEED_PROMPT: &str = "You are a Rust programming language expert \ and a helpful AI assistant trying to learn about Go programming language. \ You will answer questions ONLY about Rust and ONLY ask questions about Go. \ - You do NOT explain how Go works, you ONLY compare Go to Rust. When you receive \ - the response you will evaluate it from a Rust programmer point of view and ask \ - followup questions about Go. NEVER use emojis in your answers! Your response \ - must NOT be longer than 100 words!"; -pub const DEFAULT_SEED_PROMPT: &str = "Question: What is the biggest strength of Rust? + You do NOT explain how Go works. You are NOT Go expert! You ONLY compare Go \ + to Rust. When you receive the response you will evaluate it from a Rust programmer \ + point of view and ask followup questions about Go. NEVER use emojis in your answers! \ + Your answers must NOT be longer than 100 words! \ + Question: What is the biggest strength of Rust? Assistant: Rust's biggest strength lies in its focus on safety, particularly memory \ safety, without sacrificing performance. Can you tell me what are some of the biggest \ strengths of Go that make it stand out from other programming languages? Question: "; + +pub const DEFAULT_VOICE_ID: &str = + "s3://voice-cloning-zero-shot/b3def996-302e-486f-a234-172fa0279f0e/anthonysaad/manifest.json"; +pub const MAX_TTS_BUFFER_SIZE: usize = 1000; +pub const AUDIO_BUFFER_SIZE: usize = 1024 * 10; +pub const AUDIO_INTERVAL: u64 = 200; diff --git a/rustbot/src/tts.rs b/rustbot/src/tts.rs new file mode 100644 index 0000000..58b6848 --- /dev/null +++ b/rustbot/src/tts.rs @@ -0,0 +1,90 @@ +use crate::{buffer, prelude::*}; +use bytes::Bytes; +use playht_rs::api::{self, stream::TTSStreamReq, tts::Quality}; +use tokio::{self, sync::mpsc::Receiver, sync::watch}; + +#[derive(Debug, Clone)] +pub struct Config { + pub voice_id: Option, + pub quality: Option, + pub speed: Option, + pub sample_rate: Option, + pub buf_size: usize, +} + +impl Default for Config { + fn default() -> Self { + Config { + voice_id: Some(DEFAULT_VOICE_ID.to_string()), + quality: Some(Quality::Low), + speed: Some(1.0), + sample_rate: Some(24000), + buf_size: MAX_TTS_BUFFER_SIZE, + } + } +} + +pub struct TTS { + client: api::Client, + config: Config, +} + +impl TTS { + pub fn new(c: Config) -> TTS { + TTS { + client: api::Client::new(), + config: c, + } + } + + pub async fn stream( + self, + mut w: W, + mut chunks: Receiver, + mut done: watch::Receiver, + ) -> Result<()> + where + W: tokio::io::AsyncWriteExt + Unpin, + { + println!("launching TTS stream"); + let mut buf = buffer::Buffer::new(self.config.buf_size); + let mut req = TTSStreamReq { + voice: self.config.voice_id, + quality: self.config.quality, + speed: self.config.speed, + sample_rate: self.config.sample_rate, + ..Default::default() + }; + + loop { + tokio::select! { + _ = done.changed() => { + if *done.borrow() { + return Ok(()) + } + }, + Some(chunk) = chunks.recv() => { + if chunk.is_empty() { + let text = String::from_utf8(buf.as_bytes().to_vec())?; + req.text = Some(text); + self.client.write_audio_stream(&mut w, &req).await?; + buf.reset(); + continue + } + match buf.write(chunk.as_ref()) { + Ok(_) => {}, + Err(e) => { + let text = String::from_utf8(buf.as_bytes().to_vec())?; + req.text = Some(text); + self.client.write_audio_stream(&mut w, &req).await?; + buf.reset(); + let rem = chunk.len() - e.bytes_written; + let chunk_slice = chunk.as_ref(); + buf.write(&chunk_slice[rem..])?; + } + } + } + } + } + } +}