feat: implement better errors, send less things in requests

This commit is contained in:
dusk 2023-05-05 17:16:03 +03:00
parent 66041f133c
commit e259740d50
Signed by: dusk
GPG Key ID: 1D8F8FAF2294D6EA
2 changed files with 114 additions and 67 deletions

View File

@ -1,4 +1,4 @@
use std::{borrow::Cow, collections::HashMap, fmt::Display}; use std::{collections::HashMap, fmt::Display};
use super::AppError; use super::AppError;
use async_tungstenite::{ use async_tungstenite::{
@ -19,8 +19,8 @@ use axum::{
use base64::Engine; use base64::Engine;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use http::{ use http::{
header::{AUTHORIZATION, CACHE_CONTROL, CONTENT_TYPE}, header::{AUTHORIZATION, CACHE_CONTROL, CONTENT_TYPE, RANGE},
HeaderName, HeaderValue, Method, Request, Response, StatusCode, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode,
}; };
use hyper::Body; use hyper::Body;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -51,11 +51,19 @@ fn extract_password_from_basic_auth(auth: &str) -> Result<String, AppError> {
Ok(auth.trim_start_matches("default:").to_string()) Ok(auth.trim_start_matches("default:").to_string())
} }
struct QueryDisplay<'a, 'b> { fn remove_token_from_query(query: Option<&str>) -> HashMap<String, String> {
map: HashMap<Cow<'a, str>, Cow<'b, str>>, let mut query_map: HashMap<String, String> = query
.and_then(|v| serde_qs::from_str(v).ok())
.unwrap_or_else(HashMap::new);
query_map.remove("token");
query_map
} }
impl<'a, 'b> Display for QueryDisplay<'a, 'b> { struct QueryDisplay {
map: HashMap<String, String>,
}
impl Display for QueryDisplay {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let length = self.map.len(); let length = self.map.len();
for (index, (k, v)) in self.map.iter().enumerate() { for (index, (k, v)) in self.map.iter().enumerate() {
@ -69,11 +77,7 @@ impl<'a, 'b> Display for QueryDisplay<'a, 'b> {
} }
fn make_span_trace<B>(req: &Request<B>) -> Span { fn make_span_trace<B>(req: &Request<B>) -> Span {
let query = req.uri().query(); let query_map = remove_token_from_query(req.uri().query());
let mut query_map = query
.and_then(|v| serde_qs::from_str::<HashMap<Cow<str>, Cow<str>>>(v).ok())
.unwrap_or_else(HashMap::new);
query_map.remove("token");
let request_id = req let request_id = req
.headers() .headers()
@ -99,6 +103,11 @@ fn make_span_trace<B>(req: &Request<B>) -> Span {
} }
pub(super) async fn handler(state: AppState) -> Result<(Router, Router), AppError> { pub(super) async fn handler(state: AppState) -> Result<(Router, Router), AppError> {
let internal_router = Router::new()
.route("/token/generate", get(generate_token))
.route("/token/revoke_all", post(revoke_all_tokens))
.with_state(state.clone());
let trace_layer = TraceLayer::new_for_http() let trace_layer = TraceLayer::new_for_http()
.make_span_with(make_span_trace) .make_span_with(make_span_trace)
.on_request(|req: &Request<Body>, _span: &Span| { .on_request(|req: &Request<Body>, _span: &Span| {
@ -108,11 +117,12 @@ pub(super) async fn handler(state: AppState) -> Result<(Router, Router), AppErro
req.version(), req.version(),
) )
}); });
let cors_layer = CorsLayer::new()
let internal_router = Router::new() .allow_origin(tower_http::cors::Any)
.route("/token/generate", get(generate_token)) .allow_headers([CONTENT_TYPE, CACHE_CONTROL, REQUEST_ID])
.route("/token/revoke_all", post(revoke_all_tokens)) .allow_methods([Method::GET]);
.with_state(state.clone()); let sensitive_header_layer = SetSensitiveRequestHeadersLayer::new([AUTHORIZATION]);
let request_id_layer = SetRequestIdLayer::new(REQUEST_ID.clone(), MakeRequestUuid);
let router = Router::new() let router = Router::new()
.route("/token/generate_for_music/:id", get(generate_scoped_token)) .route("/token/generate_for_music/:id", get(generate_scoped_token))
@ -121,14 +131,9 @@ pub(super) async fn handler(state: AppState) -> Result<(Router, Router), AppErro
.route("/audio/scoped/:id", get(get_scoped_music)) .route("/audio/scoped/:id", get(get_scoped_music))
.route("/", get(metadata_ws)) .route("/", get(metadata_ws))
.layer(trace_layer) .layer(trace_layer)
.layer(SetSensitiveRequestHeadersLayer::new([AUTHORIZATION])) .layer(sensitive_header_layer)
.layer( .layer(cors_layer)
CorsLayer::new() .layer(request_id_layer)
.allow_origin(tower_http::cors::Any)
.allow_headers([CONTENT_TYPE, CACHE_CONTROL, REQUEST_ID])
.allow_methods([Method::GET]),
)
.layer(SetRequestIdLayer::new(REQUEST_ID.clone(), MakeRequestUuid))
.with_state(state); .with_state(state);
Ok((router, internal_router)) Ok((router, internal_router))
@ -186,21 +191,22 @@ async fn generate_scoped_token(
async fn get_scoped_music( async fn get_scoped_music(
State(app): State<AppState>, State(app): State<AppState>,
Path(token): Path<String>, Path(token): Path<String>,
request: Request<Body>,
) -> Result<Response<Body>, AppError> { ) -> Result<Response<Body>, AppError> {
if let Some(music_id) = app.scoped_tokens.verify(token).await { if let Some(music_id) = app.scoped_tokens.verify(token).await {
let mut resp = app let mut req = Request::builder()
.client
.request(
Request::builder()
.uri(format!( .uri(format!(
"http://{}:{}/audio/external_id/{}", "http://{}:{}/audio/external_id/{}",
app.musikcubed_address, app.musikcubed_http_port, music_id app.musikcubed_address, app.musikcubed_http_port, music_id
)) ))
.header(AUTHORIZATION, app.musikcubed_auth_header_value.clone()) .header(AUTHORIZATION, app.musikcubed_auth_header_value.clone())
.body(Body::empty()) .body(Body::empty())
.expect("cant fail"), .expect("cant fail");
) // proxy any range headers
.await?; if let Some(range) = request.headers().get(RANGE).cloned() {
req.headers_mut().insert(RANGE, range);
}
let mut resp = app.client.request(req).await?;
if resp.status().is_success() { if resp.status().is_success() {
// add cache header // add cache header
resp.headers_mut() resp.headers_mut()
@ -232,23 +238,26 @@ async fn get_music(
async fn http( async fn http(
State(app): State<AppState>, State(app): State<AppState>,
Query(query): Query<Auth>, Query(auth): Query<Auth>,
mut req: Request<Body>, mut req: Request<Body>,
) -> Result<Response<Body>, AppError> { ) -> Result<Response<Body>, AppError> {
// remove token from query
let path = req.uri().path(); let path = req.uri().path();
let path_query = req let query_map = remove_token_from_query(req.uri().query());
.uri() let has_query = !query_map.is_empty();
.path_and_query() let query = has_query
.map(|v| v.as_str()) .then(|| serde_qs::to_string(&query_map).unwrap())
.unwrap_or(path); .unwrap_or_else(String::new);
let query_prefix = has_query.then_some("?").unwrap_or("");
// craft new url
*req.uri_mut() = format!( *req.uri_mut() = format!(
"http://{}:{}{}", "http://{}:{}{path}{query_prefix}{query}",
app.musikcubed_address, app.musikcubed_http_port, path_query app.musikcubed_address, app.musikcubed_http_port
) )
.parse()?; .parse()?;
let maybe_token = query.token.or_else(|| { let maybe_token = auth.token.or_else(|| {
req.headers() req.headers()
.get(AUTHORIZATION) .get(AUTHORIZATION)
.and_then(|h| h.to_str().ok()) .and_then(|h| h.to_str().ok())
@ -269,8 +278,29 @@ async fn http(
.expect("cant fail")); .expect("cant fail"));
} }
req.headers_mut() // proxy only the headers we need
.insert(AUTHORIZATION, app.musikcubed_auth_header_value.clone()); let headers = {
let mut headers = HeaderMap::with_capacity(2);
let mut proxy_header = |header_name: HeaderName| {
if let Some(value) = req.headers().get(&header_name).cloned() {
headers.insert(header_name, value);
}
};
// proxy range header
proxy_header(RANGE);
// add auth
headers.insert(AUTHORIZATION, app.musikcubed_auth_header_value.clone());
headers
};
*req.headers_mut() = headers;
let scheme = req.uri().scheme_str().unwrap();
let authority = req.uri().authority().unwrap().as_str();
tracing::debug!(
"proxying request to {scheme}://{authority} with headers {:?}",
req.headers()
);
Ok(app.client.request(req).await?) Ok(app.client.request(req).await?)
} }
@ -286,6 +316,7 @@ async fn metadata_ws(
"ws://{}:{}", "ws://{}:{}",
app.musikcubed_address, app.musikcubed_metadata_port app.musikcubed_address, app.musikcubed_metadata_port
); );
tracing::debug!("proxying websocket request to {uri}");
let (ws_stream, _) = connect_async(uri).await?; let (ws_stream, _) = connect_async(uri).await?;
let upgrade = ws let upgrade = ws
@ -402,7 +433,8 @@ async fn handle_metadata_socket(
break 'err; break 'err;
} }
// wait for auth reply // wait for auth reply
if let Some(Ok(TungsteniteMessage::Text(raw))) = server_socket.next().await { match server_socket.next().await {
Some(Ok(TungsteniteMessage::Text(raw))) => {
let Ok(parsed) = serde_json::from_str::<WsApiMessage>(&raw) else { let Ok(parsed) = serde_json::from_str::<WsApiMessage>(&raw) else {
tracing::error!("invalid auth response message: {raw}"); tracing::error!("invalid auth response message: {raw}");
break 'err; break 'err;
@ -414,10 +446,21 @@ async fn handle_metadata_socket(
.unwrap_or(false); .unwrap_or(false);
match is_authenticated { match is_authenticated {
true => break 'ok parsed, true => break 'ok parsed,
false => break 'err, false => {
tracing::error!("unauthorized");
break 'err;
} }
} }
} }
Some(Ok(TungsteniteMessage::Close(frame))) => {
let reason = frame
.map(|v| format!("{} (code {})", v.reason, v.code))
.unwrap_or_else(|| "no reason given".to_string());
tracing::error!("socket was closed by musikcubed: {}", reason);
}
_ => tracing::error!("unexpected message from musikcubed"),
}
}
let _ = client_socket.close().await; let _ = client_socket.close().await;
let _ = server_socket.close(None).await; let _ = server_socket.close(None).await;
return; return;

View File

@ -122,12 +122,16 @@ impl AppStateInternal {
musikcubed_address: get_conf("MUSIKCUBED_ADDRESS")?, musikcubed_address: get_conf("MUSIKCUBED_ADDRESS")?,
musikcubed_http_port: get_conf("MUSIKCUBED_HTTP_PORT")?.parse()?, musikcubed_http_port: get_conf("MUSIKCUBED_HTTP_PORT")?.parse()?,
musikcubed_metadata_port: get_conf("MUSIKCUBED_METADATA_PORT")?.parse()?, musikcubed_metadata_port: get_conf("MUSIKCUBED_METADATA_PORT")?.parse()?,
musikcubed_auth_header_value: format!( musikcubed_auth_header_value: {
let mut val: http::HeaderValue = format!(
"Basic {}", "Basic {}",
B64.encode(format!("default:{}", musikcubed_password)) B64.encode(format!("default:{}", musikcubed_password))
) )
.parse() .parse()
.expect("valid header value"), .expect("valid header value");
val.set_sensitive(true);
val
},
musikcubed_password, musikcubed_password,
client: Client::new(), client: Client::new(),
tokens: Tokens::read(&tokens_path).await?, tokens: Tokens::read(&tokens_path).await?,