diff --git a/src/handlers/data.rs b/src/handlers/data.rs new file mode 100644 index 0000000..515ecb1 --- /dev/null +++ b/src/handlers/data.rs @@ -0,0 +1,67 @@ +use axum::extract::{Query, State}; +use http::{ + header::{AUTHORIZATION, CACHE_CONTROL, RANGE}, + Request, Response, +}; +use hyper::Body; + +use crate::{ + error::AppError, + utils::{extract_password_from_basic_auth, remove_token_from_query}, + AppState, +}; + +use super::{Auth, AUDIO_CACHE_HEADER}; + +pub(crate) async fn get_music( + State(app): State, + Query(query): Query, + req: Request, +) -> Result, AppError> { + http(State(app), Query(query), req).await.map(|mut resp| { + if resp.status().is_success() { + // add cache header + resp.headers_mut() + .insert(CACHE_CONTROL, AUDIO_CACHE_HEADER.clone()); + } + resp + }) +} + +pub(crate) async fn http( + State(app): State, + Query(auth): Query, + req: Request, +) -> Result, AppError> { + let maybe_token = auth.token.or_else(|| { + req.headers() + .get(AUTHORIZATION) + .and_then(|h| h.to_str().ok()) + .and_then(|auth| extract_password_from_basic_auth(auth).ok()) + }); + app.verify_token(maybe_token).await?; + + // remove token from query + let path = req.uri().path(); + let query_map = remove_token_from_query(req.uri().query()); + let has_query = !query_map.is_empty(); + let query = has_query + .then(|| serde_qs::to_string(&query_map).unwrap()) + .unwrap_or_else(String::new); + let query_prefix = has_query.then_some("?").unwrap_or(""); + + let mut request = Request::new(Body::empty()); + if let Some(range) = req.headers().get(RANGE).cloned() { + request.headers_mut().insert(RANGE, range); + } + + tracing::debug!( + "proxying request to {}:{} with headers {:?}", + app.musikcubed_address, + app.musikcubed_http_port, + req.headers() + ); + + app.make_musikcubed_request(format!("{path}{query_prefix}{query}"), request) + .await +} diff --git a/src/handlers/internal.rs b/src/handlers/internal.rs new file mode 100644 index 0000000..d736ec0 --- /dev/null +++ b/src/handlers/internal.rs @@ -0,0 +1,28 @@ +use axum::{extract::State, response::IntoResponse}; +use http::StatusCode; + +use crate::{error::AppError, AppState}; + +pub(crate) async fn revoke_all_tokens(State(app): State) -> impl IntoResponse { + app.tokens.revoke_all().await; + tokio::spawn(async move { + if let Err(err) = app.tokens.write(&app.tokens_path).await { + tracing::error!("couldn't write tokens file: {err}"); + } + }); + StatusCode::OK +} + +pub(crate) async fn generate_token( + State(app): State, +) -> Result { + // generate token + let token = app.tokens.generate().await?; + // start task to write tokens + tokio::spawn(async move { + if let Err(err) = app.tokens.write(&app.tokens_path).await { + tracing::error!("couldn't write tokens file: {err}"); + } + }); + Ok(token.into_response()) +} diff --git a/src/handler.rs b/src/handlers/metadata.rs similarity index 55% rename from src/handler.rs rename to src/handlers/metadata.rs index b3f851b..31a99e6 100644 --- a/src/handler.rs +++ b/src/handlers/metadata.rs @@ -1,267 +1,28 @@ -use std::collections::HashMap; - -use super::AppError; use async_tungstenite::{ tokio::TokioAdapter, tungstenite::Message as TungsteniteMessage, WebSocketStream, }; use axum::{ extract::{ - ws::{Message as AxumMessage, WebSocket, WebSocketUpgrade}, - Path, Query, State, + ws::{Message as AxumMessage, WebSocket}, + State, WebSocketUpgrade, }, headers::UserAgent, response::IntoResponse, - routing::{get, post}, - Router, TypedHeader, + TypedHeader, }; -use base64::Engine; use futures::{SinkExt, StreamExt}; -use http::{ - header::{AUTHORIZATION, CACHE_CONTROL, CONTENT_TYPE, RANGE}, - HeaderName, HeaderValue, Method, Request, Response, StatusCode, -}; -use hyper::Body; -use serde::Deserialize; use serde_json::Value; use tokio::net::TcpStream; -use tower_http::{ - cors::CorsLayer, - request_id::{MakeRequestUuid, SetRequestIdLayer}, - sensitive_headers::SetSensitiveRequestHeadersLayer, - trace::TraceLayer, -}; use tracing::{Instrument, Span}; use crate::{ api::WsApiMessage, - utils::{axum_msg_to_tungstenite, tungstenite_msg_to_axum, QueryDisplay, WsError}, - AppState, B64, + error::AppError, + utils::{axum_msg_to_tungstenite, tungstenite_msg_to_axum, WsError}, + AppState, }; -const AUDIO_CACHE_HEADER: HeaderValue = HeaderValue::from_static("private, max-age=604800"); -const REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); - -#[derive(Deserialize)] -struct Auth { - #[serde(default)] - token: Option, -} - -fn extract_password_from_basic_auth(auth: &str) -> Result { - let decoded = B64.decode(auth.trim_start_matches("Basic "))?; - let auth = String::from_utf8(decoded)?; - Ok(auth.trim_start_matches("default:").to_string()) -} - -fn remove_token_from_query(query: Option<&str>) -> HashMap { - let mut query_map: HashMap = query - .and_then(|v| serde_qs::from_str(v).ok()) - .unwrap_or_else(HashMap::new); - query_map.remove("token"); - query_map -} - -fn make_span_trace(req: &Request) -> Span { - let query_map = remove_token_from_query(req.uri().query()); - - let request_id = req - .headers() - .get(REQUEST_ID) - .and_then(|v| v.to_str().ok()) - .unwrap_or("no id set"); - - if query_map.is_empty() { - tracing::debug_span!( - "request", - path = %req.uri().path(), - id = %request_id, - ) - } else { - let query_display = QueryDisplay::new(query_map); - tracing::debug_span!( - "request", - path = %req.uri().path(), - query = %query_display, - id = %request_id, - ) - } -} - -pub(super) async fn handler(state: AppState) -> Result<(Router, Router), AppError> { - let internal_router = Router::new() - .route("/token/generate", get(generate_token)) - .route("/token/revoke_all", post(revoke_all_tokens)) - .with_state(state.clone()); - - let trace_layer = TraceLayer::new_for_http() - .make_span_with(make_span_trace) - .on_request(|req: &Request, _span: &Span| { - tracing::debug!( - "started processing request {} on {:?}", - req.method(), - req.version(), - ) - }); - let cors_layer = CorsLayer::new() - .allow_origin(tower_http::cors::Any) - .allow_headers([CONTENT_TYPE, CACHE_CONTROL, REQUEST_ID]) - .allow_methods([Method::GET]); - let sensitive_header_layer = SetSensitiveRequestHeadersLayer::new([AUTHORIZATION]); - let request_id_layer = SetRequestIdLayer::new(REQUEST_ID.clone(), MakeRequestUuid); - - let router = Router::new() - .route("/token/generate_for_music/:id", get(generate_scoped_token)) - .route("/thumbnail/:id", get(http)) - .route("/audio/external_id/:id", get(get_music)) - .route("/share/audio/:token", get(get_scoped_music_file)) - .route("/share/thumbnail/:token", get(get_scoped_music_thumbnail)) - .route("/share/info/:token", get(get_scoped_music_info)) - .route("/", get(metadata_ws)) - .layer(trace_layer) - .layer(sensitive_header_layer) - .layer(cors_layer) - .layer(request_id_layer) - .with_state(state); - - Ok((router, internal_router)) -} - -async fn revoke_all_tokens(State(app): State) -> impl IntoResponse { - app.tokens.revoke_all().await; - tokio::spawn(async move { - if let Err(err) = app.tokens.write(&app.tokens_path).await { - tracing::error!("couldn't write tokens file: {err}"); - } - }); - StatusCode::OK -} - -async fn generate_token(State(app): State) -> Result { - // generate token - let token = app.tokens.generate().await?; - // start task to write tokens - tokio::spawn(async move { - if let Err(err) = app.tokens.write(&app.tokens_path).await { - tracing::error!("couldn't write tokens file: {err}"); - } - }); - Ok(token.into_response()) -} - -async fn generate_scoped_token( - State(app): State, - Query(query): Query, - Path(music_id): Path, -) -> Result { - app.verify_token(query.token).await?; - - // generate token - let token = app.scoped_tokens.generate_for_id(music_id).await; - Ok(token.into_response()) -} - -async fn get_scoped_music_info( - State(app): State, - Path(token): Path, -) -> Result { - let music_id = app.verify_scoped_token(token).await?; - let Some(info) = app.music_info.get(music_id).await else { - return Err("music id not found".into()); - }; - Ok(serde_json::to_string(&info).unwrap()) -} - -async fn get_scoped_music_thumbnail( - State(app): State, - Path(token): Path, -) -> Result, AppError> { - let music_id = app.verify_scoped_token(token).await?; - let Some(info) = app.music_info.get(music_id).await else { - return Err("music id not found".into()); - }; - app.make_musikcubed_request( - format!("thumbnail/{}", info.thumbnail_id), - Request::new(Body::empty()), - ) - .await -} - -async fn get_scoped_music_file( - State(app): State, - Path(token): Path, - request: Request, -) -> Result, AppError> { - let music_id = app.verify_scoped_token(token).await?; - let mut req = Request::new(Body::empty()); - // proxy any range headers - if let Some(range) = request.headers().get(RANGE).cloned() { - req.headers_mut().insert(RANGE, range); - } - let mut resp = app - .make_musikcubed_request(format!("audio/external_id/{music_id}"), req) - .await?; - if resp.status().is_success() { - // add cache header - resp.headers_mut() - .insert(CACHE_CONTROL, AUDIO_CACHE_HEADER.clone()); - } - Ok(resp) -} - -async fn get_music( - State(app): State, - Query(query): Query, - req: Request, -) -> Result, AppError> { - http(State(app), Query(query), req).await.map(|mut resp| { - if resp.status().is_success() { - // add cache header - resp.headers_mut() - .insert(CACHE_CONTROL, AUDIO_CACHE_HEADER.clone()); - } - resp - }) -} - -async fn http( - State(app): State, - Query(auth): Query, - req: Request, -) -> Result, AppError> { - let maybe_token = auth.token.or_else(|| { - req.headers() - .get(AUTHORIZATION) - .and_then(|h| h.to_str().ok()) - .and_then(|auth| extract_password_from_basic_auth(auth).ok()) - }); - app.verify_token(maybe_token).await?; - - // remove token from query - let path = req.uri().path(); - let query_map = remove_token_from_query(req.uri().query()); - let has_query = !query_map.is_empty(); - let query = has_query - .then(|| serde_qs::to_string(&query_map).unwrap()) - .unwrap_or_else(String::new); - let query_prefix = has_query.then_some("?").unwrap_or(""); - - let mut request = Request::new(Body::empty()); - if let Some(range) = req.headers().get(RANGE).cloned() { - request.headers_mut().insert(RANGE, range); - } - - tracing::debug!( - "proxying request to {}:{} with headers {:?}", - app.musikcubed_address, - app.musikcubed_http_port, - req.headers() - ); - - app.make_musikcubed_request(format!("{path}{query_prefix}{query}"), request) - .await -} - -async fn metadata_ws( +pub(crate) async fn metadata_ws( State(app): State, TypedHeader(user_agent): TypedHeader, ws: WebSocketUpgrade, @@ -296,7 +57,7 @@ async fn metadata_ws( Ok(upgrade) } -async fn handle_metadata_socket( +pub(crate) async fn handle_metadata_socket( mut server_socket: WebSocketStream>, mut client_socket: WebSocket, app: AppState, diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs new file mode 100644 index 0000000..69586f0 --- /dev/null +++ b/src/handlers/mod.rs @@ -0,0 +1,21 @@ +use ::http::HeaderValue; +use serde::Deserialize; + +pub(crate) mod data; +pub(crate) mod internal; +pub(crate) mod metadata; +pub(crate) mod share; + +pub(crate) const AUDIO_CACHE_HEADER: HeaderValue = + HeaderValue::from_static("private, max-age=604800"); + +#[derive(Deserialize)] +pub(crate) struct Auth { + #[serde(default)] + token: Option, +} + +pub(crate) use self::data::*; +pub(crate) use internal::*; +pub(crate) use metadata::*; +pub(crate) use share::*; diff --git a/src/handlers/share.rs b/src/handlers/share.rs new file mode 100644 index 0000000..590bcbd --- /dev/null +++ b/src/handlers/share.rs @@ -0,0 +1,73 @@ +use axum::{ + extract::{Path, Query, State}, + response::IntoResponse, +}; +use http::{ + header::{CACHE_CONTROL, RANGE}, + Request, Response, +}; +use hyper::Body; + +use crate::{error::AppError, AppState}; + +use super::{Auth, AUDIO_CACHE_HEADER}; + +pub(crate) async fn generate_scoped_token( + State(app): State, + Query(query): Query, + Path(music_id): Path, +) -> Result { + app.verify_token(query.token).await?; + + // generate token + let token = app.scoped_tokens.generate_for_id(music_id).await; + Ok(token) +} + +pub(crate) async fn get_scoped_music_info( + State(app): State, + Path(token): Path, +) -> Result { + let music_id = app.verify_scoped_token(token).await?; + let Some(info) = app.music_info.get(music_id).await else { + return Err("music id not found".into()); + }; + Ok(serde_json::to_string(&info).unwrap()) +} + +pub(crate) async fn get_scoped_music_thumbnail( + State(app): State, + Path(token): Path, +) -> Result, AppError> { + let music_id = app.verify_scoped_token(token).await?; + let Some(info) = app.music_info.get(music_id).await else { + return Err("music id not found".into()); + }; + app.make_musikcubed_request( + format!("thumbnail/{}", info.thumbnail_id), + Request::new(Body::empty()), + ) + .await +} + +pub(crate) async fn get_scoped_music_file( + State(app): State, + Path(token): Path, + request: Request, +) -> Result, AppError> { + let music_id = app.verify_scoped_token(token).await?; + let mut req = Request::new(Body::empty()); + // proxy any range headers + if let Some(range) = request.headers().get(RANGE).cloned() { + req.headers_mut().insert(RANGE, range); + } + let mut resp = app + .make_musikcubed_request(format!("audio/external_id/{music_id}"), req) + .await?; + if resp.status().is_success() { + // add cache header + resp.headers_mut() + .insert(CACHE_CONTROL, AUDIO_CACHE_HEADER.clone()); + } + Ok(resp) +} diff --git a/src/main.rs b/src/main.rs index 2650593..0a61702 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,7 +20,8 @@ use crate::{ mod api; mod error; -mod handler; +mod handlers; +mod router; mod token; mod utils; @@ -60,7 +61,7 @@ async fn app() -> Result<(), AppError> { let internal_port: u16 = get_conf("INTERNAL_PORT")?.parse()?; let state = AppState::new(AppStateInternal::new(public_port).await?); - let (public_router, internal_router) = handler::handler(state).await?; + let (public_router, internal_router) = router::handler(state).await?; let internal_make_service = internal_router.into_make_service(); let internal_task = tokio::spawn( diff --git a/src/router.rs b/src/router.rs new file mode 100644 index 0000000..051cfaa --- /dev/null +++ b/src/router.rs @@ -0,0 +1,96 @@ +use super::AppError; +use axum::{ + routing::{get, post}, + Router, +}; +use http::{ + header::{AUTHORIZATION, CACHE_CONTROL, CONTENT_TYPE}, + HeaderName, Method, Request, +}; +use hyper::Body; +use tower_http::{ + cors::CorsLayer, + request_id::{MakeRequestUuid, SetRequestIdLayer}, + sensitive_headers::SetSensitiveRequestHeadersLayer, + trace::TraceLayer, +}; +use tracing::Span; + +use crate::{ + handlers, + utils::{remove_token_from_query, QueryDisplay}, + AppState, +}; + +const REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); + +fn make_span_trace(req: &Request) -> Span { + let query_map = remove_token_from_query(req.uri().query()); + + let request_id = req + .headers() + .get(REQUEST_ID) + .and_then(|v| v.to_str().ok()) + .unwrap_or("no id set"); + + if query_map.is_empty() { + tracing::debug_span!( + "request", + path = %req.uri().path(), + id = %request_id, + ) + } else { + let query_display = QueryDisplay::new(query_map); + tracing::debug_span!( + "request", + path = %req.uri().path(), + query = %query_display, + id = %request_id, + ) + } +} + +pub(super) async fn handler(state: AppState) -> Result<(Router, Router), AppError> { + let internal_router = Router::new() + .route("/token/generate", get(handlers::generate_token)) + .route("/token/revoke_all", post(handlers::revoke_all_tokens)) + .with_state(state.clone()); + + let trace_layer = TraceLayer::new_for_http() + .make_span_with(make_span_trace) + .on_request(|req: &Request, _span: &Span| { + tracing::debug!( + "started processing request {} on {:?}", + req.method(), + req.version(), + ) + }); + let cors_layer = CorsLayer::new() + .allow_origin(tower_http::cors::Any) + .allow_headers([CONTENT_TYPE, CACHE_CONTROL, REQUEST_ID]) + .allow_methods([Method::GET]); + let sensitive_header_layer = SetSensitiveRequestHeadersLayer::new([AUTHORIZATION]); + let request_id_layer = SetRequestIdLayer::new(REQUEST_ID.clone(), MakeRequestUuid); + + let router = Router::new() + .route( + "/token/generate_for_music/:id", + get(handlers::generate_scoped_token), + ) + .route("/thumbnail/:id", get(handlers::http)) + .route("/audio/external_id/:id", get(handlers::get_music)) + .route("/share/audio/:token", get(handlers::get_scoped_music_file)) + .route( + "/share/thumbnail/:token", + get(handlers::get_scoped_music_thumbnail), + ) + .route("/share/info/:token", get(handlers::get_scoped_music_info)) + .route("/", get(handlers::metadata_ws)) + .layer(trace_layer) + .layer(sensitive_header_layer) + .layer(cors_layer) + .layer(request_id_layer) + .with_state(state); + + Ok((router, internal_router)) +} diff --git a/src/utils.rs b/src/utils.rs index e4b91ff..936f367 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -8,6 +8,9 @@ use axum::{ extract::ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage}, Error as AxumError, }; +use base64::Engine; + +use crate::{error::AppError, B64}; #[derive(Debug)] pub(crate) enum WsError { @@ -196,3 +199,17 @@ impl Display for QueryDisplay { Ok(()) } } + +pub(crate) fn extract_password_from_basic_auth(auth: &str) -> Result { + let decoded = B64.decode(auth.trim_start_matches("Basic "))?; + let auth = String::from_utf8(decoded)?; + Ok(auth.trim_start_matches("default:").to_string()) +} + +pub(crate) fn remove_token_from_query(query: Option<&str>) -> HashMap { + let mut query_map: HashMap = query + .and_then(|v| serde_qs::from_str(v).ok()) + .unwrap_or_else(HashMap::new); + query_map.remove("token"); + query_map +}