diff options
Diffstat (limited to '')
-rw-r--r-- | src/auth.rs | 151 |
1 files changed, 30 insertions, 121 deletions
diff --git a/src/auth.rs b/src/auth.rs index 31826d8..91c7bb0 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,14 +1,12 @@ -use actix_web::dev::{Body, Service, ServiceRequest, ServiceResponse, Transform}; +use actix_web::dev::ServiceRequest; use actix_web::http::{header, StatusCode}; use actix_web::{HttpRequest, HttpResponse, Result}; -use futures::future; +use actix_web_httpauth::extractors::basic::BasicAuth; use sha2::{Digest, Sha256, Sha512}; use crate::errors::{self, ContextualError}; use crate::renderer; -pub struct Auth; - #[derive(Clone, Debug)] /// HTTP Basic authentication parameters pub struct BasicAuthParams { @@ -16,6 +14,15 @@ pub struct BasicAuthParams { pub password: String, } +impl From<BasicAuth> for BasicAuthParams { + fn from(auth: BasicAuth) -> Self { + Self { + username: auth.user_id().to_string(), + password: auth.password().unwrap_or(&"".into()).to_string(), + } + } +} + #[derive(Clone, Debug, PartialEq)] /// `password` field of `RequiredAuth` pub enum RequiredAuthPassword { @@ -31,29 +38,6 @@ pub struct RequiredAuth { pub password: RequiredAuthPassword, } -/// Decode a HTTP basic auth string into a tuple of username and password. -pub fn parse_basic_auth( - authorization_header: &header::HeaderValue, -) -> Result<BasicAuthParams, ContextualError> { - let basic_removed = authorization_header - .to_str() - .map_err(|e| { - ContextualError::ParseError("HTTP authentication header".to_string(), e.to_string()) - })? - .replace("Basic ", ""); - let decoded = base64::decode(&basic_removed).map_err(ContextualError::Base64DecodeError)?; - let decoded_str = String::from_utf8_lossy(&decoded); - let credentials: Vec<&str> = decoded_str.splitn(2, ':').collect(); - - // If argument parsing went fine, it means the HTTP credentials string is well formatted - // So we can safely unpack the username and the password - - Ok(BasicAuthParams { - username: credentials[0].to_owned(), - password: credentials[1].to_owned(), - }) -} - /// Return `true` if `basic_auth` is matches any of `required_auth` pub fn match_auth(basic_auth: BasicAuthParams, required_auth: &[RequiredAuth]) -> bool { required_auth @@ -88,100 +72,25 @@ pub fn get_hash<T: Digest>(text: &str) -> Vec<u8> { hasher.finalize().to_vec() } -pub struct AuthMiddleware<S> { - service: S, -} - -impl<S> Transform<S> for Auth -where - S: Service< - Request = ServiceRequest, - Response = ServiceResponse<Body>, - Error = actix_web::Error, - >, - S::Future: 'static, -{ - type Request = ServiceRequest; - type Response = ServiceResponse<Body>; - type Error = S::Error; - type Future = future::Ready<Result<Self::Transform, Self::InitError>>; - type Transform = AuthMiddleware<S>; - type InitError = (); - fn new_transform(&self, service: S) -> Self::Future { - future::ok(AuthMiddleware { service }) - } -} - -impl<S> Service for AuthMiddleware<S> -where - S: Service< - Request = ServiceRequest, - Response = ServiceResponse<Body>, - Error = actix_web::Error, - >, - S::Future: 'static, -{ - type Request = ServiceRequest; - type Response = ServiceResponse<Body>; - type Error = S::Error; - type Future = - std::pin::Pin<Box<dyn future::Future<Output = Result<Self::Response, Self::Error>>>>; - - fn poll_ready( - &mut self, - cx: &mut std::task::Context, - ) -> std::task::Poll<Result<(), Self::Error>> { - self.service.poll_ready(cx) - } - - fn call(&mut self, req: Self::Request) -> Self::Future { - let (req, pl) = req.into_parts(); - let required_auth = &req.app_data::<crate::MiniserveConfig>().unwrap().auth; - - if required_auth.is_empty() { - let resp = self - .service - .call(ServiceRequest::from_parts(req, pl).unwrap_or_else(|_| unreachable!())); - return Box::pin(async { resp.await }); - } - - if let Some(auth_headers) = req.headers().get(header::AUTHORIZATION) { - let auth_req = match parse_basic_auth(auth_headers) { - Ok(auth_req) => auth_req, - Err(err) => { - let auth_err = ContextualError::HTTPAuthenticationError(Box::new(err)); - let body = - build_unauthorized_response(&req, auth_err, true, StatusCode::BAD_REQUEST); - return Box::pin(future::ok(ServiceResponse::new( - req, - HttpResponse::BadRequest().body(body), - ))); - } - }; - - if match_auth(auth_req, required_auth) { - let resp = self - .service - .call(ServiceRequest::from_parts(req, pl).unwrap_or_else(|_| unreachable!())); - return Box::pin(async { resp.await }); - } - } - - let body = build_unauthorized_response( - &req, - ContextualError::InvalidHTTPCredentials, - true, - StatusCode::UNAUTHORIZED, - ); - Box::pin(future::ok(ServiceResponse::new( - req, - HttpResponse::Unauthorized() - .header( - header::WWW_AUTHENTICATE, - header::HeaderValue::from_static("Basic realm=\"miniserve\""), - ) - .body(body), - ))) +pub async fn handle_auth(req: ServiceRequest, cred: BasicAuth) -> Result<ServiceRequest> { + let (req, pl) = req.into_parts(); + let required_auth = &req.app_data::<crate::MiniserveConfig>().unwrap().auth; + + if match_auth(cred.into(), required_auth) { + Ok(ServiceRequest::from_parts(req, pl).unwrap_or_else(|_| unreachable!())) + } else { + Err(HttpResponse::Unauthorized() + .header( + header::WWW_AUTHENTICATE, + header::HeaderValue::from_static("Basic realm=\"miniserve\""), + ) + .body(build_unauthorized_response( + &req, + ContextualError::InvalidHTTPCredentials, + true, + StatusCode::UNAUTHORIZED, + )) + .into()) } } |