diff --git a/crates/librqbit/src/api_error.rs b/crates/librqbit/src/api_error.rs index b3822f1d..0c2f8f48 100644 --- a/crates/librqbit/src/api_error.rs +++ b/crates/librqbit/src/api_error.rs @@ -55,6 +55,14 @@ impl ApiError { } } + pub const fn unathorized() -> Self { + Self { + status: Some(StatusCode::UNAUTHORIZED), + kind: ApiErrorKind::Unauthorized, + plaintext: true, + } + } + pub fn status(&self) -> StatusCode { self.status.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) } @@ -80,6 +88,7 @@ impl ApiError { enum ApiErrorKind { TorrentNotFound(TorrentIdOrHash), DhtDisabled, + Unauthorized, Text(&'static str), Other(anyhow::Error), } @@ -102,6 +111,7 @@ impl Serialize for ApiError { error_kind: match self.kind { ApiErrorKind::TorrentNotFound(_) => "torrent_not_found", ApiErrorKind::DhtDisabled => "dht_disabled", + ApiErrorKind::Unauthorized => "unathorized", ApiErrorKind::Other(_) => "internal_error", ApiErrorKind::Text(_) => "internal_error", }, @@ -142,6 +152,7 @@ impl std::fmt::Display for ApiError { match &self.kind { ApiErrorKind::TorrentNotFound(idx) => write!(f, "torrent {idx} not found"), ApiErrorKind::Other(err) => write!(f, "{err:?}"), + ApiErrorKind::Unauthorized => write!(f, "unathorized"), ApiErrorKind::DhtDisabled => write!(f, "DHT is disabled"), ApiErrorKind::Text(t) => write!(f, "{t}"), } diff --git a/crates/librqbit/src/http_api.rs b/crates/librqbit/src/http_api.rs index d89b0802..bd229bee 100644 --- a/crates/librqbit/src/http_api.rs +++ b/crates/librqbit/src/http_api.rs @@ -1,8 +1,10 @@ use anyhow::Context; use axum::body::Bytes; use axum::extract::{ConnectInfo, Path, Query, Request, State}; +use axum::middleware::Next; use axum::response::IntoResponse; use axum::routing::{get, post}; +use base64::Engine; use bencode::AsDisplay; use buffers::ByteBuf; use futures::future::BoxFuture; @@ -19,7 +21,7 @@ use std::time::Duration; use tokio::io::AsyncSeekExt; use tokio::net::TcpListener; use tower_http::trace::{DefaultOnFailure, DefaultOnResponse, OnFailure}; -use tracing::{debug, error_span, trace, Span}; +use tracing::{debug, error_span, info, trace, Span}; use axum::{Json, Router}; @@ -43,6 +45,41 @@ pub struct HttpApi { #[derive(Debug, Default)] pub struct HttpApiOptions { pub read_only: bool, + pub basic_auth: Option<(String, String)>, +} + +async fn simple_basic_auth( + expected_username: Option<&str>, + expected_password: Option<&str>, + headers: HeaderMap, + request: axum::extract::Request, + next: Next, +) -> Result { + let (expected_user, expected_pass) = match (expected_username, expected_password) { + (Some(u), Some(p)) => (u, p), + _ => return Ok(next.run(request).await), + }; + let user_pass = headers + .get("Authorization") + .and_then(|h| h.to_str().ok()) + .and_then(|h| h.strip_prefix("Basic ")) + .and_then(|v| base64::engine::general_purpose::STANDARD.decode(v).ok()) + .and_then(|v| String::from_utf8(v).ok()); + let user_pass = match user_pass { + Some(user_pass) => user_pass, + None => { + return Ok(( + StatusCode::UNAUTHORIZED, + [("WWW-Authenticate", "Basic realm=\"API\"")], + ) + .into_response()) + } + }; + // TODO: constant time compare + match user_pass.split_once(':') { + Some((u, p)) if u == expected_user && p == expected_pass => Ok(next.run(request).await), + _ => Err(ApiError::unathorized()), + } } impl HttpApi { @@ -57,7 +94,7 @@ impl HttpApi { /// If read_only is passed, no state-modifying methods will be exposed. #[inline(never)] pub fn make_http_api_and_run( - self, + mut self, listener: TcpListener, upnp_router: Option, ) -> BoxFuture<'static, anyhow::Result<()>> { @@ -615,6 +652,19 @@ impl HttpApi { let mut app = app.with_state(state); + // Simple one-user basic auth + if let Some((user, pass)) = self.opts.basic_auth.take() { + info!("Enabling simple basic authentication in HTTP API"); + app = + app.route_layer(axum::middleware::from_fn(move |headers, request, next| { + let user = user.clone(); + let pass = pass.clone(); + async move { + simple_basic_auth(Some(&user), Some(&pass), headers, request, next).await + } + })); + } + if let Some(upnp_router) = upnp_router { app = app.nest("/upnp", upnp_router); } diff --git a/crates/rqbit/src/main.rs b/crates/rqbit/src/main.rs index 726cfd7a..3089ad87 100644 --- a/crates/rqbit/src/main.rs +++ b/crates/rqbit/src/main.rs @@ -496,6 +496,15 @@ async fn async_main(opts: Opts, cancel: CancellationToken) -> anyhow::Result<()> }, }; + let http_api_basic_auth = if let Ok(up) = std::env::var("RQBIT_HTTP_BASIC_AUTH_USERPASS") { + let (u, p) = up + .split_once(":") + .context("basic auth credentials should be in format username:password")?; + Some((u.to_owned(), p.to_owned())) + } else { + None + }; + let stats_printer = |session: Arc| async move { loop { session.with_torrents(|torrents| { @@ -615,7 +624,13 @@ async fn async_main(opts: Opts, cancel: CancellationToken) -> anyhow::Result<()> Some(log_config.rust_log_reload_tx), Some(log_config.line_broadcast), ); - let http_api = HttpApi::new(api, Some(HttpApiOptions { read_only: false })); + let http_api = HttpApi::new( + api, + Some(HttpApiOptions { + read_only: false, + basic_auth: http_api_basic_auth, + }), + ); let http_api_listen_addr = opts.http_api_listen_addr; info!("starting HTTP API at http://{http_api_listen_addr}"); @@ -735,7 +750,13 @@ async fn async_main(opts: Opts, cancel: CancellationToken) -> anyhow::Result<()> Some(log_config.rust_log_reload_tx), Some(log_config.line_broadcast), ); - let http_api = HttpApi::new(api, Some(HttpApiOptions { read_only: true })); + let http_api = HttpApi::new( + api, + Some(HttpApiOptions { + read_only: true, + basic_auth: http_api_basic_auth, + }), + ); let http_api_listen_addr = opts.http_api_listen_addr; info!("starting HTTP API at http://{http_api_listen_addr}"); diff --git a/desktop/src-tauri/src/main.rs b/desktop/src-tauri/src/main.rs index 6bcece50..3b215ca9 100644 --- a/desktop/src-tauri/src/main.rs +++ b/desktop/src-tauri/src/main.rs @@ -157,7 +157,10 @@ async fn api_from_config( .with_context(|| format!("error listening on {}", listen_addr))?; librqbit::http_api::HttpApi::new( api.clone(), - Some(librqbit::http_api::HttpApiOptions { read_only }), + Some(librqbit::http_api::HttpApiOptions { + read_only, + basic_auth: None, + }), ) .make_http_api_and_run(listener, upnp_router) .await