diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/args.rs | 27 | ||||
-rw-r--r-- | src/auth.rs | 190 | ||||
-rw-r--r-- | src/main.rs | 2 |
3 files changed, 153 insertions, 66 deletions
diff --git a/src/args.rs b/src/args.rs index 3f48329..925d3dd 100644 --- a/src/args.rs +++ b/src/args.rs @@ -41,8 +41,13 @@ struct CLIArgs { /// Set authentication. Currently supported formats: /// username:password, username:sha256:hash, username:sha512:hash /// (e.g. joe:123, joe:sha256:a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3) - #[structopt(short = "a", long = "auth", parse(try_from_str = "parse_auth"))] - auth: Option<auth::RequiredAuth>, + #[structopt( + short = "a", + long = "auth", + parse(try_from_str = "parse_auth"), + raw(number_of_values = "1") + )] + auth: Vec<auth::RequiredAuth>, /// Generate a random 6-hexdigit route #[structopt(long = "random-route")] @@ -78,7 +83,7 @@ fn parse_interface(src: &str) -> Result<IpAddr, std::net::AddrParseError> { src.parse::<IpAddr>() } -/// Checks wether the auth string is valid, i.e. it follows the syntax username:password +/// Parse authentication requirement fn parse_auth(src: &str) -> Result<auth::RequiredAuth, ContextualError> { let mut split = src.splitn(3, ':'); let invalid_auth_format = Err(ContextualError::InvalidAuthFormat); @@ -173,14 +178,16 @@ mod tests { use auth::*; use RequiredAuthPassword::*; - RequiredAuth { + let password = match encrypt { + "plain" => Plain(password.to_owned()), + "sha256" => Sha256(hex::decode(password.to_owned()).unwrap()), + "sha512" => Sha512(hex::decode(password.to_owned()).unwrap()), + _ => panic!("Unknown encryption type"), + }; + + auth::RequiredAuth { username: username.to_owned(), - password: match encrypt { - "plain" => Plain(password.to_owned()), - "sha256" => Sha256(hex::decode(password.to_owned()).unwrap()), - "sha512" => Sha512(hex::decode(password.to_owned()).unwrap()), - _ => panic!("Unknown encryption type"), - }, + password, } } diff --git a/src/auth.rs b/src/auth.rs index f2e5fcf..2c98622 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -53,32 +53,37 @@ pub fn parse_basic_auth( }) } -/// Verify authentication -pub fn match_auth(basic_auth: BasicAuthParams, required_auth: &RequiredAuth) -> bool { - if basic_auth.username != required_auth.username { - return false; - } +/// Return `true` if `basic_auth` is matches any of `required_auth` +pub fn match_auth(basic_auth: BasicAuthParams, required_auth: &[RequiredAuth]) -> bool { + required_auth.iter().any( + |RequiredAuth { username, password }| + basic_auth.username == *username && + compare_password(&basic_auth.password, password) + ) +} - match &required_auth.password { - RequiredAuthPassword::Plain(ref required_password) => { - basic_auth.password == *required_password +/// Return `true` if `basic_auth_pwd` meets `required_auth_pwd`'s requirement +pub fn compare_password (basic_auth_pwd: &str, required_auth_pwd: &RequiredAuthPassword) -> bool { + match &required_auth_pwd { + RequiredAuthPassword::Plain(required_password) => { + *basic_auth_pwd == *required_password } RequiredAuthPassword::Sha256(password_hash) => { - compare_hash::<Sha256>(basic_auth.password, password_hash) + compare_hash::<Sha256>(basic_auth_pwd, password_hash) } RequiredAuthPassword::Sha512(password_hash) => { - compare_hash::<Sha512>(basic_auth.password, password_hash) + compare_hash::<Sha512>(basic_auth_pwd, password_hash) } } } /// Return `true` if hashing of `password` by `T` algorithm equals to `hash` -pub fn compare_hash<T: Digest>(password: String, hash: &[u8]) -> bool { +pub fn compare_hash<T: Digest>(password: &str, hash: &[u8]) -> bool { get_hash::<T>(password) == hash } /// Get hash of a `text` -pub fn get_hash<T: Digest>(text: String) -> Vec<u8> { +pub fn get_hash<T: Digest>(text: &str) -> Vec<u8> { let mut hasher = T::new(); hasher.input(text); hasher.result().to_vec() @@ -90,48 +95,48 @@ impl Middleware<crate::MiniserveConfig> for Auth { req: &HttpRequest<crate::MiniserveConfig>, resp: HttpResponse, ) -> Result<Response> { - if let Some(ref required_auth) = req.state().auth { - if let Some(auth_headers) = req.headers().get(header::AUTHORIZATION) { - let auth_req = match parse_basic_auth(auth_headers) { - Ok(auth_req) => auth_req, - Err(err) => { - let auth_err = ContextualError::HTTPAuthenticationError(Box::new(err)); - return Ok(Response::Done(HttpResponse::BadRequest().body( + let required_auth = &req.state().auth; + + if required_auth.is_empty() { + return Ok(Response::Done(resp)); + } + + if let Some(auth_headers) = req.headers().get(header::AUTHORIZATION) { + let auth_req = match parse_basic_auth(auth_headers) { + Ok(auth_req) => auth_req, + Err(err) => { + let auth_err = ContextualError::HTTPAuthenticationError(Box::new(err)); + return Ok(Response::Done( + HttpResponse::BadRequest().body( build_unauthorized_response( &req, auth_err, true, StatusCode::BAD_REQUEST, ), - ))); - } - }; - if !match_auth(auth_req, required_auth) { - return Ok(Response::Done(HttpResponse::Unauthorized().body( - build_unauthorized_response( - &req, - ContextualError::InvalidHTTPCredentials, - true, - StatusCode::UNAUTHORIZED, ), - ))); - } - } else { - let new_resp = HttpResponse::Unauthorized() - .header( - header::WWW_AUTHENTICATE, - header::HeaderValue::from_static("Basic realm=\"miniserve\""), - ) - .body(build_unauthorized_response( - &req, - ContextualError::InvalidHTTPCredentials, - false, - StatusCode::UNAUTHORIZED, )); - return Ok(Response::Done(new_resp)); + } + }; + + if match_auth(auth_req, required_auth) { + return Ok(Response::Done(resp)); } } - Ok(Response::Done(resp)) + + Ok(Response::Done( + HttpResponse::Unauthorized() + .header( + header::WWW_AUTHENTICATE, + header::HeaderValue::from_static("Basic realm=\"miniserve\""), + ) + .body(build_unauthorized_response( + &req, + ContextualError::InvalidHTTPCredentials, + true, + StatusCode::UNAUTHORIZED, + )) + )) } } @@ -172,10 +177,10 @@ fn build_unauthorized_response( #[cfg(test)] mod tests { use super::*; - use rstest::rstest_parametrize; + use rstest::{rstest, rstest_parametrize, fixture}; /// Return a hashing function corresponds to given name - fn get_hash_func(name: &str) -> impl FnOnce(String) -> Vec<u8> { + fn get_hash_func(name: &str) -> impl FnOnce(&str) -> Vec<u8> { match name { "sha256" => get_hash::<Sha256>, "sha512" => get_hash::<Sha512>, @@ -191,7 +196,7 @@ mod tests { fn test_get_hash(password: &str, hash_method: &str, hash: &str) { let hash_func = get_hash_func(hash_method); let expected = hex::decode(hash).expect("Provided hash is not a valid hex code"); - let received = hash_func(password.to_owned()); + let received = hash_func(&password.to_owned()); assert_eq!(received, expected); } @@ -199,14 +204,16 @@ mod tests { fn create_required_auth(username: &str, password: &str, encrypt: &str) -> RequiredAuth { use RequiredAuthPassword::*; + let password = match encrypt { + "plain" => Plain(password.to_owned()), + "sha256" => Sha256(get_hash::<sha2::Sha256>(&password.to_owned())), + "sha512" => Sha512(get_hash::<sha2::Sha512>(&password.to_owned())), + _ => panic!("Unknown encryption type"), + }; + RequiredAuth { username: username.to_owned(), - password: match encrypt { - "plain" => Plain(password.to_owned()), - "sha256" => Sha256(get_hash::<sha2::Sha256>(password.to_owned())), - "sha512" => Sha512(get_hash::<sha2::Sha512>(password.to_owned())), - _ => panic!("Unknown encryption type"), - }, + password, } } @@ -219,7 +226,7 @@ mod tests { case(true, "obi", "hello there", "obi", "hello there", "sha512"), case(false, "obi", "hello there", "obi", "hi!", "sha512") )] - fn test_auth( + fn test_single_auth( should_pass: bool, param_username: &str, param_password: &str, @@ -233,9 +240,82 @@ mod tests { username: param_username.to_owned(), password: param_password.to_owned(), }, - &create_required_auth(required_username, required_password, encrypt), + &[create_required_auth(required_username, required_password, encrypt)], ), should_pass, ) } + + /// Helper function that creates a sample of multiple accounts + #[fixture] + fn account_sample() -> Vec<RequiredAuth> { + [ + ("usr0", "pwd0", "plain"), + ("usr1", "pwd1", "plain"), + ("usr2", "pwd2", "sha256"), + ("usr3", "pwd3", "sha256"), + ("usr4", "pwd4", "sha512"), + ("usr5", "pwd5", "sha512"), + ] + .iter() + .map(|(username, password, encrypt)| create_required_auth(username, password, encrypt)) + .collect() + } + + #[rstest_parametrize( + username, password, + case("usr0", "pwd0"), + case("usr1", "pwd1"), + case("usr2", "pwd2"), + case("usr3", "pwd3"), + case("usr4", "pwd4"), + case("usr5", "pwd5"), + )] + fn test_multiple_auth_pass( + account_sample: Vec<RequiredAuth>, + username: &str, + password: &str, + ) { + assert!(match_auth( + BasicAuthParams { + username: username.to_owned(), + password: password.to_owned(), + }, + &account_sample, + )); + } + + #[rstest] + fn test_multiple_auth_wrong_username(account_sample: Vec<RequiredAuth>) { + assert_eq!(match_auth( + BasicAuthParams { + username: "unregistered user".to_owned(), + password: "pwd0".to_owned(), + }, + &account_sample, + ), false); + } + + #[rstest_parametrize( + username, password, + case("usr0", "pwd5"), + case("usr1", "pwd4"), + case("usr2", "pwd3"), + case("usr3", "pwd2"), + case("usr4", "pwd1"), + case("usr5", "pwd0"), + )] + fn test_multiple_auth_wrong_password( + account_sample: Vec<RequiredAuth>, + username: &str, + password: &str, + ) { + assert_eq!(match_auth( + BasicAuthParams { + username: username.to_owned(), + password: password.to_owned(), + }, + &account_sample, + ), false); + } } diff --git a/src/main.rs b/src/main.rs index bb61edc..f26369a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -37,7 +37,7 @@ pub struct MiniserveConfig { pub interfaces: Vec<IpAddr>, /// Enable HTTP basic authentication - pub auth: Option<auth::RequiredAuth>, + pub auth: Vec<auth::RequiredAuth>, /// If false, miniserve will serve the current working directory pub path_explicitly_chosen: bool, |