aboutsummaryrefslogtreecommitdiffstats
path: root/src/file_op.rs
diff options
context:
space:
mode:
authorAlec Di Vito <me@alecdivito.com>2025-02-17 04:35:26 +0000
committerAlec Di Vito <me@alecdivito.com>2025-02-17 04:35:26 +0000
commit413a63a60307bdf60229670b0f858963604d62a3 (patch)
tree45c4e203bf2d39d7cc13b96b30813be6ce44fb74 /src/file_op.rs
parentMerge branch 'svenstaro:master' into upload-progress-bar (diff)
downloadminiserve-413a63a60307bdf60229670b0f858963604d62a3.tar.gz
miniserve-413a63a60307bdf60229670b0f858963604d62a3.zip
feat: implement temporary file uploads and tweak mobile design
Diffstat (limited to 'src/file_op.rs')
-rw-r--r--src/file_op.rs140
1 files changed, 113 insertions, 27 deletions
diff --git a/src/file_op.rs b/src/file_op.rs
index 76a7234..367517a 100644
--- a/src/file_op.rs
+++ b/src/file_op.rs
@@ -4,10 +4,10 @@ use std::io::ErrorKind;
use std::path::{Component, Path, PathBuf};
use actix_web::{http::header, web, HttpRequest, HttpResponse};
-use futures::{StreamExt, TryFutureExt};
-use futures::TryStreamExt;
+use futures::{StreamExt, TryStreamExt};
use serde::Deserialize;
-use tokio::fs::File;
+use sha2::{Digest, Sha256};
+use tempfile::NamedTempFile;
use tokio::io::AsyncWriteExt;
use crate::{
@@ -15,6 +15,18 @@ use crate::{
file_utils::sanitize_path,
};
+enum FileHash {
+ SHA256(String),
+}
+
+impl FileHash {
+ pub fn get_hasher(&self) -> impl Digest {
+ match self {
+ Self::SHA256(_) => Sha256::new(),
+ }
+ }
+}
+
/// Saves file data from a multipart form field (`field`) to `file_path`, optionally overwriting
/// existing file.
///
@@ -23,31 +35,84 @@ async fn save_file(
field: &mut actix_multipart::Field,
file_path: PathBuf,
overwrite_files: bool,
+ file_hash: Option<&FileHash>,
) -> Result<u64, RuntimeError> {
if !overwrite_files && file_path.exists() {
return Err(RuntimeError::DuplicateFileError);
}
- let file = match File::create(&file_path).await {
- Err(err) if err.kind() == ErrorKind::PermissionDenied => Err(
+ let named_temp_file = match tokio::task::spawn_blocking(|| NamedTempFile::new()).await {
+ Err(err) => Err(RuntimeError::MultipartError(format!(
+ "Failed to complete spawned task to create named temp file. {}",
+ err
+ ))),
+ Ok(Err(err)) if err.kind() == ErrorKind::PermissionDenied => Err(
RuntimeError::InsufficientPermissionsError(file_path.display().to_string()),
),
- Err(err) => Err(RuntimeError::IoError(
- format!("Failed to create {}", file_path.display()),
- err,
+ Ok(Err(err)) => Err(RuntimeError::IoError(
+ format!("Failed to create temporary file {}", file_path.display()),
+ err,
)),
- Ok(v) => Ok(v),
+ Ok(Ok(file)) => Ok(file),
}?;
- let (_, written_len) = field
- .map_err(|x| RuntimeError::MultipartError(x.to_string()))
- .try_fold((file, 0u64), |(mut file, written_len), bytes| async move {
- file.write_all(bytes.as_ref())
- .map_err(|e| RuntimeError::IoError("Failed to write to file".to_string(), e))
- .await?;
- Ok((file, written_len + bytes.len() as u64))
- })
- .await?;
+ let (file, temp_path) = named_temp_file.keep().map_err(|err| {
+ RuntimeError::IoError("Failed to keep temporary file".into(), err.error.into())
+ })?;
+ let mut temp_file = tokio::fs::File::from_std(file);
+
+ let mut written_len = 0;
+ let mut hasher = file_hash.as_ref().map(|h| h.get_hasher());
+ let mut error: Option<RuntimeError> = None;
+
+ while let Some(Ok(bytes)) = field.next().await {
+ if let Some(hasher) = hasher.as_mut() {
+ hasher.update(&bytes)
+ }
+ if let Err(e) = temp_file.write_all(&bytes).await {
+ error = Some(RuntimeError::IoError(
+ "Failed to write to file".to_string(),
+ e,
+ ));
+ break;
+ }
+ written_len += bytes.len() as u64;
+ }
+
+ drop(temp_file);
+
+ if let Some(e) = error {
+ let _ = tokio::fs::remove_file(temp_path).await;
+ return Err(e);
+ }
+
+ // There isn't a way to get notified when a request is cancelled
+ // by the user in actix it seems. References:
+ // - https://github.com/actix/actix-web/issues/1313
+ // - https://github.com/actix/actix-web/discussions/3011
+ // Therefore, we are relying on the fact that the web UI
+ // uploads a hash of the file.
+ if let Some(hasher) = hasher {
+ if let Some(FileHash::SHA256(expected_hash)) = file_hash {
+ let actual_hash = hex::encode(hasher.finalize());
+ if &actual_hash != expected_hash {
+ let _ = tokio::fs::remove_file(&temp_path).await;
+ return Err(RuntimeError::UploadHashMismatchError);
+ }
+ }
+ }
+
+ if let Err(e) = tokio::fs::rename(&temp_path, &file_path).await {
+ let _ = tokio::fs::remove_file(&temp_path).await;
+ return Err(RuntimeError::IoError(
+ format!(
+ "Failed to move temporary file {} to {}",
+ temp_path.display(),
+ file_path.display()
+ ),
+ e,
+ ));
+ }
Ok(written_len)
}
@@ -60,6 +125,7 @@ async fn handle_multipart(
allow_mkdir: bool,
allow_hidden_paths: bool,
allow_symlinks: bool,
+ file_hash: Option<&FileHash>,
) -> Result<u64, RuntimeError> {
let field_name = field.name().expect("No name field found").to_string();
@@ -168,15 +234,13 @@ async fn handle_multipart(
}
}
- match save_file(&mut field, path.join(filename_path), overwrite_files).await {
- Ok(bytes) => Ok(bytes),
- Err(err) => {
- // Required for file upload. If entire stream is not consumed, javascript
- // XML HTTP Request will never complete.
- while field.next().await.is_some() {}
- Err(err)
- },
- }
+ save_file(
+ &mut field,
+ path.join(filename_path),
+ overwrite_files,
+ file_hash,
+ )
+ .await
}
/// Query parameters used by upload and rm APIs
@@ -226,6 +290,27 @@ pub async fn upload_file(
)),
}?;
+ let mut file_hash: Option<FileHash> = None;
+ if let Some(hash) = req
+ .headers()
+ .get("X-File-Hash")
+ .and_then(|h| h.to_str().ok())
+ {
+ if let Some(hash_funciton) = req
+ .headers()
+ .get("X-File-Hash-Function")
+ .and_then(|h| h.to_str().ok())
+ {
+ match hash_funciton.to_ascii_uppercase().as_str() {
+ "SHA256" => {
+ file_hash = Some(FileHash::SHA256(hash.to_string()));
+ }
+ _ => {}
+ }
+ }
+ }
+
+ let hash_ref = file_hash.as_ref();
actix_multipart::Multipart::new(req.headers(), payload)
.map_err(|x| RuntimeError::MultipartError(x.to_string()))
.and_then(|field| {
@@ -236,6 +321,7 @@ pub async fn upload_file(
conf.mkdir_enabled,
conf.show_hidden,
!conf.no_symlinks,
+ hash_ref,
)
})
.try_collect::<Vec<u64>>()