diff options
-rw-r--r-- | src/main.rs | 91 |
1 files changed, 63 insertions, 28 deletions
diff --git a/src/main.rs b/src/main.rs index d291aa5..3db64d8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,37 +2,45 @@ extern crate actix; extern crate actix_web; extern crate simplelog; extern crate base64; +#[macro_use] extern crate clap; -#[macro_use] -extern crate clap; - -use actix_web::http::{StatusCode, header}; +use actix_web::http::header; use actix_web::{server, App, fs, middleware, HttpRequest, HttpResponse, HttpMessage, Result}; use actix_web::middleware::{Middleware, Response}; use simplelog::{TermLogger, LevelFilter, Config}; use std::path::PathBuf; use std::net::{IpAddr, Ipv4Addr}; -use std::error::Error; -/// Decode a HTTP basic auth string into a tuple of username and password. -fn parse_basic_auth(auth: String) -> Result<(String, String), String> { - let decoded = base64::decode(&auth).map_err(|e| e.description().to_owned())?; - let decoded_str = String::from_utf8_lossy(&decoded); - let strings: Vec<&str> = decoded_str.splitn(2, ':').collect(); - if strings.len() != 2 { - return Err("Invalid username/password format".to_owned()); - } - let (user, password) = (strings[0], strings[1]); - Ok((user.to_owned(), password.to_owned())) +enum BasicAuthError { + Base64DecodeError, + InvalidUsernameFormat, +} + +#[derive(Clone, Debug)] +struct BasicAuthParams { + username: String, + password: String, } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct MiniserveConfig { verbose: bool, path: std::path::PathBuf, port: u16, interface: IpAddr, - auth: Option<String>, + auth: Option<BasicAuthParams>, +} + +/// Decode a HTTP basic auth string into a tuple of username and password. +fn parse_basic_auth(authorization_header: &header::HeaderValue) -> Result<BasicAuthParams, BasicAuthError> { + let basic_removed = authorization_header.to_str().unwrap().replace("Basic ", ""); + let decoded = base64::decode(&basic_removed).map_err(|_| BasicAuthError::Base64DecodeError)?; + let decoded_str = String::from_utf8_lossy(&decoded); + let strings: Vec<&str> = decoded_str.splitn(2, ':').collect(); + if strings.len() != 2 { + return Err(BasicAuthError::InvalidUsernameFormat); + } + Ok(BasicAuthParams { username: strings[0].to_owned(), password: strings[1].to_owned() }) } fn is_valid_path(path: String) -> Result<(), String> { @@ -52,7 +60,7 @@ fn is_valid_interface(interface: String) -> Result<(), String> { } fn is_valid_auth(auth: String) -> Result<(), String> { - auth.find(':').ok_or("Correct format is user:password".to_owned()).map(|_| ()) + auth.find(':').ok_or("Correct format is username:password".to_owned()).map(|_| ()) } pub fn parse_args() -> MiniserveConfig { @@ -99,7 +107,7 @@ pub fn parse_args() -> MiniserveConfig { .short("a") .long("auth") .validator(is_valid_auth) - .help("Set authentication (user:password)") + .help("Set authentication (username:password)") .takes_value(true), ) .get_matches(); @@ -108,7 +116,16 @@ pub fn parse_args() -> MiniserveConfig { let path = matches.value_of("PATH").unwrap(); let port = matches.value_of("port").unwrap().parse().unwrap(); let interface = matches.value_of("interface").unwrap().parse().unwrap(); - let auth = matches.value_of("auth").map(|a| a.to_owned()); + let auth = if let Some(auth_split) = matches.value_of("auth").map(|x| x.splitn(2, ':')) { + let auth_vec = auth_split.collect::<Vec<&str>>(); + if auth_vec.len() == 2 { + Some(BasicAuthParams { username: auth_vec[0].to_owned(), password: auth_vec[1].to_owned() }) + } else { + None + } + } else { + None + }; MiniserveConfig { verbose, @@ -144,15 +161,33 @@ fn configure_app(app: App<MiniserveConfig>) -> App<MiniserveConfig> { struct Auth; impl Middleware<MiniserveConfig> for Auth { - fn response(&self, req: &mut HttpRequest<MiniserveConfig>, mut resp: HttpResponse) -> Result<Response> { - let required_auth = &req.state().auth; - if required_auth.is_some() { - // parse_basic_auth(pass) - println!("{:?}", required_auth); - println!("{:?}", req.headers().get(header::AUTHORIZATION)); + fn response(&self, req: &mut HttpRequest<MiniserveConfig>, resp: HttpResponse) -> Result<Response> { + if let Some(required_auth) = &req.state().auth { + if let Some(auth_headers) = req.headers().get(header::AUTHORIZATION) { + let auth_req = match parse_basic_auth(auth_headers) { + Ok(auth_req) => auth_req, + Err(BasicAuthError::Base64DecodeError) => + return Ok(Response::Done(HttpResponse::BadRequest() + .body(format!("Error decoding basic auth base64: '{}'", + auth_headers.to_str().unwrap())))), + Err(BasicAuthError::InvalidUsernameFormat) => + return Ok(Response::Done(HttpResponse::BadRequest() + .body("Invalid basic auth format"))), + }; + if auth_req.username != required_auth.username + || auth_req.password != required_auth.password { + let new_resp = HttpResponse::Forbidden() + .finish(); + return Ok(Response::Done(new_resp)); + } + } else { + let new_resp = HttpResponse::Unauthorized() + .header(header::WWW_AUTHENTICATE, + header::HeaderValue::from_static("Basic realm=\"miniserve\"")) + .finish(); + return Ok(Response::Done(new_resp)); + } } - resp.headers_mut().insert(header::WWW_AUTHENTICATE, header::HeaderValue::from_static("Basic realm=\"lol\"")); - *resp.status_mut() = StatusCode::UNAUTHORIZED; Ok(Response::Done(resp)) } } |