diff --git a/Cargo.lock b/Cargo.lock index e4150ce..cc369df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -619,6 +619,7 @@ dependencies = [ "tower-http", "tracing", "tracing-subscriber", + "tungstenite 0.18.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 7f97bd0..3a00970 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ tower-http = {version = "0.4", features = ["trace", "cors", "sensitive-headers", hyper = {version = "0.14", features = ["client"]} http = "0.2" async-tungstenite = {version = "0.21", features = ["tokio-runtime"]} +axum-tungstenite = {package = "tungstenite", version = "0.18"} futures = {version = "0.3"} serde = {version = "1", features = ["derive"]} serde_json = "1" diff --git a/flake.nix b/flake.nix index 26752ab..98ba00a 100644 --- a/flake.nix +++ b/flake.nix @@ -27,18 +27,20 @@ }; devShells.default = crateOutputs.devShell.overrideAttrs (old: { RUST_SRC_PATH = "${config.nci.toolchains.shell}/lib/rustlib/src/rust/library"; - packages = (old.packages or []) ++ [ - pkgs.rust-analyzer - (pkgs.writeShellApplication { - name = "generate-cert"; - runtimeInputs = with pkgs; [mkcert coreutils]; - text = '' - mkcert localhost 127.0.0.1 ::1 - mv localhost+2.pem cert.pem - mv localhost+2-key.pem key.pem - ''; - }) - ]; + packages = + (old.packages or []) + ++ [ + pkgs.rust-analyzer + (pkgs.writeShellApplication { + name = "generate-cert"; + runtimeInputs = with pkgs; [mkcert coreutils]; + text = '' + mkcert localhost 127.0.0.1 ::1 + mv localhost+2.pem cert.pem + mv localhost+2-key.pem key.pem + ''; + }) + ]; }); packages.default = crateOutputs.packages.release; }; diff --git a/src/api.rs b/src/api.rs new file mode 100644 index 0000000..9238e93 --- /dev/null +++ b/src/api.rs @@ -0,0 +1,64 @@ +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct WsApiMessage { + pub(crate) name: String, + #[serde(rename = "type")] + pub(crate) kind: WsApiMessageType, + pub(crate) id: String, + #[serde(default)] + pub(crate) device_id: Option, + pub(crate) options: Map, +} + +impl WsApiMessage { + pub(crate) fn new(name: impl Into, kind: WsApiMessageType) -> Self { + Self { + name: name.into(), + kind, + id: String::new(), + device_id: None, + options: Map::new(), + } + } + + pub(crate) fn request(name: impl Into) -> Self { + Self::new(name, WsApiMessageType::Request) + } + + pub(crate) fn authenticate(password: impl Into) -> Self { + Self::new("authenticate", WsApiMessageType::Request).option("password", password.into()) + } + + pub(crate) fn id(mut self, id: impl Into) -> Self { + self.id = id.into(); + self + } + + pub(crate) fn device_id(mut self, device_id: impl Into) -> Self { + self.device_id = Some(device_id.into()); + self + } + + pub(crate) fn option(mut self, key: impl Into, value: impl Into) -> Self { + self.options.insert(key.into(), value.into()); + self + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) enum WsApiMessageType { + #[serde(rename = "request")] + Request, + #[serde(rename = "response")] + Response, + #[serde(rename = "broadcast")] + Broadcast, +} + +impl ToString for WsApiMessage { + fn to_string(&self) -> String { + serde_json::to_string(self).unwrap() + } +} diff --git a/src/error.rs b/src/error.rs index 43ba6a2..0d91121 100644 --- a/src/error.rs +++ b/src/error.rs @@ -8,6 +8,14 @@ type BoxedError = Box; #[derive(Debug)] pub(crate) struct AppError { internal: BoxedError, + status: Option, +} + +impl AppError { + pub(crate) fn status(mut self, code: StatusCode) -> Self { + self.status = Some(code); + self + } } impl From for AppError @@ -17,6 +25,7 @@ where fn from(err: E) -> Self { Self { internal: err.into(), + status: None, } } } @@ -24,7 +33,7 @@ where impl IntoResponse for AppError { fn into_response(self) -> axum::response::Response { ( - StatusCode::INTERNAL_SERVER_ERROR, + self.status.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), format!("Something went wrong: {}", self.internal), ) .into_response() diff --git a/src/handler.rs b/src/handler.rs index 391c511..58fe879 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,14 +1,12 @@ -use std::{collections::HashMap, fmt::Display}; +use std::collections::HashMap; use super::AppError; use async_tungstenite::{ - tokio::TokioAdapter, - tungstenite::{protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage}, - WebSocketStream, + tokio::TokioAdapter, tungstenite::Message as TungsteniteMessage, WebSocketStream, }; use axum::{ extract::{ - ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage, WebSocket, WebSocketUpgrade}, + ws::{Message as AxumMessage, WebSocket, WebSocketUpgrade}, Path, Query, State, }, headers::UserAgent, @@ -23,7 +21,7 @@ use http::{ HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode, }; use hyper::Body; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use serde_json::Value; use tokio::net::TcpStream; use tower_http::{ @@ -34,7 +32,11 @@ use tower_http::{ }; use tracing::{Instrument, Span}; -use crate::{AppState, B64}; +use crate::{ + api::WsApiMessage, + utils::{axum_msg_to_tungstenite, tungstenite_msg_to_axum, QueryDisplay, WsError}, + AppState, B64, +}; const AUDIO_CACHE_HEADER: HeaderValue = HeaderValue::from_static("private, max-age=604800"); const REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); @@ -59,23 +61,6 @@ fn remove_token_from_query(query: Option<&str>) -> HashMap { query_map } -struct QueryDisplay { - map: HashMap, -} - -impl Display for QueryDisplay { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let length = self.map.len(); - for (index, (k, v)) in self.map.iter().enumerate() { - write!(f, "{k}={v}")?; - if index < length - 1 { - write!(f, "&")?; - } - } - Ok(()) - } -} - fn make_span_trace(req: &Request) -> Span { let query_map = remove_token_from_query(req.uri().query()); @@ -92,7 +77,7 @@ fn make_span_trace(req: &Request) -> Span { id = %request_id, ) } else { - let query_display = QueryDisplay { map: query_map }; + let query_display = QueryDisplay::new(query_map); tracing::debug_span!( "request", path = %req.uri().path(), @@ -128,7 +113,9 @@ pub(super) async fn handler(state: AppState) -> Result<(Router, Router), AppErro .route("/token/generate_for_music/:id", get(generate_scoped_token)) .route("/thumbnail/:id", get(http)) .route("/audio/external_id/:id", get(get_music)) - .route("/audio/scoped/:id", get(get_scoped_music)) + .route("/share/audio/:token", get(get_scoped_music_file)) + .route("/share/thumbnail/:token", get(get_scoped_music_thumbnail)) + .route("/share/info/:token", get(get_scoped_music_info)) .route("/", get(metadata_ws)) .layer(trace_layer) .layer(sensitive_header_layer) @@ -188,37 +175,62 @@ async fn generate_scoped_token( Ok(token.into_response()) } -async fn get_scoped_music( +async fn get_scoped_music_info( + State(app): State, + Path(token): Path, +) -> Result { + let music_id = app.verify_scoped_token(token).await?; + let Some(info) = app.music_info.get(music_id).await else { + return Err("music id not found".into()); + }; + Ok(serde_json::to_string(&info).unwrap()) +} + +async fn get_scoped_music_thumbnail( + State(app): State, + Path(token): Path, +) -> Result, AppError> { + let music_id = app.verify_scoped_token(token).await?; + let Some(info) = app.music_info.get(music_id).await else { + return Err("music id not found".into()); + }; + let req = Request::builder() + .uri(format!( + "http://{}:{}/thumbnail/{}", + app.musikcubed_address, app.musikcubed_http_port, info.thumbnail_id + )) + .header(AUTHORIZATION, app.musikcubed_auth_header_value.clone()) + .body(Body::empty()) + .expect("cant fail"); + let resp = app.client.request(req).await?; + Ok(resp) +} + +async fn get_scoped_music_file( State(app): State, Path(token): Path, request: Request, ) -> Result, AppError> { - if let Some(music_id) = app.scoped_tokens.verify(token).await { - let mut req = Request::builder() - .uri(format!( - "http://{}:{}/audio/external_id/{}", - app.musikcubed_address, app.musikcubed_http_port, music_id - )) - .header(AUTHORIZATION, app.musikcubed_auth_header_value.clone()) - .body(Body::empty()) - .expect("cant fail"); - // proxy any range headers - if let Some(range) = request.headers().get(RANGE).cloned() { - req.headers_mut().insert(RANGE, range); - } - let mut resp = app.client.request(req).await?; - if resp.status().is_success() { - // add cache header - resp.headers_mut() - .insert(CACHE_CONTROL, AUDIO_CACHE_HEADER.clone()); - } - Ok(resp) - } else { - Ok(Response::builder() - .status(StatusCode::UNAUTHORIZED) - .body("Invalid scoped token".to_string().into()) - .expect("cant fail")) + let music_id = app.verify_scoped_token(token).await?; + let mut req = Request::builder() + .uri(format!( + "http://{}:{}/audio/external_id/{}", + app.musikcubed_address, app.musikcubed_http_port, music_id + )) + .header(AUTHORIZATION, app.musikcubed_auth_header_value.clone()) + .body(Body::empty()) + .expect("cant fail"); + // proxy any range headers + if let Some(range) = request.headers().get(RANGE).cloned() { + req.headers_mut().insert(RANGE, range); } + let mut resp = app.client.request(req).await?; + if resp.status().is_success() { + // add cache header + resp.headers_mut() + .insert(CACHE_CONTROL, AUDIO_CACHE_HEADER.clone()); + } + Ok(resp) } async fn get_music( @@ -340,16 +352,6 @@ async fn metadata_ws( Ok(upgrade) } -#[derive(Serialize, Deserialize)] -struct WsApiMessage { - name: String, - r#type: String, - id: String, - #[serde(default)] - device_id: Option, - options: serde_json::value::Map, -} - async fn handle_metadata_socket( mut server_socket: WebSocketStream>, mut client_socket: WebSocket, @@ -410,20 +412,9 @@ async fn handle_metadata_socket( let og_auth_reply = 'ok: { 'err: { // send actual auth message to the musikcubed server - let auth_msg = WsApiMessage { - name: "authenticate".to_string(), - r#type: "request".to_string(), - id: og_auth_msg.id, - device_id: og_auth_msg.device_id, - options: { - let mut map = serde_json::Map::with_capacity(1); - map.insert( - "password".to_string(), - app.musikcubed_password.clone().into(), - ); - map - }, - }; + let auth_msg = WsApiMessage::authenticate(app.musikcubed_password.clone()) + .id(og_auth_msg.id) + .device_id(og_auth_msg.device_id.unwrap_or_default()); let auth_msg_ser = serde_json::to_string(&auth_msg).expect(""); if let Err(err) = server_socket .send(TungsteniteMessage::Text(auth_msg_ser)) @@ -507,19 +498,35 @@ async fn handle_metadata_socket( let in_read_fut = async move { while let Some(res) = in_read.next().await { + let res = res.map_err(WsError::from); match res { Ok(msg) => { tracing::trace!("got message from client: {msg:?}"); - let res = out_write.send(axum_msg_to_tungstenite(msg)).await; + let res = out_write + .send(axum_msg_to_tungstenite(msg)) + .await + .map_err(WsError::from); if let Err(err) = res { - tracing::error!("could not write to server socket: {err}"); - break; + match err { + WsError::Closed(reason) => { + tracing::error!("server socket was closed: {reason}"); + break; + } + err => { + tracing::error!("could not write to server socket: {err}"); + } + } } } - Err(err) => { - tracing::error!("could not read from client socket: {err}"); - break; - } + Err(err) => match err { + WsError::Closed(reason) => { + tracing::error!("client socket was closed, {reason}"); + break; + } + err => { + tracing::error!("could not read from client socket: {err}"); + } + }, } } let _ = out_write.send(TungsteniteMessage::Close(None)).await; @@ -528,19 +535,36 @@ async fn handle_metadata_socket( let in_write_fut = async move { while let Some(res) = out_read.next().await { + let res = res.map_err(WsError::from); match res { Ok(msg) => { tracing::trace!("got message from server: {msg:?}"); - let res = in_write.send(tungstenite_msg_to_axum(msg)).await; + let res = in_write + .send(tungstenite_msg_to_axum(msg)) + .await + .map_err(WsError::from); if let Err(err) = res { - tracing::error!("could not write to client socket: {err}"); + match err { + WsError::Closed(reason) => { + tracing::error!("client socket was closed, {reason}"); + break; + } + err => { + tracing::error!("could not write to server socket: {err}"); + } + } break; } } - Err(err) => { - tracing::error!("could not read from server socket: {err}"); - break; - } + Err(err) => match err { + WsError::Closed(reason) => { + tracing::error!("server socket was closed, {reason}"); + break; + } + err => { + tracing::error!("could not read from server socket: {err}"); + } + }, } } let _ = in_write.send(AxumMessage::Close(None)).await; @@ -551,46 +575,3 @@ async fn handle_metadata_socket( tracing::debug!("ending metadata ws task"); } - -#[inline(always)] -fn tungstenite_msg_to_axum(msg: TungsteniteMessage) -> AxumMessage { - match msg { - TungsteniteMessage::Text(data) => AxumMessage::Text(data), - TungsteniteMessage::Binary(data) => AxumMessage::Binary(data), - TungsteniteMessage::Ping(data) => AxumMessage::Ping(data), - TungsteniteMessage::Pong(data) => AxumMessage::Pong(data), - TungsteniteMessage::Close(frame) => { - AxumMessage::Close(frame.map(tungstenite_close_frame_to_axum)) - } - TungsteniteMessage::Frame(_) => unreachable!("we don't use raw frames"), - } -} - -#[inline(always)] -fn axum_msg_to_tungstenite(msg: AxumMessage) -> TungsteniteMessage { - match msg { - AxumMessage::Text(data) => TungsteniteMessage::Text(data), - AxumMessage::Binary(data) => TungsteniteMessage::Binary(data), - AxumMessage::Ping(data) => TungsteniteMessage::Ping(data), - AxumMessage::Pong(data) => TungsteniteMessage::Pong(data), - AxumMessage::Close(frame) => { - TungsteniteMessage::Close(frame.map(axum_close_frame_to_tungstenite)) - } - } -} - -#[inline(always)] -fn tungstenite_close_frame_to_axum(frame: TungsteniteCloseFrame) -> AxumCloseFrame { - AxumCloseFrame { - code: frame.code.into(), - reason: frame.reason, - } -} - -#[inline(always)] -fn axum_close_frame_to_tungstenite(frame: AxumCloseFrame) -> TungsteniteCloseFrame { - TungsteniteCloseFrame { - code: frame.code.into(), - reason: frame.reason, - } -} diff --git a/src/main.rs b/src/main.rs index c99bea5..f4d00a2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,14 +4,25 @@ use axum_server::tls_rustls::RustlsConfig; use base64::Engine; use dotenvy::Error as DotenvError; use error::AppError; +use futures::{SinkExt, StreamExt}; use hyper::{client::HttpConnector, Body}; +use scc::HashMap; +use serde::{Deserialize, Serialize}; +use serde_json::Value; use token::{MusicScopedTokens, Tokens}; use tracing::{info, warn}; use tracing_subscriber::prelude::*; +use crate::{ + api::WsApiMessage, + utils::{HandleWsItem, WsError}, +}; + +mod api; mod error; mod handler; mod token; +mod utils; #[tokio::main] async fn main() -> ExitCode { @@ -103,6 +114,7 @@ const B64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STA struct AppStateInternal { client: Client, tokens: Tokens, + music_info: MusicInfoMap, scoped_tokens: MusicScopedTokens, tokens_path: String, public_port: u16, @@ -115,29 +127,135 @@ struct AppStateInternal { impl AppStateInternal { async fn new(public_port: u16) -> Result { + let musikcubed_http_port = get_conf("MUSIKCUBED_HTTP_PORT")?.parse()?; + let musikcubed_metadata_port = get_conf("MUSIKCUBED_METADATA_PORT")?.parse()?; + let musikcubed_address = get_conf("MUSIKCUBED_ADDRESS")?; let musikcubed_password = get_conf("MUSIKCUBED_PASSWORD")?; + let musikcubed_auth_header_value = { + let mut val: http::HeaderValue = format!( + "Basic {}", + B64.encode(format!("default:{}", musikcubed_password)) + ) + .parse() + .expect("valid header value"); + val.set_sensitive(true); + val + }; + let tokens_path = get_conf("TOKENS_FILE")?; + + let music_info = MusicInfoMap::new(); + music_info + .read( + musikcubed_password.clone(), + &musikcubed_address, + musikcubed_metadata_port, + ) + .await?; + let this = Self { - public_port, - musikcubed_address: get_conf("MUSIKCUBED_ADDRESS")?, - musikcubed_http_port: get_conf("MUSIKCUBED_HTTP_PORT")?.parse()?, - musikcubed_metadata_port: get_conf("MUSIKCUBED_METADATA_PORT")?.parse()?, - musikcubed_auth_header_value: { - let mut val: http::HeaderValue = format!( - "Basic {}", - B64.encode(format!("default:{}", musikcubed_password)) - ) - .parse() - .expect("valid header value"); - val.set_sensitive(true); - val - }, - musikcubed_password, client: Client::new(), tokens: Tokens::read(&tokens_path).await?, scoped_tokens: MusicScopedTokens::new(get_conf("SCOPED_EXPIRY_DURATION")?.parse()?), + musikcubed_address, + musikcubed_http_port, + musikcubed_metadata_port, + musikcubed_auth_header_value, + musikcubed_password, tokens_path, + public_port, + music_info, }; Ok(this) } + + async fn verify_scoped_token(&self, token: impl AsRef) -> Result { + self.scoped_tokens.verify(token).await.ok_or_else(|| { + AppError::from("Invalid token or not authorized").status(http::StatusCode::UNAUTHORIZED) + }) + } +} + +#[derive(Clone, Deserialize, Serialize)] +struct MusicInfo { + external_id: String, + title: String, + album: String, + artist: String, + thumbnail_id: u32, +} + +#[derive(Clone)] +struct MusicInfoMap { + map: HashMap, +} + +impl MusicInfoMap { + fn new() -> Self { + Self { + map: HashMap::new(), + } + } + + async fn get(&self, id: impl AsRef) -> Option { + self.map.read_async(id.as_ref(), |_, v| v.clone()).await + } + + async fn read( + &self, + password: impl Into, + address: impl AsRef, + port: u16, + ) -> Result<(), AppError> { + use async_tungstenite::tungstenite::Message; + + let uri = format!("ws://{}:{}", address.as_ref(), port); + let (mut ws_stream, _) = async_tungstenite::tokio::connect_async(uri) + .await + .map_err(WsError::from)?; + + let device_id = "musikquadrupled"; + + // do the authentication + let auth_msg = WsApiMessage::authenticate(password.into()) + .id("auth") + .device_id(device_id); + ws_stream + .send(Message::Text(auth_msg.to_string())) + .await + .map_err(WsError::from)?; + let auth_reply: WsApiMessage = ws_stream.next().await.handle_item()?; + let is_authenticated = auth_reply + .options + .get("authenticated") + .and_then(Value::as_bool) + .unwrap_or_default(); + if !is_authenticated { + return Err("not authenticated".into()); + } + + // fetch the tracks + let fetch_tracks_msg = WsApiMessage::request("query_tracks") + .device_id(device_id) + .id("fetch_tracks") + .option("limit", u32::MAX) + .option("offset", 0); + ws_stream + .send(Message::Text(fetch_tracks_msg.to_string())) + .await + .map_err(WsError::from)?; + let mut tracks_reply: WsApiMessage = ws_stream.next().await.handle_item()?; + let Some(Value::Array(tracks)) = tracks_reply.options.remove("data") else { + tracing::debug!("reply: {tracks_reply:#?}"); + return Err("must have tracks".into()); + }; + for track in tracks { + let info: MusicInfo = serde_json::from_value(track).unwrap(); + let _ = self.map.insert_async(info.external_id.clone(), info).await; + } + + ws_stream.close(None).await.map_err(WsError::from)?; + + Ok(()) + } } diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..e4b91ff --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,198 @@ +use std::{borrow::Cow, collections::HashMap, error::Error, fmt::Display}; + +use async_tungstenite::tungstenite::{ + protocol::CloseFrame as TungsteniteCloseFrame, Error as TungsteniteError, + Message as TungsteniteMessage, +}; +use axum::{ + extract::ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage}, + Error as AxumError, +}; + +#[derive(Debug)] +pub(crate) enum WsError { + Closed(Cow<'static, str>), + InvalidMessage(serde_json::Error), + Other(Box), + Ping(Vec), + Pong(Vec), +} + +impl WsError { + pub(crate) fn is_closed(&self) -> bool { + match self { + WsError::Closed(_) => true, + _ => false, + } + } +} + +impl Display for WsError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Closed(reason) => write!(f, "ws closed, reason: {reason}"), + Self::InvalidMessage(err) => write!(f, "invalid JSON message, {err}"), + Self::Other(err) => write!(f, "error: {err}"), + Self::Ping(_) => write!(f, "was a ping"), + Self::Pong(_) => write!(f, "was a pong"), + } + } +} + +impl Error for WsError {} + +impl From for WsError { + fn from(err: TungsteniteError) -> Self { + match err { + TungsteniteError::ConnectionClosed => WsError::Closed("was closed".into()), + TungsteniteError::AlreadyClosed => WsError::Closed("was already closed".into()), + err => WsError::Other(Box::new(err)), + } + } +} + +impl From for WsError { + fn from(err: AxumError) -> Self { + use axum_tungstenite::Error as AxumError; + + let err = match err.into_inner().downcast::() { + Ok(err) => *err, + Err(err) => return WsError::Other(err), + }; + + match err { + AxumError::ConnectionClosed => WsError::Closed("was closed".into()), + AxumError::AlreadyClosed => WsError::Closed("was already closed".into()), + err => WsError::Other(Box::new(err)), + } + } +} + +pub(crate) trait HandleWsItem { + fn handle_item(self) -> Result; +} + +impl HandleWsItem for Option> { + fn handle_item(self) -> Result { + let Some(result) = self else { + return Err(WsError::Closed("was already closed".into())); + }; + + match result? { + TungsteniteMessage::Binary(data) => { + serde_json::from_slice(&data).map_err(WsError::InvalidMessage) + } + TungsteniteMessage::Text(data) => { + serde_json::from_str(&data).map_err(WsError::InvalidMessage) + } + TungsteniteMessage::Close(frame) => Err(WsError::Closed( + frame + .map(|f| format!("was closed, reason {} (code {})", f.reason, f.code).into()) + .unwrap_or_else(|| "was closed".into()), + )), + TungsteniteMessage::Ping(data) => Err(WsError::Ping(data)), + TungsteniteMessage::Pong(data) => Err(WsError::Pong(data)), + TungsteniteMessage::Frame(_) => unreachable!("we don't read raw frames"), + } + } +} + +impl HandleWsItem for Result { + fn handle_item(self) -> Result { + Some(self).handle_item() + } +} + +impl HandleWsItem for Option> { + fn handle_item(self) -> Result { + let Some(result) = self else { + return Err(WsError::Closed("was already closed".into())); + }; + + match result? { + AxumMessage::Binary(data) => { + serde_json::from_slice(&data).map_err(WsError::InvalidMessage) + } + AxumMessage::Text(data) => serde_json::from_str(&data).map_err(WsError::InvalidMessage), + AxumMessage::Close(frame) => Err(WsError::Closed( + frame + .map(|f| format!("was closed, reason {} (code {})", f.reason, f.code).into()) + .unwrap_or_else(|| "was closed".into()), + )), + AxumMessage::Ping(data) => Err(WsError::Ping(data)), + AxumMessage::Pong(data) => Err(WsError::Pong(data)), + } + } +} + +impl HandleWsItem for Result { + fn handle_item(self) -> Result { + Some(self).handle_item() + } +} + +#[inline(always)] +pub(crate) fn tungstenite_msg_to_axum(msg: TungsteniteMessage) -> AxumMessage { + match msg { + TungsteniteMessage::Text(data) => AxumMessage::Text(data), + TungsteniteMessage::Binary(data) => AxumMessage::Binary(data), + TungsteniteMessage::Ping(data) => AxumMessage::Ping(data), + TungsteniteMessage::Pong(data) => AxumMessage::Pong(data), + TungsteniteMessage::Close(frame) => { + AxumMessage::Close(frame.map(tungstenite_close_frame_to_axum)) + } + TungsteniteMessage::Frame(_) => unreachable!("we don't use raw frames"), + } +} + +#[inline(always)] +pub(crate) fn axum_msg_to_tungstenite(msg: AxumMessage) -> TungsteniteMessage { + match msg { + AxumMessage::Text(data) => TungsteniteMessage::Text(data), + AxumMessage::Binary(data) => TungsteniteMessage::Binary(data), + AxumMessage::Ping(data) => TungsteniteMessage::Ping(data), + AxumMessage::Pong(data) => TungsteniteMessage::Pong(data), + AxumMessage::Close(frame) => { + TungsteniteMessage::Close(frame.map(axum_close_frame_to_tungstenite)) + } + } +} + +#[inline(always)] +fn tungstenite_close_frame_to_axum(frame: TungsteniteCloseFrame) -> AxumCloseFrame { + AxumCloseFrame { + code: frame.code.into(), + reason: frame.reason, + } +} + +#[inline(always)] +fn axum_close_frame_to_tungstenite(frame: AxumCloseFrame) -> TungsteniteCloseFrame { + TungsteniteCloseFrame { + code: frame.code.into(), + reason: frame.reason, + } +} + +pub(crate) struct QueryDisplay { + map: HashMap, +} + +impl QueryDisplay { + pub(crate) fn new(map: HashMap) -> Self { + Self { map } + } +} + +impl Display for QueryDisplay { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let length = self.map.len(); + for (index, (k, v)) in self.map.iter().enumerate() { + write!(f, "{k}={v}")?; + if index < length - 1 { + write!(f, "&")?; + } + } + Ok(()) + } +}