diff --git a/src/handler.rs b/src/handler.rs index 6a04018..f8da7ba 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -12,8 +12,9 @@ use axum::{ ConnectInfo, Query, State, }, headers::UserAgent, + middleware::Next, response::IntoResponse, - routing::get, + routing::{get, post}, Router, TypedHeader, }; use base64::Engine; @@ -48,12 +49,28 @@ fn extract_password_from_basic_auth(auth: &str) -> Result { Ok(auth.trim_start_matches("default:").to_string()) } +async fn block_external_ips( + ConnectInfo(addr): ConnectInfo, + req: Request, + next: Next, +) -> Result { + if addr.ip().is_loopback() { + Ok(next.run(req).await) + } else { + Err((StatusCode::FORBIDDEN, "not allowed")) + } +} + pub(super) async fn handler(state: AppState) -> Result { let trace_layer = TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().include_headers(true)); + let internal_router = Router::new() + .route("/token/generate", get(generate_token)) + .route("/token/revoke_all", post(revoke_all_tokens)) + .layer(axum::middleware::from_fn(block_external_ips)); + let router = Router::new() - .route("/generate_token", get(generate_token)) .route("/thumbnails/:id", get(http)) .route("/audio/id/:id", get(http)) .route("/", get(metadata_ws)) @@ -62,23 +79,23 @@ pub(super) async fn handler(state: AppState) -> Result { CorsLayer::new() .allow_origin(get_conf("CORS_ALLOW_ORIGIN")?.parse::()?) .allow_headers([CONTENT_TYPE]) - .allow_methods([Method::GET]) - .allow_credentials(true), - ) - .with_state(state.clone()); + .allow_methods([Method::GET]), + ); - Ok(router) + Ok(router.merge(internal_router).with_state(state)) } -async fn generate_token( - State(app): State, - ConnectInfo(addr): ConnectInfo, -) -> Result { - // disallow any connections other than local ones - if !addr.ip().is_loopback() { - return Ok((StatusCode::FORBIDDEN, "not allowed").into_response()); - } +async fn revoke_all_tokens(State(app): State) -> impl IntoResponse { + app.tokens.revoke_all().await; + tokio::spawn(async move { + if let Err(err) = app.tokens.write(&app.tokens_path).await { + tracing::error!("couldn't write tokens file: {err}"); + } + }); + StatusCode::OK +} +async fn generate_token(State(app): State) -> Result { // generate token let token = app.tokens.generate().await?; // start task to write tokens @@ -87,7 +104,7 @@ async fn generate_token( tracing::error!("couldn't write tokens file: {err}"); } }); - return Ok(token.into_response()); + Ok(token.into_response()) } async fn http( diff --git a/src/token.rs b/src/token.rs index 5ad7c44..9ea81b2 100644 --- a/src/token.rs +++ b/src/token.rs @@ -66,4 +66,8 @@ impl Tokens { let token_hash = hash_string(token.as_bytes())?; Ok(self.hashed.contains_async(&Cow::Owned(token_hash)).await) } + + pub async fn revoke_all(&self) { + self.hashed.clear_async().await; + } }