feat: separate internal and public functions

This commit is contained in:
dusk 2023-05-05 11:32:56 +03:00
parent 52d82eed80
commit e935668b85
Signed by: dusk
GPG Key ID: 1D8F8FAF2294D6EA
2 changed files with 29 additions and 27 deletions

View File

@ -12,7 +12,6 @@ use axum::{
ConnectInfo, Path, Query, State, ConnectInfo, Path, Query, State,
}, },
headers::UserAgent, headers::UserAgent,
middleware::Next,
response::IntoResponse, response::IntoResponse,
routing::{get, post}, routing::{get, post},
Router, TypedHeader, Router, TypedHeader,
@ -46,18 +45,6 @@ fn extract_password_from_basic_auth(auth: &str) -> Result<String, AppError> {
Ok(auth.trim_start_matches("default:").to_string()) Ok(auth.trim_start_matches("default:").to_string())
} }
async fn block_external_ips<B>(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
req: Request<B>,
next: Next<B>,
) -> Result<impl IntoResponse, impl IntoResponse> {
if addr.ip().is_loopback() {
Ok(next.run(req).await)
} else {
Err((StatusCode::FORBIDDEN, "not allowed"))
}
}
struct ComponentDisplay<Left, Right> { struct ComponentDisplay<Left, Right> {
left: Left, left: Left,
right: Right, right: Right,
@ -126,13 +113,13 @@ fn make_span_trace<B>(req: &Request<B>) -> Span {
) )
} }
pub(super) async fn handler(state: AppState) -> Result<Router, AppError> { pub(super) async fn handler(state: AppState) -> Result<(Router, Router), AppError> {
let trace_layer = TraceLayer::new_for_http().make_span_with(make_span_trace); let trace_layer = TraceLayer::new_for_http().make_span_with(make_span_trace);
let internal_router = Router::new() let internal_router = Router::new()
.route("/token/generate", get(generate_token)) .route("/token/generate", get(generate_token))
.route("/token/revoke_all", post(revoke_all_tokens)) .route("/token/revoke_all", post(revoke_all_tokens))
.layer(axum::middleware::from_fn(block_external_ips)); .with_state(state.clone());
let router = Router::new() let router = Router::new()
.route("/token/generate_for_music/:id", get(generate_scoped_token)) .route("/token/generate_for_music/:id", get(generate_scoped_token))
@ -147,9 +134,10 @@ pub(super) async fn handler(state: AppState) -> Result<Router, AppError> {
.allow_origin(tower_http::cors::Any) .allow_origin(tower_http::cors::Any)
.allow_headers([CONTENT_TYPE]) .allow_headers([CONTENT_TYPE])
.allow_methods([Method::GET]), .allow_methods([Method::GET]),
); )
.with_state(state);
Ok(router.merge(internal_router).with_state(state)) Ok((router, internal_router))
} }
async fn revoke_all_tokens(State(app): State<AppState>) -> impl IntoResponse { async fn revoke_all_tokens(State(app): State<AppState>) -> impl IntoResponse {

View File

@ -45,31 +45,45 @@ async fn app() -> Result<(), AppError> {
let cert_path = get_conf("TLS_CERT_PATH").ok(); let cert_path = get_conf("TLS_CERT_PATH").ok();
let key_path = get_conf("TLS_KEY_PATH").ok(); let key_path = get_conf("TLS_KEY_PATH").ok();
let addr: SocketAddr = get_conf("ADDRESS")?.parse()?; let public_port: u16 = get_conf("PORT")?.parse()?;
let internal_port: u16 = get_conf("INTERNAL_PORT")?.parse()?;
let state = AppState::new(AppStateInternal::new(addr.port()).await?); let state = AppState::new(AppStateInternal::new(public_port).await?);
let router = handler::handler(state).await?; let (public_router, internal_router) = handler::handler(state).await?;
let make_service = router.into_make_service_with_connect_info::<SocketAddr>();
let (task, scheme) = if let (Some(cert_path), Some(key_path)) = (cert_path, key_path) { let internal_make_service = internal_router.into_make_service();
let internal_task = tokio::spawn(
axum_server::bind(SocketAddr::from(([127, 0, 0, 1], internal_port)))
.serve(internal_make_service),
);
let public_addr = SocketAddr::from(([127, 0, 0, 1], public_port));
let public_make_service = public_router.into_make_service();
let (pub_task, scheme) = if let (Some(cert_path), Some(key_path)) = (cert_path, key_path) {
info!("cert path is: {cert_path}"); info!("cert path is: {cert_path}");
info!("key path is: {key_path}"); info!("key path is: {key_path}");
let config = RustlsConfig::from_pem_file(cert_path, key_path).await?; let config = RustlsConfig::from_pem_file(cert_path, key_path).await?;
let task = tokio::spawn(axum_server::bind_rustls(addr, config).serve(make_service)); let task =
tokio::spawn(axum_server::bind_rustls(public_addr, config).serve(public_make_service));
(task, "https") (task, "https")
} else if get_conf("INSECURE").ok().is_some() { } else if get_conf("INSECURE").ok().is_some() {
tracing::warn!("RUNNING IN INSECURE MODE (NO TLS)"); tracing::warn!("RUNNING IN INSECURE MODE (NO TLS)");
let task = tokio::spawn(axum_server::bind(addr).serve(make_service)); let task = tokio::spawn(axum_server::bind(public_addr).serve(public_make_service));
(task, "http") (task, "http")
} else { } else {
tracing::warn!("note: either one or both of MUSIKQUAD_TLS_CERT_PATH and MUSIKQUAD_TLS_KEY_PATH has not been set"); tracing::warn!("note: either one or both of MUSIKQUAD_TLS_CERT_PATH and MUSIKQUAD_TLS_KEY_PATH has not been set");
return Err("will not serve HTTP unless the MUSIKQUAD_INSECURE env var is set".into()); return Err("will not serve HTTP unless the MUSIKQUAD_INSECURE env var is set".into());
}; };
info!("listening on {scheme}://{addr}"); info!("listening on {scheme}://{public_addr}");
task.await??;
Ok(()) let res = tokio::select! {
res = internal_task => res,
res = pub_task => res,
};
res.map_err(AppError::from)
.map(|res| res.map_err(AppError::from))
.and_then(std::convert::identity)
} }
fn get_conf(key: &str) -> Result<String, AppError> { fn get_conf(key: &str) -> Result<String, AppError> {