diff options
Diffstat (limited to '')
-rw-r--r-- | src/args.rs | 128 |
1 files changed, 102 insertions, 26 deletions
diff --git a/src/args.rs b/src/args.rs index 8d2e105..4077f35 100644 --- a/src/args.rs +++ b/src/args.rs @@ -38,9 +38,10 @@ struct CLIArgs { )] interfaces: Vec<IpAddr>, - /// Set authentication (username:password) + /// Set authentication. Currently supported formats: + /// username:password, username:sha256:hash, username:sha512:hash #[structopt(short = "a", long = "auth", parse(try_from_str = "parse_auth"))] - auth: Option<(String, String)>, + auth: Option<auth::RequiredAuth>, /// Generate a random 6-hexdigit route #[structopt(long = "random-route")] @@ -77,34 +78,55 @@ fn parse_interface(src: &str) -> Result<IpAddr, std::net::AddrParseError> { } /// Checks wether the auth string is valid, i.e. it follows the syntax username:password -fn parse_auth(src: &str) -> Result<(String, String), ContextualError> { - let mut split = src.splitn(2, ':'); +fn parse_auth(src: &str) -> Result<auth::RequiredAuth, ContextualError> { + let mut split = src.splitn(3, ':'); + let invalid_auth_format = Err( + ContextualError::new(ContextualErrorKind::InvalidAuthFormat) + ); let username = match split.next() { Some(username) => username, - None => { - return Err(ContextualError::new(ContextualErrorKind::InvalidAuthFormat)); - } + None => return invalid_auth_format, }; - let password = match split.next() { + // second_part is either password in username:password or method in username:method:hash + let second_part = match split.next() { // This allows empty passwords, as the spec does not forbid it Some(password) => password, - None => { - return Err(ContextualError::new(ContextualErrorKind::InvalidAuthFormat)); - } + None => return invalid_auth_format, }; - // To make it Windows-compatible, the password needs to be shorter than 255 characters. - // After 255 characters, Windows will truncate the value. - // As for the username, the spec does not mention a limit in length - if password.len() > 255 { - return Err(ContextualError::new( - ContextualErrorKind::PasswordTooLongError, - )); - } + let password = if let Some(hash_hex) = split.next() { + let hash_bin = if let Ok(hash_bin) = hex::decode(hash_hex) { + hash_bin + } else { + return Err(ContextualError::new(ContextualErrorKind::InvalidPasswordHash)) + }; + + match second_part { + "sha256" => auth::RequiredAuthPassword::Sha256(hash_bin.to_owned()), + "sha512" => auth::RequiredAuthPassword::Sha512(hash_bin.to_owned()), + _ => { + return Err(ContextualError::new( + ContextualErrorKind::InvalidHashMethod(second_part.to_owned()) + )) + }, + } + } else { + // To make it Windows-compatible, the password needs to be shorter than 255 characters. + // After 255 characters, Windows will truncate the value. + // As for the username, the spec does not mention a limit in length + if second_part.len() > 255 { + return Err(ContextualError::new(ContextualErrorKind::PasswordTooLongError)); + } + + auth::RequiredAuthPassword::Plain(second_part.to_owned()) + }; - Ok((username.to_owned(), password.to_owned())) + Ok(auth::RequiredAuth { + username: username.to_owned(), + password, + }) } /// Parses the command line arguments @@ -120,11 +142,6 @@ pub fn parse_args() -> crate::MiniserveConfig { ] }; - let auth = match args.auth { - Some((username, password)) => Some(auth::BasicAuthParams { username, password }), - None => None, - }; - let random_route = if args.random_route { Some(nanoid::custom(6, &ROUTE_ALPHABET)) } else { @@ -140,7 +157,7 @@ pub fn parse_args() -> crate::MiniserveConfig { path: args.path.unwrap_or_else(|| PathBuf::from(".")), port: args.port, interfaces, - auth, + auth: args.auth, path_explicitly_chosen, no_symlinks: args.no_symlinks, random_route, @@ -149,3 +166,62 @@ pub fn parse_args() -> crate::MiniserveConfig { file_upload: args.file_upload, } } + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest_parametrize; + + /// Helper function that creates a `RequiredAuth` structure + fn create_required_auth(username: &str, password: &str, encrypt: &str) -> auth::RequiredAuth { + use auth::*; + use RequiredAuthPassword::*; + + RequiredAuth { + username: username.to_owned(), + password: match encrypt { + "plain" => Plain(password.to_owned()), + "sha256" => Sha256(hex::decode(password.to_owned()).unwrap()), + "sha512" => Sha512(hex::decode(password.to_owned()).unwrap()), + _ => panic!("Unknown encryption type"), + }, + } + } + + #[rstest_parametrize( + auth_string, username, password, encrypt, + case("username:password", "username", "password", "plain"), + case("username:sha256:abcd", "username", "abcd", "sha256"), + case("username:sha512:abcd", "username", "abcd", "sha512") + )] + fn parse_auth_valid(auth_string: &str, username: &str, password: &str, encrypt: &str) { + assert_eq!( + parse_auth(auth_string).unwrap(), + create_required_auth(username, password, encrypt), + ); + } + + #[rstest_parametrize( + auth_string, err_msg, + case( + "foo", + "Invalid format for credentials string. Expected username:password, username:sha256:hash or username:sha512:hash" + ), + case( + "username:blahblah:abcd", + "blahblah is not a valid hashing method. Expected sha256 or sha512" + ), + case( + "username:sha256:invalid", + "Invalid format for password hash. Expected hex code" + ), + case( + "username:sha512:invalid", + "Invalid format for password hash. Expected hex code" + ), + )] + fn parse_auth_invalid(auth_string: &str, err_msg: &str) { + let err = parse_auth(auth_string).unwrap_err(); + assert_eq!(format!("{}", err), err_msg.to_owned()); + } +} |