diff --git a/src/handler.rs b/src/handler.rs index 4df1044..d4c1769 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -12,7 +12,6 @@ use axum::{ ConnectInfo, Path, Query, State, }, headers::UserAgent, - middleware::Next, response::IntoResponse, routing::{get, post}, Router, TypedHeader, @@ -46,18 +45,6 @@ fn extract_password_from_basic_auth(auth: &str) -> Result { Ok(auth.trim_start_matches("default:").to_string()) } -async fn block_external_ips( - ConnectInfo(addr): ConnectInfo, - req: Request, - next: Next, -) -> Result { - if addr.ip().is_loopback() { - Ok(next.run(req).await) - } else { - Err((StatusCode::FORBIDDEN, "not allowed")) - } -} - struct ComponentDisplay { left: Left, right: Right, @@ -126,13 +113,13 @@ fn make_span_trace(req: &Request) -> Span { ) } -pub(super) async fn handler(state: AppState) -> Result { +pub(super) async fn handler(state: AppState) -> Result<(Router, Router), AppError> { let trace_layer = TraceLayer::new_for_http().make_span_with(make_span_trace); let internal_router = Router::new() .route("/token/generate", get(generate_token)) .route("/token/revoke_all", post(revoke_all_tokens)) - .layer(axum::middleware::from_fn(block_external_ips)); + .with_state(state.clone()); let router = Router::new() .route("/token/generate_for_music/:id", get(generate_scoped_token)) @@ -147,9 +134,10 @@ pub(super) async fn handler(state: AppState) -> Result { .allow_origin(tower_http::cors::Any) .allow_headers([CONTENT_TYPE]) .allow_methods([Method::GET]), - ); + ) + .with_state(state); - Ok(router.merge(internal_router).with_state(state)) + Ok((router, internal_router)) } async fn revoke_all_tokens(State(app): State) -> impl IntoResponse { diff --git a/src/main.rs b/src/main.rs index ea085e7..f1b6875 100644 --- a/src/main.rs +++ b/src/main.rs @@ -45,31 +45,45 @@ async fn app() -> Result<(), AppError> { let cert_path = get_conf("TLS_CERT_PATH").ok(); let key_path = get_conf("TLS_KEY_PATH").ok(); - let addr: SocketAddr = get_conf("ADDRESS")?.parse()?; + let public_port: u16 = get_conf("PORT")?.parse()?; + let internal_port: u16 = get_conf("INTERNAL_PORT")?.parse()?; - let state = AppState::new(AppStateInternal::new(addr.port()).await?); - let router = handler::handler(state).await?; - let make_service = router.into_make_service_with_connect_info::(); + let state = AppState::new(AppStateInternal::new(public_port).await?); + let (public_router, internal_router) = handler::handler(state).await?; - let (task, scheme) = if let (Some(cert_path), Some(key_path)) = (cert_path, key_path) { + let internal_make_service = internal_router.into_make_service(); + let internal_task = tokio::spawn( + axum_server::bind(SocketAddr::from(([127, 0, 0, 1], internal_port))) + .serve(internal_make_service), + ); + + let public_addr = SocketAddr::from(([127, 0, 0, 1], public_port)); + let public_make_service = public_router.into_make_service(); + let (pub_task, scheme) = if let (Some(cert_path), Some(key_path)) = (cert_path, key_path) { info!("cert path is: {cert_path}"); info!("key path is: {key_path}"); let config = RustlsConfig::from_pem_file(cert_path, key_path).await?; - let task = tokio::spawn(axum_server::bind_rustls(addr, config).serve(make_service)); + let task = + tokio::spawn(axum_server::bind_rustls(public_addr, config).serve(public_make_service)); (task, "https") } else if get_conf("INSECURE").ok().is_some() { tracing::warn!("RUNNING IN INSECURE MODE (NO TLS)"); - let task = tokio::spawn(axum_server::bind(addr).serve(make_service)); + let task = tokio::spawn(axum_server::bind(public_addr).serve(public_make_service)); (task, "http") } else { tracing::warn!("note: either one or both of MUSIKQUAD_TLS_CERT_PATH and MUSIKQUAD_TLS_KEY_PATH has not been set"); return Err("will not serve HTTP unless the MUSIKQUAD_INSECURE env var is set".into()); }; - info!("listening on {scheme}://{addr}"); - task.await??; + info!("listening on {scheme}://{public_addr}"); - Ok(()) + let res = tokio::select! { + res = internal_task => res, + res = pub_task => res, + }; + res.map_err(AppError::from) + .map(|res| res.map_err(AppError::from)) + .and_then(std::convert::identity) } fn get_conf(key: &str) -> Result {