diff options
Diffstat (limited to 'src/auth.rs')
-rw-r--r-- | src/auth.rs | 105 |
1 files changed, 80 insertions, 25 deletions
diff --git a/src/auth.rs b/src/auth.rs index 6081a9d..31826d8 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,6 +1,7 @@ +use actix_web::dev::{Body, Service, ServiceRequest, ServiceResponse, Transform}; use actix_web::http::{header, StatusCode}; -use actix_web::middleware::{Middleware, Response}; use actix_web::{HttpRequest, HttpResponse, Result}; +use futures::future; use sha2::{Digest, Sha256, Sha512}; use crate::errors::{self, ContextualError}; @@ -87,16 +88,61 @@ pub fn get_hash<T: Digest>(text: &str) -> Vec<u8> { hasher.finalize().to_vec() } -impl Middleware<crate::MiniserveConfig> for Auth { - fn response( - &self, - req: &HttpRequest<crate::MiniserveConfig>, - resp: HttpResponse, - ) -> Result<Response> { - let required_auth = &req.state().auth; +pub struct AuthMiddleware<S> { + service: S, +} + +impl<S> Transform<S> for Auth +where + S: Service< + Request = ServiceRequest, + Response = ServiceResponse<Body>, + Error = actix_web::Error, + >, + S::Future: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse<Body>; + type Error = S::Error; + type Future = future::Ready<Result<Self::Transform, Self::InitError>>; + type Transform = AuthMiddleware<S>; + type InitError = (); + fn new_transform(&self, service: S) -> Self::Future { + future::ok(AuthMiddleware { service }) + } +} + +impl<S> Service for AuthMiddleware<S> +where + S: Service< + Request = ServiceRequest, + Response = ServiceResponse<Body>, + Error = actix_web::Error, + >, + S::Future: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse<Body>; + type Error = S::Error; + type Future = + std::pin::Pin<Box<dyn future::Future<Output = Result<Self::Response, Self::Error>>>>; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context, + ) -> std::task::Poll<Result<(), Self::Error>> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + let (req, pl) = req.into_parts(); + let required_auth = &req.app_data::<crate::MiniserveConfig>().unwrap().auth; if required_auth.is_empty() { - return Ok(Response::Done(resp)); + let resp = self + .service + .call(ServiceRequest::from_parts(req, pl).unwrap_or_else(|_| unreachable!())); + return Box::pin(async { resp.await }); } if let Some(auth_headers) = req.headers().get(header::AUTHORIZATION) { @@ -104,30 +150,38 @@ impl Middleware<crate::MiniserveConfig> for Auth { Ok(auth_req) => auth_req, Err(err) => { let auth_err = ContextualError::HTTPAuthenticationError(Box::new(err)); - return Ok(Response::Done(HttpResponse::BadRequest().body( - build_unauthorized_response(&req, auth_err, true, StatusCode::BAD_REQUEST), + let body = + build_unauthorized_response(&req, auth_err, true, StatusCode::BAD_REQUEST); + return Box::pin(future::ok(ServiceResponse::new( + req, + HttpResponse::BadRequest().body(body), ))); } }; if match_auth(auth_req, required_auth) { - return Ok(Response::Done(resp)); + let resp = self + .service + .call(ServiceRequest::from_parts(req, pl).unwrap_or_else(|_| unreachable!())); + return Box::pin(async { resp.await }); } } - Ok(Response::Done( + let body = build_unauthorized_response( + &req, + ContextualError::InvalidHTTPCredentials, + true, + StatusCode::UNAUTHORIZED, + ); + Box::pin(future::ok(ServiceResponse::new( + req, HttpResponse::Unauthorized() .header( header::WWW_AUTHENTICATE, header::HeaderValue::from_static("Basic realm=\"miniserve\""), ) - .body(build_unauthorized_response( - &req, - ContextualError::InvalidHTTPCredentials, - true, - StatusCode::UNAUTHORIZED, - )), - )) + .body(body), + ))) } } @@ -135,18 +189,19 @@ impl Middleware<crate::MiniserveConfig> for Auth { /// The reason why log_error_chain is optional is to handle cases where the auth pop-up appears and when the user clicks Cancel. /// In those case, we do not log the error to the terminal since it does not really matter. fn build_unauthorized_response( - req: &HttpRequest<crate::MiniserveConfig>, + req: &HttpRequest, error: ContextualError, log_error_chain: bool, error_code: StatusCode, ) -> String { + let state = req.app_data::<crate::MiniserveConfig>().unwrap(); let error = ContextualError::HTTPAuthenticationError(Box::new(error)); if log_error_chain { errors::log_error_chain(error.to_string()); } - let return_path = match &req.state().random_route { - Some(random_route) => format!("/{}", random_route), + let return_path = match state.random_route { + Some(ref random_route) => format!("/{}", random_route), None => "/".to_string(), }; @@ -156,8 +211,8 @@ fn build_unauthorized_response( &return_path, None, None, - req.state().default_color_scheme, - req.state().default_color_scheme, + state.default_color_scheme, + state.default_color_scheme, false, false, ) |