diff --git a/src/handler.rs b/src/handler.rs index 05cdbf9..ce979ee 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -9,7 +9,7 @@ use async_tungstenite::{ use axum::{ extract::{ ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage, WebSocket, WebSocketUpgrade}, - ConnectInfo, Query, State, + ConnectInfo, Path, Query, State, }, headers::UserAgent, middleware::Next, @@ -32,9 +32,7 @@ use tower_http::{ }; use tracing::{Instrument, Span}; -use crate::{get_conf, AppState}; - -const B64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD; +use crate::{get_conf, AppState, B64}; #[derive(Deserialize)] struct Auth { @@ -137,8 +135,10 @@ pub(super) async fn handler(state: AppState) -> Result { .layer(axum::middleware::from_fn(block_external_ips)); let router = Router::new() + .route("/token/generate_for_music/:id", get(generate_scoped_token)) .route("/thumbnail/:id", get(http)) .route("/audio/external_id/:id", get(http)) + .route("/audio/scoped/:id", get(get_scoped_music)) .route("/", get(metadata_ws)) .layer(SetSensitiveRequestHeadersLayer::new([AUTHORIZATION])) .layer(trace_layer) @@ -174,6 +174,58 @@ async fn generate_token(State(app): State) -> Result, + Query(query): Query, + Path(music_id): Path, +) -> Result { + let maybe_token = query.token; + + 'ok: { + tracing::debug!("verifying token: {maybe_token:?}"); + if let Some(token) = maybe_token { + if app.tokens.verify(token).await? { + break 'ok; + } + } + return Ok(( + StatusCode::UNAUTHORIZED, + "Invalid token or token not present", + ) + .into_response()); + } + + // generate token + let token = app.scoped_tokens.generate_for_id(music_id).await; + Ok(token.into_response()) +} + +async fn get_scoped_music( + State(app): State, + Path(token): Path, +) -> Result, AppError> { + if let Some(music_id) = app.scoped_tokens.verify(token).await { + Ok(app + .client + .request( + Request::builder() + .uri(format!( + "http://{}:{}/audio/external_id/{}", + app.musikcubed_address, app.musikcubed_http_port, music_id + )) + .header(AUTHORIZATION, app.musikcubed_auth_header_value.clone()) + .body(Body::empty()) + .expect("cant fail"), + ) + .await?) + } else { + Ok(Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body("Invalid scoped token".to_string().into()) + .expect("cant fail")) + } +} + async fn http( State(app): State, Query(query): Query, @@ -212,11 +264,8 @@ async fn http( .expect("cant fail")); } - let auth = B64.encode(format!("default:{}", app.musikcubed_password)); - req.headers_mut().insert( - AUTHORIZATION, - format!("Basic {auth}").parse().expect("valid header value"), - ); + req.headers_mut() + .insert(AUTHORIZATION, app.musikcubed_auth_header_value.clone()); Ok(app.client.request(req).await?) } diff --git a/src/main.rs b/src/main.rs index a0fae15..ea085e7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,11 @@ use std::{net::SocketAddr, process::ExitCode, sync::Arc}; use axum_server::tls_rustls::RustlsConfig; +use base64::Engine; use dotenvy::Error as DotenvError; use error::AppError; use hyper::{client::HttpConnector, Body}; -use token::Tokens; +use token::{MusicScopedTokens, Tokens}; use tracing::{info, warn}; use tracing_subscriber::prelude::*; @@ -82,29 +83,41 @@ type Client = hyper::Client; type AppState = Arc; +const B64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD; + #[derive(Clone)] struct AppStateInternal { client: Client, tokens: Tokens, + scoped_tokens: MusicScopedTokens, tokens_path: String, public_port: u16, musikcubed_address: String, musikcubed_http_port: u16, musikcubed_metadata_port: u16, musikcubed_password: String, + musikcubed_auth_header_value: http::HeaderValue, } impl AppStateInternal { async fn new(public_port: u16) -> Result { + let musikcubed_password = get_conf("MUSIKCUBED_PASSWORD")?; let tokens_path = get_conf("TOKENS_FILE")?; let this = Self { public_port, musikcubed_address: get_conf("MUSIKCUBED_ADDRESS")?, musikcubed_http_port: get_conf("MUSIKCUBED_HTTP_PORT")?.parse()?, musikcubed_metadata_port: get_conf("MUSIKCUBED_METADATA_PORT")?.parse()?, - musikcubed_password: get_conf("MUSIKCUBED_PASSWORD")?, + musikcubed_auth_header_value: format!( + "Basic {}", + B64.encode(format!("default:{}", musikcubed_password)) + ) + .parse() + .expect("valid header value"), + musikcubed_password, client: Client::new(), tokens: Tokens::read(&tokens_path).await?, + scoped_tokens: MusicScopedTokens::new(get_conf("SCOPED_EXPIRY_DURATION")?.parse()?), tokens_path, }; Ok(this) diff --git a/src/token.rs b/src/token.rs index 9ea81b2..4849730 100644 --- a/src/token.rs +++ b/src/token.rs @@ -1,16 +1,29 @@ use rand::Rng; -use scc::HashSet; +use scc::{HashMap, HashSet}; use std::borrow::Cow; use std::fmt::Write; use std::path::Path; use std::sync::Arc; +use std::time::UNIX_EPOCH; use crate::error::AppError; +fn get_current_time() -> u64 { + UNIX_EPOCH.elapsed().unwrap().as_secs() +} + fn hash_string(data: &[u8]) -> Result { argon2::hash_encoded(data, "11111111".as_bytes(), &argon2::Config::default()) } +fn generate_random_string(len: usize) -> String { + rand::thread_rng() + .sample_iter(rand::distributions::Alphanumeric) + .take(len) + .map(|c| c as char) + .collect::() +} + #[derive(Debug, Clone)] pub(crate) struct Tokens { hashed: Arc>>, @@ -49,11 +62,7 @@ impl Tokens { } pub async fn generate(&self) -> Result { - let token = rand::thread_rng() - .sample_iter(rand::distributions::Alphanumeric) - .take(30) - .map(|c| c as char) - .collect::(); + let token = generate_random_string(30); let token_hash = hash_string(token.as_bytes())?; @@ -71,3 +80,45 @@ impl Tokens { self.hashed.clear_async().await; } } + +#[derive(Clone)] +pub(crate) struct MusicScopedTokens { + map: Arc>, + expiry_time: u64, +} + +impl MusicScopedTokens { + pub fn new(expiry_time: u64) -> Self { + Self { + map: Arc::new(HashMap::new()), + expiry_time, + } + } + + pub async fn generate_for_id(&self, music_id: String) -> String { + let data = MusicScopedToken { + creation: get_current_time(), + music_id, + }; + let token = generate_random_string(12); + + let _ = self.map.insert_async(token.clone(), data).await; + return token; + } + + pub async fn verify(&self, token: impl AsRef) -> Option { + let token = token.as_ref(); + let data = self.map.read_async(token, |_, v| v.clone()).await?; + if get_current_time() - data.creation > self.expiry_time { + self.map.remove_async(token).await; + return None; + } + Some(data.music_id) + } +} + +#[derive(Clone)] +pub(crate) struct MusicScopedToken { + creation: u64, + music_id: String, +}