feat: add cache header to audio requests

This commit is contained in:
dusk 2023-05-05 12:22:22 +03:00
parent e254bf62fc
commit fba59560d1
Signed by: dusk
GPG Key ID: 1D8F8FAF2294D6EA

View File

@ -19,8 +19,8 @@ use axum::{
use base64::Engine; use base64::Engine;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use http::{ use http::{
header::{AUTHORIZATION, CONTENT_TYPE}, header::{AUTHORIZATION, CACHE_CONTROL, CONTENT_TYPE},
Method, Request, Response, StatusCode, HeaderValue, Method, Request, Response, StatusCode,
}; };
use hyper::Body; use hyper::Body;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -33,6 +33,8 @@ use tracing::{Instrument, Span};
use crate::{AppState, B64}; use crate::{AppState, B64};
const AUDIO_CACHE_HEADER: HeaderValue = HeaderValue::from_static("private, max-age=604800");
#[derive(Deserialize)] #[derive(Deserialize)]
struct Auth { struct Auth {
#[serde(default)] #[serde(default)]
@ -124,7 +126,7 @@ pub(super) async fn handler(state: AppState) -> Result<(Router, Router), AppErro
let router = Router::new() let router = Router::new()
.route("/token/generate_for_music/:id", get(generate_scoped_token)) .route("/token/generate_for_music/:id", get(generate_scoped_token))
.route("/thumbnail/:id", get(http)) .route("/thumbnail/:id", get(http))
.route("/audio/external_id/:id", get(http)) .route("/audio/external_id/:id", get(get_music))
.route("/audio/scoped/:id", get(get_scoped_music)) .route("/audio/scoped/:id", get(get_scoped_music))
.route("/", get(metadata_ws)) .route("/", get(metadata_ws))
.layer(SetSensitiveRequestHeadersLayer::new([AUTHORIZATION])) .layer(SetSensitiveRequestHeadersLayer::new([AUTHORIZATION]))
@ -170,12 +172,13 @@ async fn generate_scoped_token(
let maybe_token = query.token; let maybe_token = query.token;
'ok: { 'ok: {
tracing::debug!("verifying token: {maybe_token:?}");
if let Some(token) = maybe_token { if let Some(token) = maybe_token {
if app.tokens.verify(token).await? { if app.tokens.verify(token).await? {
tracing::debug!("verified token");
break 'ok; break 'ok;
} }
} }
tracing::debug!("invalid token");
return Ok(( return Ok((
StatusCode::UNAUTHORIZED, StatusCode::UNAUTHORIZED,
"Invalid token or token not present", "Invalid token or token not present",
@ -193,7 +196,7 @@ async fn get_scoped_music(
Path(token): Path<String>, Path(token): Path<String>,
) -> Result<Response<Body>, AppError> { ) -> Result<Response<Body>, AppError> {
if let Some(music_id) = app.scoped_tokens.verify(token).await { if let Some(music_id) = app.scoped_tokens.verify(token).await {
Ok(app let mut resp = app
.client .client
.request( .request(
Request::builder() Request::builder()
@ -205,7 +208,11 @@ async fn get_scoped_music(
.body(Body::empty()) .body(Body::empty())
.expect("cant fail"), .expect("cant fail"),
) )
.await?) .await?;
// add cache header
resp.headers_mut()
.insert(CACHE_CONTROL, AUDIO_CACHE_HEADER.clone());
Ok(resp)
} else { } else {
Ok(Response::builder() Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED) .status(StatusCode::UNAUTHORIZED)
@ -214,6 +221,19 @@ async fn get_scoped_music(
} }
} }
async fn get_music(
State(app): State<AppState>,
Query(query): Query<Auth>,
req: Request<Body>,
) -> Result<Response<Body>, AppError> {
http(State(app), Query(query), req).await.map(|mut resp| {
// add cache header
resp.headers_mut()
.insert(CACHE_CONTROL, AUDIO_CACHE_HEADER.clone());
resp
})
}
async fn http( async fn http(
State(app): State<AppState>, State(app): State<AppState>,
Query(query): Query<Auth>, Query(query): Query<Auth>,
@ -240,12 +260,13 @@ async fn http(
}); });
'ok: { 'ok: {
tracing::debug!("verifying token: {maybe_token:?}");
if let Some(token) = maybe_token { if let Some(token) = maybe_token {
if app.tokens.verify(token).await? { if app.tokens.verify(token).await? {
tracing::debug!("verified token");
break 'ok; break 'ok;
} }
} }
tracing::debug!("invalid token");
return Ok(Response::builder() return Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED) .status(StatusCode::UNAUTHORIZED)
.body("Invalid token or token not present".to_string().into()) .body("Invalid token or token not present".to_string().into())