diff --git a/Cargo.lock b/Cargo.lock index 9ebc275..9de2ed0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -614,6 +614,7 @@ dependencies = [ "scc", "serde", "serde_json", + "serde_qs", "tokio", "tower-http", "tracing", @@ -884,6 +885,17 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_qs" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0431a35568651e363364210c91983c1da5eb29404d9f0928b67d4ebcfa7d330c" +dependencies = [ + "percent-encoding", + "serde", + "thiserror", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" diff --git a/Cargo.toml b/Cargo.toml index 6fb47c0..a7b537e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,4 +20,5 @@ serde_json = "1" rust-argon2 = "1.0" rand = "0.8" scc = "1" -base64 = "0.21" \ No newline at end of file +base64 = "0.21" +serde_qs = "0.12" \ No newline at end of file diff --git a/src/handler.rs b/src/handler.rs index c8b1d4a..5f9a59a 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,4 +1,4 @@ -use std::fmt::Display; +use std::{collections::HashMap, fmt::Display}; use super::AppError; use async_tungstenite::{ @@ -47,69 +47,37 @@ fn extract_password_from_basic_auth(auth: &str) -> Result { Ok(auth.trim_start_matches("default:").to_string()) } -struct ComponentDisplay { - left: Left, - right: Right, +struct QueryDisplay<'a, 'b> { + map: HashMap<&'a str, &'b str>, } -impl<'a, 'b> ComponentDisplay<&'a str, &'b str> { - fn is_empty(&self) -> bool { - self.left.is_empty() && self.right.is_empty() - } -} - -impl Display for ComponentDisplay -where - Left: Display, - Right: Display, -{ +impl<'a, 'b> Display for QueryDisplay<'a, 'b> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}{}", self.left, self.right) + let length = self.map.len(); + for (index, (k, v)) in self.map.iter().enumerate() { + write!(f, "{k}={v}")?; + if index != length { + write!(f, "&")?; + } + } + Ok(()) } } fn make_span_trace(req: &Request) -> Span { - let uri_query_filtered = req - .uri() - .query() - .map(|q| { - let token_start = q.find("&token="); - if let Some(pos) = token_start { - let (left, right) = q.split_at(pos); - let (_, right) = right.split_at(pos + 6 + 30); - return ComponentDisplay { left, right }; - } - let token_start = q.find("token="); - if let Some(_) = token_start { - let (_, right) = q.split_at(6 + 30); - return ComponentDisplay { left: "", right }; - } - ComponentDisplay { left: q, right: "" } - }) - .unwrap_or(ComponentDisplay { - left: "", - right: "", - }); - let uri_path = ComponentDisplay { - left: { - if !uri_query_filtered.is_empty() { - ComponentDisplay { - left: req.uri().path(), - right: "?", - } - } else { - ComponentDisplay { - left: req.uri().path(), - right: "", - } - } - }, - right: uri_query_filtered, - }; + let query = req.uri().query(); + let mut query_map = query + .and_then(|v| serde_qs::from_str::>(v).ok()) + .unwrap_or_else(HashMap::new); + if query_map.contains_key("token") { + query_map.insert("token", ""); + } + let query_display = QueryDisplay { map: query_map }; tracing::debug_span!( "request", method = %req.method(), - uri = %uri_path, + path = %req.uri().path(), + query = %query_display, version = ?req.version(), headers = ?req.headers(), ) @@ -209,9 +177,11 @@ async fn get_scoped_music( .expect("cant fail"), ) .await?; - // add cache header - resp.headers_mut() - .insert(CACHE_CONTROL, AUDIO_CACHE_HEADER.clone()); + if resp.status().is_success() { + // add cache header + resp.headers_mut() + .insert(CACHE_CONTROL, AUDIO_CACHE_HEADER.clone()); + } Ok(resp) } else { Ok(Response::builder() @@ -227,9 +197,11 @@ async fn get_music( req: Request, ) -> Result, AppError> { http(State(app), Query(query), req).await.map(|mut resp| { - // add cache header - resp.headers_mut() - .insert(CACHE_CONTROL, AUDIO_CACHE_HEADER.clone()); + if resp.status().is_success() { + // add cache header + resp.headers_mut() + .insert(CACHE_CONTROL, AUDIO_CACHE_HEADER.clone()); + } resp }) }