diff --git a/examples/walkdir.rs b/examples/walkdir.rs index eff0d26..774214c 100644 --- a/examples/walkdir.rs +++ b/examples/walkdir.rs @@ -18,6 +18,7 @@ Options: --max-depth NUM Maximum depth. -n, --fd-max NUM Maximum open file descriptors. [default: 32] --tree Show output as a tree. + --sort Sort the output. -q, --ignore-errors Ignore errors. "; @@ -31,6 +32,7 @@ struct Args { flag_fd_max: usize, flag_tree: bool, flag_ignore_errors: bool, + flag_sort: bool, } macro_rules! wout { ($($tt:tt)*) => { {writeln!($($tt)*)}.unwrap() } } @@ -40,12 +42,16 @@ fn main() { .unwrap_or_else(|e| e.exit()); let mind = args.flag_min_depth.unwrap_or(0); let maxd = args.flag_max_depth.unwrap_or(::std::usize::MAX); - let it = WalkDir::new(args.arg_dir.clone().unwrap_or(".".to_owned())) + let dir = args.arg_dir.clone().unwrap_or(".".to_owned()); + let mut walkdir = WalkDir::new(dir) .max_open(args.flag_fd_max) .follow_links(args.flag_follow_links) .min_depth(mind) - .max_depth(maxd) - .into_iter(); + .max_depth(maxd); + if args.flag_sort { + walkdir = walkdir.sort_by(|a,b| a.cmp(b)); + } + let it = walkdir.into_iter(); let mut out = io::BufWriter::new(io::stdout()); let mut eout = io::stderr(); if args.flag_tree { diff --git a/src/lib.rs b/src/lib.rs index 20e3f94..012cac6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -95,6 +95,7 @@ use std::fmt; use std::fs::{self, FileType, ReadDir}; use std::io; use std::ffi::OsStr; +use std::ffi::OsString; use std::path::{Path, PathBuf}; use std::result; use std::vec; @@ -190,6 +191,7 @@ struct WalkDirOptions { max_open: usize, min_depth: usize, max_depth: usize, + sorter: Option std::cmp::Ordering>>, } impl WalkDir { @@ -204,6 +206,7 @@ impl WalkDir { max_open: 10, min_depth: 0, max_depth: ::std::usize::MAX, + sorter: None, }, root: root.as_ref().to_path_buf(), } @@ -285,6 +288,25 @@ impl WalkDir { self.opts.max_open = n; self } + + /// Set a function for sorting directory entries. + /// If a compare function is set, WalkDir will return all paths in sorted + /// order. The compare function will be called to compare names from entries + /// from the same directory. Just the file_name() part of the paths is + /// passed to the compare function. + /// If no function is set, the entries will not be sorted. + /// + /// ```rust,no-run + /// use std::cmp; + /// use std::ffi::OsString; + /// use walkdir::WalkDir; + /// + /// WalkDir::new("foo").sort_by(|a,b| a.cmp(b)); + /// ``` + pub fn sort_by std::cmp::Ordering + 'static>(mut self, cmp: F) -> Self { + self.opts.sorter = Some(Box::new(cmp)); + self + } } impl IntoIterator for WalkDir { @@ -545,7 +567,21 @@ impl Iter { let rd = fs::read_dir(dent.path()).map_err(|err| { Some(Error::from_path(self.depth, dent.path().to_path_buf(), err)) }); - self.stack_list.push(DirList::Opened { depth: self.depth, it: rd }); + let mut list = DirList::Opened { depth: self.depth, it: rd }; + if let Some(ref mut cmp) = self.opts.sorter { + let mut entries = list.collect::>(); + entries.sort_by(|a, b| { + match (a, b) { + (&Ok(ref a), &Ok(ref b)) + => cmp(&a.file_name(), &b.file_name()), + (&Err(_), &Err(_)) => std::cmp::Ordering::Equal, + (&Ok(_), &Err(_)) => std::cmp::Ordering::Greater, + (&Err(_), &Ok(_)) => std::cmp::Ordering::Less, + } + }); + list = DirList::Closed ( entries.into_iter() ); + } + self.stack_list.push(list); if self.opts.follow_links { self.stack_path.push(dent.path().to_path_buf()); } diff --git a/src/tests.rs b/src/tests.rs index a853728..03955f4 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -638,3 +638,24 @@ fn qc_roundtrip_no_symlinks_with_follow() { .max_tests(10_000) .quickcheck(p as fn(Tree) -> bool); } + +#[test] +fn walk_dir_sort() { + let exp = td("foo", vec![ + tf("bar"), + td("abc", vec![tf("fit")]), + tf("faz"), + ]); + let tmp = tmpdir(); + let tmp_path = tmp.path(); + let tmp_len = tmp_path.to_str().unwrap().len(); + exp.create_in(tmp_path).unwrap(); + let it = WalkDir::new(tmp_path).sort_by(|a,b| a.cmp(b)).into_iter(); + let got = it.map(|d| { + let path = d.unwrap(); + let path = &path.path().to_str().unwrap()[tmp_len..]; + path.replace("\\", "/") + }).collect::>(); + assert_eq!(got, + ["", "/foo", "/foo/abc", "/foo/abc/fit", "/foo/bar", "/foo/faz"]); +}