feat: implement sharing a single music

This commit is contained in:
dusk 2023-05-09 07:29:35 +03:00
parent e259740d50
commit 3726d637f5
Signed by: dusk
GPG Key ID: 1D8F8FAF2294D6EA
8 changed files with 534 additions and 160 deletions

1
Cargo.lock generated
View File

@ -619,6 +619,7 @@ dependencies = [
"tower-http", "tower-http",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"tungstenite 0.18.0",
] ]
[[package]] [[package]]

View File

@ -14,6 +14,7 @@ tower-http = {version = "0.4", features = ["trace", "cors", "sensitive-headers",
hyper = {version = "0.14", features = ["client"]} hyper = {version = "0.14", features = ["client"]}
http = "0.2" http = "0.2"
async-tungstenite = {version = "0.21", features = ["tokio-runtime"]} async-tungstenite = {version = "0.21", features = ["tokio-runtime"]}
axum-tungstenite = {package = "tungstenite", version = "0.18"}
futures = {version = "0.3"} futures = {version = "0.3"}
serde = {version = "1", features = ["derive"]} serde = {version = "1", features = ["derive"]}
serde_json = "1" serde_json = "1"

View File

@ -27,7 +27,9 @@
}; };
devShells.default = crateOutputs.devShell.overrideAttrs (old: { devShells.default = crateOutputs.devShell.overrideAttrs (old: {
RUST_SRC_PATH = "${config.nci.toolchains.shell}/lib/rustlib/src/rust/library"; RUST_SRC_PATH = "${config.nci.toolchains.shell}/lib/rustlib/src/rust/library";
packages = (old.packages or []) ++ [ packages =
(old.packages or [])
++ [
pkgs.rust-analyzer pkgs.rust-analyzer
(pkgs.writeShellApplication { (pkgs.writeShellApplication {
name = "generate-cert"; name = "generate-cert";

64
src/api.rs Normal file
View File

@ -0,0 +1,64 @@
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct WsApiMessage {
pub(crate) name: String,
#[serde(rename = "type")]
pub(crate) kind: WsApiMessageType,
pub(crate) id: String,
#[serde(default)]
pub(crate) device_id: Option<String>,
pub(crate) options: Map<String, Value>,
}
impl WsApiMessage {
pub(crate) fn new(name: impl Into<String>, kind: WsApiMessageType) -> Self {
Self {
name: name.into(),
kind,
id: String::new(),
device_id: None,
options: Map::new(),
}
}
pub(crate) fn request(name: impl Into<String>) -> Self {
Self::new(name, WsApiMessageType::Request)
}
pub(crate) fn authenticate(password: impl Into<String>) -> Self {
Self::new("authenticate", WsApiMessageType::Request).option("password", password.into())
}
pub(crate) fn id(mut self, id: impl Into<String>) -> Self {
self.id = id.into();
self
}
pub(crate) fn device_id(mut self, device_id: impl Into<String>) -> Self {
self.device_id = Some(device_id.into());
self
}
pub(crate) fn option(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
self.options.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) enum WsApiMessageType {
#[serde(rename = "request")]
Request,
#[serde(rename = "response")]
Response,
#[serde(rename = "broadcast")]
Broadcast,
}
impl ToString for WsApiMessage {
fn to_string(&self) -> String {
serde_json::to_string(self).unwrap()
}
}

View File

@ -8,6 +8,14 @@ type BoxedError = Box<dyn std::error::Error>;
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct AppError { pub(crate) struct AppError {
internal: BoxedError, internal: BoxedError,
status: Option<StatusCode>,
}
impl AppError {
pub(crate) fn status(mut self, code: StatusCode) -> Self {
self.status = Some(code);
self
}
} }
impl<E> From<E> for AppError impl<E> From<E> for AppError
@ -17,6 +25,7 @@ where
fn from(err: E) -> Self { fn from(err: E) -> Self {
Self { Self {
internal: err.into(), internal: err.into(),
status: None,
} }
} }
} }
@ -24,7 +33,7 @@ where
impl IntoResponse for AppError { impl IntoResponse for AppError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
( (
StatusCode::INTERNAL_SERVER_ERROR, self.status.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
format!("Something went wrong: {}", self.internal), format!("Something went wrong: {}", self.internal),
) )
.into_response() .into_response()

View File

@ -1,14 +1,12 @@
use std::{collections::HashMap, fmt::Display}; use std::collections::HashMap;
use super::AppError; use super::AppError;
use async_tungstenite::{ use async_tungstenite::{
tokio::TokioAdapter, tokio::TokioAdapter, tungstenite::Message as TungsteniteMessage, WebSocketStream,
tungstenite::{protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage},
WebSocketStream,
}; };
use axum::{ use axum::{
extract::{ extract::{
ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage, WebSocket, WebSocketUpgrade}, ws::{Message as AxumMessage, WebSocket, WebSocketUpgrade},
Path, Query, State, Path, Query, State,
}, },
headers::UserAgent, headers::UserAgent,
@ -23,7 +21,7 @@ use http::{
HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode,
}; };
use hyper::Body; use hyper::Body;
use serde::{Deserialize, Serialize}; use serde::Deserialize;
use serde_json::Value; use serde_json::Value;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tower_http::{ use tower_http::{
@ -34,7 +32,11 @@ use tower_http::{
}; };
use tracing::{Instrument, Span}; use tracing::{Instrument, Span};
use crate::{AppState, B64}; use crate::{
api::WsApiMessage,
utils::{axum_msg_to_tungstenite, tungstenite_msg_to_axum, QueryDisplay, WsError},
AppState, B64,
};
const AUDIO_CACHE_HEADER: HeaderValue = HeaderValue::from_static("private, max-age=604800"); const AUDIO_CACHE_HEADER: HeaderValue = HeaderValue::from_static("private, max-age=604800");
const REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); const REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
@ -59,23 +61,6 @@ fn remove_token_from_query(query: Option<&str>) -> HashMap<String, String> {
query_map query_map
} }
struct QueryDisplay {
map: HashMap<String, String>,
}
impl Display for QueryDisplay {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let length = self.map.len();
for (index, (k, v)) in self.map.iter().enumerate() {
write!(f, "{k}={v}")?;
if index < length - 1 {
write!(f, "&")?;
}
}
Ok(())
}
}
fn make_span_trace<B>(req: &Request<B>) -> Span { fn make_span_trace<B>(req: &Request<B>) -> Span {
let query_map = remove_token_from_query(req.uri().query()); let query_map = remove_token_from_query(req.uri().query());
@ -92,7 +77,7 @@ fn make_span_trace<B>(req: &Request<B>) -> Span {
id = %request_id, id = %request_id,
) )
} else { } else {
let query_display = QueryDisplay { map: query_map }; let query_display = QueryDisplay::new(query_map);
tracing::debug_span!( tracing::debug_span!(
"request", "request",
path = %req.uri().path(), path = %req.uri().path(),
@ -128,7 +113,9 @@ pub(super) async fn handler(state: AppState) -> Result<(Router, Router), AppErro
.route("/token/generate_for_music/:id", get(generate_scoped_token)) .route("/token/generate_for_music/:id", get(generate_scoped_token))
.route("/thumbnail/:id", get(http)) .route("/thumbnail/:id", get(http))
.route("/audio/external_id/:id", get(get_music)) .route("/audio/external_id/:id", get(get_music))
.route("/audio/scoped/:id", get(get_scoped_music)) .route("/share/audio/:token", get(get_scoped_music_file))
.route("/share/thumbnail/:token", get(get_scoped_music_thumbnail))
.route("/share/info/:token", get(get_scoped_music_info))
.route("/", get(metadata_ws)) .route("/", get(metadata_ws))
.layer(trace_layer) .layer(trace_layer)
.layer(sensitive_header_layer) .layer(sensitive_header_layer)
@ -188,12 +175,43 @@ async fn generate_scoped_token(
Ok(token.into_response()) Ok(token.into_response())
} }
async fn get_scoped_music( async fn get_scoped_music_info(
State(app): State<AppState>,
Path(token): Path<String>,
) -> Result<impl IntoResponse, AppError> {
let music_id = app.verify_scoped_token(token).await?;
let Some(info) = app.music_info.get(music_id).await else {
return Err("music id not found".into());
};
Ok(serde_json::to_string(&info).unwrap())
}
async fn get_scoped_music_thumbnail(
State(app): State<AppState>,
Path(token): Path<String>,
) -> Result<Response<Body>, AppError> {
let music_id = app.verify_scoped_token(token).await?;
let Some(info) = app.music_info.get(music_id).await else {
return Err("music id not found".into());
};
let req = Request::builder()
.uri(format!(
"http://{}:{}/thumbnail/{}",
app.musikcubed_address, app.musikcubed_http_port, info.thumbnail_id
))
.header(AUTHORIZATION, app.musikcubed_auth_header_value.clone())
.body(Body::empty())
.expect("cant fail");
let resp = app.client.request(req).await?;
Ok(resp)
}
async fn get_scoped_music_file(
State(app): State<AppState>, State(app): State<AppState>,
Path(token): Path<String>, Path(token): Path<String>,
request: Request<Body>, request: Request<Body>,
) -> Result<Response<Body>, AppError> { ) -> Result<Response<Body>, AppError> {
if let Some(music_id) = app.scoped_tokens.verify(token).await { let music_id = app.verify_scoped_token(token).await?;
let mut req = Request::builder() let mut req = Request::builder()
.uri(format!( .uri(format!(
"http://{}:{}/audio/external_id/{}", "http://{}:{}/audio/external_id/{}",
@ -213,12 +231,6 @@ async fn get_scoped_music(
.insert(CACHE_CONTROL, AUDIO_CACHE_HEADER.clone()); .insert(CACHE_CONTROL, AUDIO_CACHE_HEADER.clone());
} }
Ok(resp) Ok(resp)
} else {
Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body("Invalid scoped token".to_string().into())
.expect("cant fail"))
}
} }
async fn get_music( async fn get_music(
@ -340,16 +352,6 @@ async fn metadata_ws(
Ok(upgrade) Ok(upgrade)
} }
#[derive(Serialize, Deserialize)]
struct WsApiMessage {
name: String,
r#type: String,
id: String,
#[serde(default)]
device_id: Option<String>,
options: serde_json::value::Map<String, Value>,
}
async fn handle_metadata_socket( async fn handle_metadata_socket(
mut server_socket: WebSocketStream<TokioAdapter<TcpStream>>, mut server_socket: WebSocketStream<TokioAdapter<TcpStream>>,
mut client_socket: WebSocket, mut client_socket: WebSocket,
@ -410,20 +412,9 @@ async fn handle_metadata_socket(
let og_auth_reply = 'ok: { let og_auth_reply = 'ok: {
'err: { 'err: {
// send actual auth message to the musikcubed server // send actual auth message to the musikcubed server
let auth_msg = WsApiMessage { let auth_msg = WsApiMessage::authenticate(app.musikcubed_password.clone())
name: "authenticate".to_string(), .id(og_auth_msg.id)
r#type: "request".to_string(), .device_id(og_auth_msg.device_id.unwrap_or_default());
id: og_auth_msg.id,
device_id: og_auth_msg.device_id,
options: {
let mut map = serde_json::Map::with_capacity(1);
map.insert(
"password".to_string(),
app.musikcubed_password.clone().into(),
);
map
},
};
let auth_msg_ser = serde_json::to_string(&auth_msg).expect(""); let auth_msg_ser = serde_json::to_string(&auth_msg).expect("");
if let Err(err) = server_socket if let Err(err) = server_socket
.send(TungsteniteMessage::Text(auth_msg_ser)) .send(TungsteniteMessage::Text(auth_msg_ser))
@ -507,19 +498,35 @@ async fn handle_metadata_socket(
let in_read_fut = async move { let in_read_fut = async move {
while let Some(res) = in_read.next().await { while let Some(res) = in_read.next().await {
let res = res.map_err(WsError::from);
match res { match res {
Ok(msg) => { Ok(msg) => {
tracing::trace!("got message from client: {msg:?}"); tracing::trace!("got message from client: {msg:?}");
let res = out_write.send(axum_msg_to_tungstenite(msg)).await; let res = out_write
.send(axum_msg_to_tungstenite(msg))
.await
.map_err(WsError::from);
if let Err(err) = res { if let Err(err) = res {
match err {
WsError::Closed(reason) => {
tracing::error!("server socket was closed: {reason}");
break;
}
err => {
tracing::error!("could not write to server socket: {err}"); tracing::error!("could not write to server socket: {err}");
}
}
}
}
Err(err) => match err {
WsError::Closed(reason) => {
tracing::error!("client socket was closed, {reason}");
break; break;
} }
} err => {
Err(err) => {
tracing::error!("could not read from client socket: {err}"); tracing::error!("could not read from client socket: {err}");
break;
} }
},
} }
} }
let _ = out_write.send(TungsteniteMessage::Close(None)).await; let _ = out_write.send(TungsteniteMessage::Close(None)).await;
@ -528,19 +535,36 @@ async fn handle_metadata_socket(
let in_write_fut = async move { let in_write_fut = async move {
while let Some(res) = out_read.next().await { while let Some(res) = out_read.next().await {
let res = res.map_err(WsError::from);
match res { match res {
Ok(msg) => { Ok(msg) => {
tracing::trace!("got message from server: {msg:?}"); tracing::trace!("got message from server: {msg:?}");
let res = in_write.send(tungstenite_msg_to_axum(msg)).await; let res = in_write
.send(tungstenite_msg_to_axum(msg))
.await
.map_err(WsError::from);
if let Err(err) = res { if let Err(err) = res {
tracing::error!("could not write to client socket: {err}"); match err {
WsError::Closed(reason) => {
tracing::error!("client socket was closed, {reason}");
break;
}
err => {
tracing::error!("could not write to server socket: {err}");
}
}
break; break;
} }
} }
Err(err) => { Err(err) => match err {
WsError::Closed(reason) => {
tracing::error!("server socket was closed, {reason}");
break;
}
err => {
tracing::error!("could not read from server socket: {err}"); tracing::error!("could not read from server socket: {err}");
break;
} }
},
} }
} }
let _ = in_write.send(AxumMessage::Close(None)).await; let _ = in_write.send(AxumMessage::Close(None)).await;
@ -551,46 +575,3 @@ async fn handle_metadata_socket(
tracing::debug!("ending metadata ws task"); tracing::debug!("ending metadata ws task");
} }
#[inline(always)]
fn tungstenite_msg_to_axum(msg: TungsteniteMessage) -> AxumMessage {
match msg {
TungsteniteMessage::Text(data) => AxumMessage::Text(data),
TungsteniteMessage::Binary(data) => AxumMessage::Binary(data),
TungsteniteMessage::Ping(data) => AxumMessage::Ping(data),
TungsteniteMessage::Pong(data) => AxumMessage::Pong(data),
TungsteniteMessage::Close(frame) => {
AxumMessage::Close(frame.map(tungstenite_close_frame_to_axum))
}
TungsteniteMessage::Frame(_) => unreachable!("we don't use raw frames"),
}
}
#[inline(always)]
fn axum_msg_to_tungstenite(msg: AxumMessage) -> TungsteniteMessage {
match msg {
AxumMessage::Text(data) => TungsteniteMessage::Text(data),
AxumMessage::Binary(data) => TungsteniteMessage::Binary(data),
AxumMessage::Ping(data) => TungsteniteMessage::Ping(data),
AxumMessage::Pong(data) => TungsteniteMessage::Pong(data),
AxumMessage::Close(frame) => {
TungsteniteMessage::Close(frame.map(axum_close_frame_to_tungstenite))
}
}
}
#[inline(always)]
fn tungstenite_close_frame_to_axum(frame: TungsteniteCloseFrame) -> AxumCloseFrame {
AxumCloseFrame {
code: frame.code.into(),
reason: frame.reason,
}
}
#[inline(always)]
fn axum_close_frame_to_tungstenite(frame: AxumCloseFrame) -> TungsteniteCloseFrame {
TungsteniteCloseFrame {
code: frame.code.into(),
reason: frame.reason,
}
}

View File

@ -4,14 +4,25 @@ use axum_server::tls_rustls::RustlsConfig;
use base64::Engine; use base64::Engine;
use dotenvy::Error as DotenvError; use dotenvy::Error as DotenvError;
use error::AppError; use error::AppError;
use futures::{SinkExt, StreamExt};
use hyper::{client::HttpConnector, Body}; use hyper::{client::HttpConnector, Body};
use scc::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use token::{MusicScopedTokens, Tokens}; use token::{MusicScopedTokens, Tokens};
use tracing::{info, warn}; use tracing::{info, warn};
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
use crate::{
api::WsApiMessage,
utils::{HandleWsItem, WsError},
};
mod api;
mod error; mod error;
mod handler; mod handler;
mod token; mod token;
mod utils;
#[tokio::main] #[tokio::main]
async fn main() -> ExitCode { async fn main() -> ExitCode {
@ -103,6 +114,7 @@ const B64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STA
struct AppStateInternal { struct AppStateInternal {
client: Client, client: Client,
tokens: Tokens, tokens: Tokens,
music_info: MusicInfoMap,
scoped_tokens: MusicScopedTokens, scoped_tokens: MusicScopedTokens,
tokens_path: String, tokens_path: String,
public_port: u16, public_port: u16,
@ -115,14 +127,11 @@ struct AppStateInternal {
impl AppStateInternal { impl AppStateInternal {
async fn new(public_port: u16) -> Result<Self, AppError> { async fn new(public_port: u16) -> Result<Self, AppError> {
let musikcubed_http_port = get_conf("MUSIKCUBED_HTTP_PORT")?.parse()?;
let musikcubed_metadata_port = get_conf("MUSIKCUBED_METADATA_PORT")?.parse()?;
let musikcubed_address = get_conf("MUSIKCUBED_ADDRESS")?;
let musikcubed_password = get_conf("MUSIKCUBED_PASSWORD")?; let musikcubed_password = get_conf("MUSIKCUBED_PASSWORD")?;
let tokens_path = get_conf("TOKENS_FILE")?; let musikcubed_auth_header_value = {
let this = Self {
public_port,
musikcubed_address: get_conf("MUSIKCUBED_ADDRESS")?,
musikcubed_http_port: get_conf("MUSIKCUBED_HTTP_PORT")?.parse()?,
musikcubed_metadata_port: get_conf("MUSIKCUBED_METADATA_PORT")?.parse()?,
musikcubed_auth_header_value: {
let mut val: http::HeaderValue = format!( let mut val: http::HeaderValue = format!(
"Basic {}", "Basic {}",
B64.encode(format!("default:{}", musikcubed_password)) B64.encode(format!("default:{}", musikcubed_password))
@ -131,13 +140,122 @@ impl AppStateInternal {
.expect("valid header value"); .expect("valid header value");
val.set_sensitive(true); val.set_sensitive(true);
val val
}, };
musikcubed_password,
let tokens_path = get_conf("TOKENS_FILE")?;
let music_info = MusicInfoMap::new();
music_info
.read(
musikcubed_password.clone(),
&musikcubed_address,
musikcubed_metadata_port,
)
.await?;
let this = Self {
client: Client::new(), client: Client::new(),
tokens: Tokens::read(&tokens_path).await?, tokens: Tokens::read(&tokens_path).await?,
scoped_tokens: MusicScopedTokens::new(get_conf("SCOPED_EXPIRY_DURATION")?.parse()?), scoped_tokens: MusicScopedTokens::new(get_conf("SCOPED_EXPIRY_DURATION")?.parse()?),
musikcubed_address,
musikcubed_http_port,
musikcubed_metadata_port,
musikcubed_auth_header_value,
musikcubed_password,
tokens_path, tokens_path,
public_port,
music_info,
}; };
Ok(this) Ok(this)
} }
async fn verify_scoped_token(&self, token: impl AsRef<str>) -> Result<String, AppError> {
self.scoped_tokens.verify(token).await.ok_or_else(|| {
AppError::from("Invalid token or not authorized").status(http::StatusCode::UNAUTHORIZED)
})
}
}
#[derive(Clone, Deserialize, Serialize)]
struct MusicInfo {
external_id: String,
title: String,
album: String,
artist: String,
thumbnail_id: u32,
}
#[derive(Clone)]
struct MusicInfoMap {
map: HashMap<String, MusicInfo>,
}
impl MusicInfoMap {
fn new() -> Self {
Self {
map: HashMap::new(),
}
}
async fn get(&self, id: impl AsRef<str>) -> Option<MusicInfo> {
self.map.read_async(id.as_ref(), |_, v| v.clone()).await
}
async fn read(
&self,
password: impl Into<String>,
address: impl AsRef<str>,
port: u16,
) -> Result<(), AppError> {
use async_tungstenite::tungstenite::Message;
let uri = format!("ws://{}:{}", address.as_ref(), port);
let (mut ws_stream, _) = async_tungstenite::tokio::connect_async(uri)
.await
.map_err(WsError::from)?;
let device_id = "musikquadrupled";
// do the authentication
let auth_msg = WsApiMessage::authenticate(password.into())
.id("auth")
.device_id(device_id);
ws_stream
.send(Message::Text(auth_msg.to_string()))
.await
.map_err(WsError::from)?;
let auth_reply: WsApiMessage = ws_stream.next().await.handle_item()?;
let is_authenticated = auth_reply
.options
.get("authenticated")
.and_then(Value::as_bool)
.unwrap_or_default();
if !is_authenticated {
return Err("not authenticated".into());
}
// fetch the tracks
let fetch_tracks_msg = WsApiMessage::request("query_tracks")
.device_id(device_id)
.id("fetch_tracks")
.option("limit", u32::MAX)
.option("offset", 0);
ws_stream
.send(Message::Text(fetch_tracks_msg.to_string()))
.await
.map_err(WsError::from)?;
let mut tracks_reply: WsApiMessage = ws_stream.next().await.handle_item()?;
let Some(Value::Array(tracks)) = tracks_reply.options.remove("data") else {
tracing::debug!("reply: {tracks_reply:#?}");
return Err("must have tracks".into());
};
for track in tracks {
let info: MusicInfo = serde_json::from_value(track).unwrap();
let _ = self.map.insert_async(info.external_id.clone(), info).await;
}
ws_stream.close(None).await.map_err(WsError::from)?;
Ok(())
}
} }

198
src/utils.rs Normal file
View File

@ -0,0 +1,198 @@
use std::{borrow::Cow, collections::HashMap, error::Error, fmt::Display};
use async_tungstenite::tungstenite::{
protocol::CloseFrame as TungsteniteCloseFrame, Error as TungsteniteError,
Message as TungsteniteMessage,
};
use axum::{
extract::ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
Error as AxumError,
};
#[derive(Debug)]
pub(crate) enum WsError {
Closed(Cow<'static, str>),
InvalidMessage(serde_json::Error),
Other(Box<dyn Error + Send>),
Ping(Vec<u8>),
Pong(Vec<u8>),
}
impl WsError {
pub(crate) fn is_closed(&self) -> bool {
match self {
WsError::Closed(_) => true,
_ => false,
}
}
}
impl Display for WsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Closed(reason) => write!(f, "ws closed, reason: {reason}"),
Self::InvalidMessage(err) => write!(f, "invalid JSON message, {err}"),
Self::Other(err) => write!(f, "error: {err}"),
Self::Ping(_) => write!(f, "was a ping"),
Self::Pong(_) => write!(f, "was a pong"),
}
}
}
impl Error for WsError {}
impl From<TungsteniteError> for WsError {
fn from(err: TungsteniteError) -> Self {
match err {
TungsteniteError::ConnectionClosed => WsError::Closed("was closed".into()),
TungsteniteError::AlreadyClosed => WsError::Closed("was already closed".into()),
err => WsError::Other(Box::new(err)),
}
}
}
impl From<AxumError> for WsError {
fn from(err: AxumError) -> Self {
use axum_tungstenite::Error as AxumError;
let err = match err.into_inner().downcast::<AxumError>() {
Ok(err) => *err,
Err(err) => return WsError::Other(err),
};
match err {
AxumError::ConnectionClosed => WsError::Closed("was closed".into()),
AxumError::AlreadyClosed => WsError::Closed("was already closed".into()),
err => WsError::Other(Box::new(err)),
}
}
}
pub(crate) trait HandleWsItem {
fn handle_item<T: serde::de::DeserializeOwned>(self) -> Result<T, WsError>;
}
impl HandleWsItem for Option<Result<TungsteniteMessage, TungsteniteError>> {
fn handle_item<T: serde::de::DeserializeOwned>(self) -> Result<T, WsError> {
let Some(result) = self else {
return Err(WsError::Closed("was already closed".into()));
};
match result? {
TungsteniteMessage::Binary(data) => {
serde_json::from_slice(&data).map_err(WsError::InvalidMessage)
}
TungsteniteMessage::Text(data) => {
serde_json::from_str(&data).map_err(WsError::InvalidMessage)
}
TungsteniteMessage::Close(frame) => Err(WsError::Closed(
frame
.map(|f| format!("was closed, reason {} (code {})", f.reason, f.code).into())
.unwrap_or_else(|| "was closed".into()),
)),
TungsteniteMessage::Ping(data) => Err(WsError::Ping(data)),
TungsteniteMessage::Pong(data) => Err(WsError::Pong(data)),
TungsteniteMessage::Frame(_) => unreachable!("we don't read raw frames"),
}
}
}
impl HandleWsItem for Result<TungsteniteMessage, TungsteniteError> {
fn handle_item<T: serde::de::DeserializeOwned>(self) -> Result<T, WsError> {
Some(self).handle_item()
}
}
impl HandleWsItem for Option<Result<AxumMessage, AxumError>> {
fn handle_item<T: serde::de::DeserializeOwned>(self) -> Result<T, WsError> {
let Some(result) = self else {
return Err(WsError::Closed("was already closed".into()));
};
match result? {
AxumMessage::Binary(data) => {
serde_json::from_slice(&data).map_err(WsError::InvalidMessage)
}
AxumMessage::Text(data) => serde_json::from_str(&data).map_err(WsError::InvalidMessage),
AxumMessage::Close(frame) => Err(WsError::Closed(
frame
.map(|f| format!("was closed, reason {} (code {})", f.reason, f.code).into())
.unwrap_or_else(|| "was closed".into()),
)),
AxumMessage::Ping(data) => Err(WsError::Ping(data)),
AxumMessage::Pong(data) => Err(WsError::Pong(data)),
}
}
}
impl HandleWsItem for Result<AxumMessage, AxumError> {
fn handle_item<T: serde::de::DeserializeOwned>(self) -> Result<T, WsError> {
Some(self).handle_item()
}
}
#[inline(always)]
pub(crate) fn tungstenite_msg_to_axum(msg: TungsteniteMessage) -> AxumMessage {
match msg {
TungsteniteMessage::Text(data) => AxumMessage::Text(data),
TungsteniteMessage::Binary(data) => AxumMessage::Binary(data),
TungsteniteMessage::Ping(data) => AxumMessage::Ping(data),
TungsteniteMessage::Pong(data) => AxumMessage::Pong(data),
TungsteniteMessage::Close(frame) => {
AxumMessage::Close(frame.map(tungstenite_close_frame_to_axum))
}
TungsteniteMessage::Frame(_) => unreachable!("we don't use raw frames"),
}
}
#[inline(always)]
pub(crate) fn axum_msg_to_tungstenite(msg: AxumMessage) -> TungsteniteMessage {
match msg {
AxumMessage::Text(data) => TungsteniteMessage::Text(data),
AxumMessage::Binary(data) => TungsteniteMessage::Binary(data),
AxumMessage::Ping(data) => TungsteniteMessage::Ping(data),
AxumMessage::Pong(data) => TungsteniteMessage::Pong(data),
AxumMessage::Close(frame) => {
TungsteniteMessage::Close(frame.map(axum_close_frame_to_tungstenite))
}
}
}
#[inline(always)]
fn tungstenite_close_frame_to_axum(frame: TungsteniteCloseFrame) -> AxumCloseFrame {
AxumCloseFrame {
code: frame.code.into(),
reason: frame.reason,
}
}
#[inline(always)]
fn axum_close_frame_to_tungstenite(frame: AxumCloseFrame) -> TungsteniteCloseFrame {
TungsteniteCloseFrame {
code: frame.code.into(),
reason: frame.reason,
}
}
pub(crate) struct QueryDisplay {
map: HashMap<String, String>,
}
impl QueryDisplay {
pub(crate) fn new(map: HashMap<String, String>) -> Self {
Self { map }
}
}
impl Display for QueryDisplay {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let length = self.map.len();
for (index, (k, v)) in self.map.iter().enumerate() {
write!(f, "{k}={v}")?;
if index < length - 1 {
write!(f, "&")?;
}
}
Ok(())
}
}