refactor routers, add revoke all tokens

This commit is contained in:
dusk 2023-04-21 20:37:13 +03:00
parent 93e3617472
commit 8d509136ba
Signed by: dusk
GPG Key ID: 1D8F8FAF2294D6EA
2 changed files with 37 additions and 16 deletions

View File

@ -12,8 +12,9 @@ use axum::{
ConnectInfo, Query, State, ConnectInfo, Query, State,
}, },
headers::UserAgent, headers::UserAgent,
middleware::Next,
response::IntoResponse, response::IntoResponse,
routing::get, routing::{get, post},
Router, TypedHeader, Router, TypedHeader,
}; };
use base64::Engine; use base64::Engine;
@ -48,12 +49,28 @@ fn extract_password_from_basic_auth(auth: &str) -> Result<String, AppError> {
Ok(auth.trim_start_matches("default:").to_string()) Ok(auth.trim_start_matches("default:").to_string())
} }
async fn block_external_ips<B>(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
req: Request<B>,
next: Next<B>,
) -> Result<impl IntoResponse, impl IntoResponse> {
if addr.ip().is_loopback() {
Ok(next.run(req).await)
} else {
Err((StatusCode::FORBIDDEN, "not allowed"))
}
}
pub(super) async fn handler(state: AppState) -> Result<Router, AppError> { pub(super) async fn handler(state: AppState) -> Result<Router, AppError> {
let trace_layer = let trace_layer =
TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().include_headers(true)); TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().include_headers(true));
let internal_router = Router::new()
.route("/token/generate", get(generate_token))
.route("/token/revoke_all", post(revoke_all_tokens))
.layer(axum::middleware::from_fn(block_external_ips));
let router = Router::new() let router = Router::new()
.route("/generate_token", get(generate_token))
.route("/thumbnails/:id", get(http)) .route("/thumbnails/:id", get(http))
.route("/audio/id/:id", get(http)) .route("/audio/id/:id", get(http))
.route("/", get(metadata_ws)) .route("/", get(metadata_ws))
@ -62,23 +79,23 @@ pub(super) async fn handler(state: AppState) -> Result<Router, AppError> {
CorsLayer::new() CorsLayer::new()
.allow_origin(get_conf("CORS_ALLOW_ORIGIN")?.parse::<HeaderValue>()?) .allow_origin(get_conf("CORS_ALLOW_ORIGIN")?.parse::<HeaderValue>()?)
.allow_headers([CONTENT_TYPE]) .allow_headers([CONTENT_TYPE])
.allow_methods([Method::GET]) .allow_methods([Method::GET]),
.allow_credentials(true), );
)
.with_state(state.clone());
Ok(router) Ok(router.merge(internal_router).with_state(state))
} }
async fn generate_token( async fn revoke_all_tokens(State(app): State<AppState>) -> impl IntoResponse {
State(app): State<AppState>, app.tokens.revoke_all().await;
ConnectInfo(addr): ConnectInfo<SocketAddr>, tokio::spawn(async move {
) -> Result<axum::response::Response, AppError> { if let Err(err) = app.tokens.write(&app.tokens_path).await {
// disallow any connections other than local ones tracing::error!("couldn't write tokens file: {err}");
if !addr.ip().is_loopback() { }
return Ok((StatusCode::FORBIDDEN, "not allowed").into_response()); });
StatusCode::OK
} }
async fn generate_token(State(app): State<AppState>) -> Result<axum::response::Response, AppError> {
// generate token // generate token
let token = app.tokens.generate().await?; let token = app.tokens.generate().await?;
// start task to write tokens // start task to write tokens
@ -87,7 +104,7 @@ async fn generate_token(
tracing::error!("couldn't write tokens file: {err}"); tracing::error!("couldn't write tokens file: {err}");
} }
}); });
return Ok(token.into_response()); Ok(token.into_response())
} }
async fn http( async fn http(

View File

@ -66,4 +66,8 @@ impl Tokens {
let token_hash = hash_string(token.as_bytes())?; let token_hash = hash_string(token.as_bytes())?;
Ok(self.hashed.contains_async(&Cow::Owned(token_hash)).await) Ok(self.hashed.contains_async(&Cow::Owned(token_hash)).await)
} }
pub async fn revoke_all(&self) {
self.hashed.clear_async().await;
}
} }