aboutsummaryrefslogtreecommitdiffstats
path: root/src/args.rs
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/args.rs128
1 files changed, 102 insertions, 26 deletions
diff --git a/src/args.rs b/src/args.rs
index 8d2e105..4077f35 100644
--- a/src/args.rs
+++ b/src/args.rs
@@ -38,9 +38,10 @@ struct CLIArgs {
)]
interfaces: Vec<IpAddr>,
- /// Set authentication (username:password)
+ /// Set authentication. Currently supported formats:
+ /// username:password, username:sha256:hash, username:sha512:hash
#[structopt(short = "a", long = "auth", parse(try_from_str = "parse_auth"))]
- auth: Option<(String, String)>,
+ auth: Option<auth::RequiredAuth>,
/// Generate a random 6-hexdigit route
#[structopt(long = "random-route")]
@@ -77,34 +78,55 @@ fn parse_interface(src: &str) -> Result<IpAddr, std::net::AddrParseError> {
}
/// Checks wether the auth string is valid, i.e. it follows the syntax username:password
-fn parse_auth(src: &str) -> Result<(String, String), ContextualError> {
- let mut split = src.splitn(2, ':');
+fn parse_auth(src: &str) -> Result<auth::RequiredAuth, ContextualError> {
+ let mut split = src.splitn(3, ':');
+ let invalid_auth_format = Err(
+ ContextualError::new(ContextualErrorKind::InvalidAuthFormat)
+ );
let username = match split.next() {
Some(username) => username,
- None => {
- return Err(ContextualError::new(ContextualErrorKind::InvalidAuthFormat));
- }
+ None => return invalid_auth_format,
};
- let password = match split.next() {
+ // second_part is either password in username:password or method in username:method:hash
+ let second_part = match split.next() {
// This allows empty passwords, as the spec does not forbid it
Some(password) => password,
- None => {
- return Err(ContextualError::new(ContextualErrorKind::InvalidAuthFormat));
- }
+ None => return invalid_auth_format,
};
- // To make it Windows-compatible, the password needs to be shorter than 255 characters.
- // After 255 characters, Windows will truncate the value.
- // As for the username, the spec does not mention a limit in length
- if password.len() > 255 {
- return Err(ContextualError::new(
- ContextualErrorKind::PasswordTooLongError,
- ));
- }
+ let password = if let Some(hash_hex) = split.next() {
+ let hash_bin = if let Ok(hash_bin) = hex::decode(hash_hex) {
+ hash_bin
+ } else {
+ return Err(ContextualError::new(ContextualErrorKind::InvalidPasswordHash))
+ };
+
+ match second_part {
+ "sha256" => auth::RequiredAuthPassword::Sha256(hash_bin.to_owned()),
+ "sha512" => auth::RequiredAuthPassword::Sha512(hash_bin.to_owned()),
+ _ => {
+ return Err(ContextualError::new(
+ ContextualErrorKind::InvalidHashMethod(second_part.to_owned())
+ ))
+ },
+ }
+ } else {
+ // To make it Windows-compatible, the password needs to be shorter than 255 characters.
+ // After 255 characters, Windows will truncate the value.
+ // As for the username, the spec does not mention a limit in length
+ if second_part.len() > 255 {
+ return Err(ContextualError::new(ContextualErrorKind::PasswordTooLongError));
+ }
+
+ auth::RequiredAuthPassword::Plain(second_part.to_owned())
+ };
- Ok((username.to_owned(), password.to_owned()))
+ Ok(auth::RequiredAuth {
+ username: username.to_owned(),
+ password,
+ })
}
/// Parses the command line arguments
@@ -120,11 +142,6 @@ pub fn parse_args() -> crate::MiniserveConfig {
]
};
- let auth = match args.auth {
- Some((username, password)) => Some(auth::BasicAuthParams { username, password }),
- None => None,
- };
-
let random_route = if args.random_route {
Some(nanoid::custom(6, &ROUTE_ALPHABET))
} else {
@@ -140,7 +157,7 @@ pub fn parse_args() -> crate::MiniserveConfig {
path: args.path.unwrap_or_else(|| PathBuf::from(".")),
port: args.port,
interfaces,
- auth,
+ auth: args.auth,
path_explicitly_chosen,
no_symlinks: args.no_symlinks,
random_route,
@@ -149,3 +166,62 @@ pub fn parse_args() -> crate::MiniserveConfig {
file_upload: args.file_upload,
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use rstest::rstest_parametrize;
+
+ /// Helper function that creates a `RequiredAuth` structure
+ fn create_required_auth(username: &str, password: &str, encrypt: &str) -> auth::RequiredAuth {
+ use auth::*;
+ use RequiredAuthPassword::*;
+
+ RequiredAuth {
+ username: username.to_owned(),
+ password: match encrypt {
+ "plain" => Plain(password.to_owned()),
+ "sha256" => Sha256(hex::decode(password.to_owned()).unwrap()),
+ "sha512" => Sha512(hex::decode(password.to_owned()).unwrap()),
+ _ => panic!("Unknown encryption type"),
+ },
+ }
+ }
+
+ #[rstest_parametrize(
+ auth_string, username, password, encrypt,
+ case("username:password", "username", "password", "plain"),
+ case("username:sha256:abcd", "username", "abcd", "sha256"),
+ case("username:sha512:abcd", "username", "abcd", "sha512")
+ )]
+ fn parse_auth_valid(auth_string: &str, username: &str, password: &str, encrypt: &str) {
+ assert_eq!(
+ parse_auth(auth_string).unwrap(),
+ create_required_auth(username, password, encrypt),
+ );
+ }
+
+ #[rstest_parametrize(
+ auth_string, err_msg,
+ case(
+ "foo",
+ "Invalid format for credentials string. Expected username:password, username:sha256:hash or username:sha512:hash"
+ ),
+ case(
+ "username:blahblah:abcd",
+ "blahblah is not a valid hashing method. Expected sha256 or sha512"
+ ),
+ case(
+ "username:sha256:invalid",
+ "Invalid format for password hash. Expected hex code"
+ ),
+ case(
+ "username:sha512:invalid",
+ "Invalid format for password hash. Expected hex code"
+ ),
+ )]
+ fn parse_auth_invalid(auth_string: &str, err_msg: &str) {
+ let err = parse_auth(auth_string).unwrap_err();
+ assert_eq!(format!("{}", err), err_msg.to_owned());
+ }
+}