From e259740d501d00fb3548b0be145817ef8fdc35ae Mon Sep 17 00:00:00 2001 From: Yusuf Bera Ertan Date: Fri, 5 May 2023 17:16:03 +0300 Subject: [PATCH] feat: implement better errors, send less things in requests --- src/handler.rs | 165 +++++++++++++++++++++++++++++++------------------ src/main.rs | 16 +++-- 2 files changed, 114 insertions(+), 67 deletions(-) diff --git a/src/handler.rs b/src/handler.rs index 2a2c88d..391c511 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,4 +1,4 @@ -use std::{borrow::Cow, collections::HashMap, fmt::Display}; +use std::{collections::HashMap, fmt::Display}; use super::AppError; use async_tungstenite::{ @@ -19,8 +19,8 @@ use axum::{ use base64::Engine; use futures::{SinkExt, StreamExt}; use http::{ - header::{AUTHORIZATION, CACHE_CONTROL, CONTENT_TYPE}, - HeaderName, HeaderValue, Method, Request, Response, StatusCode, + header::{AUTHORIZATION, CACHE_CONTROL, CONTENT_TYPE, RANGE}, + HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode, }; use hyper::Body; use serde::{Deserialize, Serialize}; @@ -51,11 +51,19 @@ fn extract_password_from_basic_auth(auth: &str) -> Result { Ok(auth.trim_start_matches("default:").to_string()) } -struct QueryDisplay<'a, 'b> { - map: HashMap, Cow<'b, str>>, +fn remove_token_from_query(query: Option<&str>) -> HashMap { + let mut query_map: HashMap = query + .and_then(|v| serde_qs::from_str(v).ok()) + .unwrap_or_else(HashMap::new); + query_map.remove("token"); + query_map } -impl<'a, 'b> Display for QueryDisplay<'a, 'b> { +struct QueryDisplay { + map: HashMap, +} + +impl Display for QueryDisplay { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let length = self.map.len(); for (index, (k, v)) in self.map.iter().enumerate() { @@ -69,11 +77,7 @@ impl<'a, 'b> Display for QueryDisplay<'a, 'b> { } fn make_span_trace(req: &Request) -> Span { - let query = req.uri().query(); - let mut query_map = query - .and_then(|v| serde_qs::from_str::, Cow>>(v).ok()) - .unwrap_or_else(HashMap::new); - query_map.remove("token"); + let query_map = remove_token_from_query(req.uri().query()); let request_id = req .headers() @@ -99,6 +103,11 @@ fn make_span_trace(req: &Request) -> Span { } pub(super) async fn handler(state: AppState) -> Result<(Router, Router), AppError> { + let internal_router = Router::new() + .route("/token/generate", get(generate_token)) + .route("/token/revoke_all", post(revoke_all_tokens)) + .with_state(state.clone()); + let trace_layer = TraceLayer::new_for_http() .make_span_with(make_span_trace) .on_request(|req: &Request, _span: &Span| { @@ -108,11 +117,12 @@ pub(super) async fn handler(state: AppState) -> Result<(Router, Router), AppErro req.version(), ) }); - - let internal_router = Router::new() - .route("/token/generate", get(generate_token)) - .route("/token/revoke_all", post(revoke_all_tokens)) - .with_state(state.clone()); + let cors_layer = CorsLayer::new() + .allow_origin(tower_http::cors::Any) + .allow_headers([CONTENT_TYPE, CACHE_CONTROL, REQUEST_ID]) + .allow_methods([Method::GET]); + let sensitive_header_layer = SetSensitiveRequestHeadersLayer::new([AUTHORIZATION]); + let request_id_layer = SetRequestIdLayer::new(REQUEST_ID.clone(), MakeRequestUuid); let router = Router::new() .route("/token/generate_for_music/:id", get(generate_scoped_token)) @@ -121,14 +131,9 @@ pub(super) async fn handler(state: AppState) -> Result<(Router, Router), AppErro .route("/audio/scoped/:id", get(get_scoped_music)) .route("/", get(metadata_ws)) .layer(trace_layer) - .layer(SetSensitiveRequestHeadersLayer::new([AUTHORIZATION])) - .layer( - CorsLayer::new() - .allow_origin(tower_http::cors::Any) - .allow_headers([CONTENT_TYPE, CACHE_CONTROL, REQUEST_ID]) - .allow_methods([Method::GET]), - ) - .layer(SetRequestIdLayer::new(REQUEST_ID.clone(), MakeRequestUuid)) + .layer(sensitive_header_layer) + .layer(cors_layer) + .layer(request_id_layer) .with_state(state); Ok((router, internal_router)) @@ -186,21 +191,22 @@ async fn generate_scoped_token( async fn get_scoped_music( State(app): State, Path(token): Path, + request: Request, ) -> Result, AppError> { if let Some(music_id) = app.scoped_tokens.verify(token).await { - let mut resp = app - .client - .request( - Request::builder() - .uri(format!( - "http://{}:{}/audio/external_id/{}", - app.musikcubed_address, app.musikcubed_http_port, music_id - )) - .header(AUTHORIZATION, app.musikcubed_auth_header_value.clone()) - .body(Body::empty()) - .expect("cant fail"), - ) - .await?; + let mut req = Request::builder() + .uri(format!( + "http://{}:{}/audio/external_id/{}", + app.musikcubed_address, app.musikcubed_http_port, music_id + )) + .header(AUTHORIZATION, app.musikcubed_auth_header_value.clone()) + .body(Body::empty()) + .expect("cant fail"); + // proxy any range headers + if let Some(range) = request.headers().get(RANGE).cloned() { + req.headers_mut().insert(RANGE, range); + } + let mut resp = app.client.request(req).await?; if resp.status().is_success() { // add cache header resp.headers_mut() @@ -232,23 +238,26 @@ async fn get_music( async fn http( State(app): State, - Query(query): Query, + Query(auth): Query, mut req: Request, ) -> Result, AppError> { + // remove token from query let path = req.uri().path(); - let path_query = req - .uri() - .path_and_query() - .map(|v| v.as_str()) - .unwrap_or(path); + let query_map = remove_token_from_query(req.uri().query()); + let has_query = !query_map.is_empty(); + let query = has_query + .then(|| serde_qs::to_string(&query_map).unwrap()) + .unwrap_or_else(String::new); + let query_prefix = has_query.then_some("?").unwrap_or(""); + // craft new url *req.uri_mut() = format!( - "http://{}:{}{}", - app.musikcubed_address, app.musikcubed_http_port, path_query + "http://{}:{}{path}{query_prefix}{query}", + app.musikcubed_address, app.musikcubed_http_port ) .parse()?; - let maybe_token = query.token.or_else(|| { + let maybe_token = auth.token.or_else(|| { req.headers() .get(AUTHORIZATION) .and_then(|h| h.to_str().ok()) @@ -269,8 +278,29 @@ async fn http( .expect("cant fail")); } - req.headers_mut() - .insert(AUTHORIZATION, app.musikcubed_auth_header_value.clone()); + // proxy only the headers we need + let headers = { + let mut headers = HeaderMap::with_capacity(2); + let mut proxy_header = |header_name: HeaderName| { + if let Some(value) = req.headers().get(&header_name).cloned() { + headers.insert(header_name, value); + } + }; + // proxy range header + proxy_header(RANGE); + // add auth + headers.insert(AUTHORIZATION, app.musikcubed_auth_header_value.clone()); + headers + }; + + *req.headers_mut() = headers; + + let scheme = req.uri().scheme_str().unwrap(); + let authority = req.uri().authority().unwrap().as_str(); + tracing::debug!( + "proxying request to {scheme}://{authority} with headers {:?}", + req.headers() + ); Ok(app.client.request(req).await?) } @@ -286,6 +316,7 @@ async fn metadata_ws( "ws://{}:{}", app.musikcubed_address, app.musikcubed_metadata_port ); + tracing::debug!("proxying websocket request to {uri}"); let (ws_stream, _) = connect_async(uri).await?; let upgrade = ws @@ -402,20 +433,32 @@ async fn handle_metadata_socket( break 'err; } // wait for auth reply - if let Some(Ok(TungsteniteMessage::Text(raw))) = server_socket.next().await { - let Ok(parsed) = serde_json::from_str::(&raw) else { - tracing::error!("invalid auth response message: {raw}"); - break 'err; - }; - let is_authenticated = parsed - .options - .get("authenticated") - .and_then(Value::as_bool) - .unwrap_or(false); - match is_authenticated { - true => break 'ok parsed, - false => break 'err, + match server_socket.next().await { + Some(Ok(TungsteniteMessage::Text(raw))) => { + let Ok(parsed) = serde_json::from_str::(&raw) else { + tracing::error!("invalid auth response message: {raw}"); + break 'err; + }; + let is_authenticated = parsed + .options + .get("authenticated") + .and_then(Value::as_bool) + .unwrap_or(false); + match is_authenticated { + true => break 'ok parsed, + false => { + tracing::error!("unauthorized"); + break 'err; + } + } } + Some(Ok(TungsteniteMessage::Close(frame))) => { + let reason = frame + .map(|v| format!("{} (code {})", v.reason, v.code)) + .unwrap_or_else(|| "no reason given".to_string()); + tracing::error!("socket was closed by musikcubed: {}", reason); + } + _ => tracing::error!("unexpected message from musikcubed"), } } let _ = client_socket.close().await; diff --git a/src/main.rs b/src/main.rs index f1b6875..c99bea5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -122,12 +122,16 @@ impl AppStateInternal { musikcubed_address: get_conf("MUSIKCUBED_ADDRESS")?, musikcubed_http_port: get_conf("MUSIKCUBED_HTTP_PORT")?.parse()?, musikcubed_metadata_port: get_conf("MUSIKCUBED_METADATA_PORT")?.parse()?, - musikcubed_auth_header_value: format!( - "Basic {}", - B64.encode(format!("default:{}", musikcubed_password)) - ) - .parse() - .expect("valid header value"), + musikcubed_auth_header_value: { + let mut val: http::HeaderValue = format!( + "Basic {}", + B64.encode(format!("default:{}", musikcubed_password)) + ) + .parse() + .expect("valid header value"); + val.set_sensitive(true); + val + }, musikcubed_password, client: Client::new(), tokens: Tokens::read(&tokens_path).await?,