From 336be837af96a88ebd8dc8368bf7c2716ef2649a Mon Sep 17 00:00:00 2001 From: tison Date: Wed, 3 Jun 2026 21:29:37 +0800 Subject: [PATCH 1/4] Add traversable fold support --- README.md | 43 +++- traversable-derive/Cargo.toml | 2 +- traversable-derive/README.md | 3 +- traversable-derive/src/lib.rs | 222 ++++++++++++++++++++ traversable/src/function.rs | 150 ++++++++++++++ traversable/src/impls/mod.rs | 15 ++ traversable/src/impls/ordered_float_5.rs | 10 + traversable/src/impls/std_container.rs | 145 +++++++++++++ traversable/src/impls/std_primary.rs | 2 + traversable/src/impls/trivial.rs | 2 + traversable/src/impls/tuple.rs | 18 ++ traversable/src/lib.rs | 97 ++++++++- traversable/tests/test_folder.rs | 251 +++++++++++++++++++++++ 13 files changed, 950 insertions(+), 10 deletions(-) create mode 100644 traversable/tests/test_folder.rs diff --git a/README.md b/README.md index 26b5a24..ea36f1b 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ ## Overview -This crate provides traits and proc macros to implement the visitor pattern for arbitrary data structures. This pattern is particularly useful when dealing with complex nested data structures, abstract trees and hierarchies of all kinds. +This crate provides traits and proc macros to implement visitor and folder patterns for arbitrary data structures. These patterns are particularly useful when dealing with complex nested data structures, abstract trees and hierarchies of all kinds. ## Quick Start @@ -87,12 +87,49 @@ fn main() { } ``` +Use `TraversableFold` for owned bottom-up rewrites: + +```rust +use std::ops::ControlFlow; + +use traversable::TraversableFold; +use traversable::function::folder_leave; + +#[derive(TraversableFold)] +enum Expr { + Add(Box, Box), + Literal(i32), +} + +fn simplify(expr: Expr) -> Expr { + match expr { + Expr::Add(left, right) => match (*left, *right) { + (Expr::Literal(0), expr) | (expr, Expr::Literal(0)) => expr, + (left, right) => Expr::Add(Box::new(left), Box::new(right)), + }, + expr => expr, + } +} + +fn main() { + let expr = Expr::Add(Box::new(Expr::Literal(0)), Box::new(Expr::Literal(1))); + let mut folder = folder_leave::(|expr| ControlFlow::Continue(simplify(expr))); + + let expr = match expr.traverse_fold(&mut folder) { + ControlFlow::Continue(expr) => expr, + ControlFlow::Break(()) => unreachable!(), + }; + + assert!(matches!(expr, Expr::Literal(1))); +} +``` + ## Attributes The derive macro supports the following attributes on structs and enums: -* `#[traverse(skip_self)]`: Skips calling the visitor for the annotated type while still traversing its children. -* `#[traverse(skip_children)]`: Calls the visitor for the annotated type without traversing its children. +* `#[traverse(skip_self)]`: Skips calling the visitor or folder for the annotated type while still traversing its children. +* `#[traverse(skip_children)]`: Calls the visitor or folder for the annotated type without traversing its children. The derive macro supports the following attributes on fields and variants: diff --git a/traversable-derive/Cargo.toml b/traversable-derive/Cargo.toml index 34e51f9..cec60de 100644 --- a/traversable-derive/Cargo.toml +++ b/traversable-derive/Cargo.toml @@ -16,7 +16,7 @@ name = "traversable-derive" version = "0.2.0" -description = "Procedural macro to derive Traversable and TraversableMut" +description = "Procedural macro to derive Traversable, TraversableMut, and TraversableFold" documentation = "https://docs.rs/traversable-derive" keywords = ["visitor", "traverse", "traversable"] readme = "README.md" diff --git a/traversable-derive/README.md b/traversable-derive/README.md index 4669c15..3aced0e 100644 --- a/traversable-derive/README.md +++ b/traversable-derive/README.md @@ -20,4 +20,5 @@ This is an implementation crate for the [`traversable`](https://crates.io/crates Please refer to the main [`traversable`](https://crates.io/crates/traversable) crate for documentation and usage examples. -This crate contains procedural macros that derive `Traversable` and `TraversableMut` implementations. +This crate contains procedural macros that derive `Traversable`, `TraversableMut`, and +`TraversableFold` implementations. diff --git a/traversable-derive/src/lib.rs b/traversable-derive/src/lib.rs index a3f67fc..e42ca48 100644 --- a/traversable-derive/src/lib.rs +++ b/traversable-derive/src/lib.rs @@ -55,6 +55,11 @@ pub fn derive_traversable_mut(input: proc_macro::TokenStream) -> proc_macro::Tok expand_with(input, |stream| impl_traversable(stream, true)) } +#[proc_macro_derive(TraversableFold, attributes(traverse))] +pub fn derive_traversable_fold(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + expand_with(input, impl_traversable_fold) +} + fn expand_with( input: proc_macro::TokenStream, handler: impl Fn(DeriveInput) -> Result, @@ -442,3 +447,220 @@ fn traverse_field(value: &TokenStream, field: Field, mutable: bool) -> Result Result { + let mut params = Params::from_attrs(input.attrs, "traverse")?; + params.validate(&["skip_self", "skip_children"])?; + + let skip_visit_self = params + .param("skip_self")? + .map(Param::unit) + .transpose()? + .is_some(); + let skip_children = params + .param("skip_children")? + .map(Param::unit) + .transpose()? + .is_some(); + + let name = input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let crate_name = resolve_crate_name(); + + let enter_self = if skip_visit_self { + quote! { + let this = self; + } + } else { + quote! { + let this = #crate_name::Folder::enter(folder, self)?; + } + }; + + let fold_children = match input.data { + Data::Struct(struct_) => { + if skip_children { + Ok(TokenStream::new()) + } else { + fold_struct(struct_) + } + } + Data::Enum(enum_) => { + if skip_children { + Ok(TokenStream::new()) + } else { + fold_enum(enum_) + } + } + Data::Union(union_) => { + return Err(Error::new_spanned( + union_.union_token, + "unions are not supported", + )); + } + }?; + + let leave_self = if skip_visit_self { + TokenStream::new() + } else { + quote! { + let this = #crate_name::Folder::leave(folder, this)?; + } + }; + + Ok(quote! { + impl #impl_generics #crate_name::TraversableFold for #name #ty_generics #where_clause { + fn traverse_fold( + self, + folder: &mut V + ) -> ::core::ops::ControlFlow { + #enter_self + #fold_children + #leave_self + ::core::ops::ControlFlow::Continue(this) + } + } + }) +} + +fn fold_struct(s: DataStruct) -> Result { + Ok(match s.fields { + Fields::Named(fields) => { + let mut field_names = Vec::new(); + let mut fold_fields = Vec::new(); + + for field in fields.named { + let field_name = field.ident.clone().unwrap(); + fold_fields.push(fold_field(&field_name.to_token_stream(), field)?); + field_names.push(field_name); + } + + quote! { + let this = match this { + Self { #( #field_names ),* } => { + #( #fold_fields )* + Self { #( #field_names ),* } + } + }; + } + } + Fields::Unnamed(fields) => { + let mut field_names = Vec::new(); + let mut fold_fields = Vec::new(); + + for (index, field) in fields.unnamed.into_iter().enumerate() { + let field_name = Ident::new(&format!("i{index}"), Span::call_site()); + fold_fields.push(fold_field(&field_name.to_token_stream(), field)?); + field_names.push(field_name); + } + + quote! { + let this = match this { + Self( #( #field_names ),* ) => { + #( #fold_fields )* + Self( #( #field_names ),* ) + } + }; + } + } + Fields::Unit => TokenStream::new(), + }) +} + +fn fold_enum(e: DataEnum) -> Result { + let variants = e + .variants + .into_iter() + .map(fold_variant) + .collect::>()?; + Ok(quote! { + let this = match this { + #variants + }; + }) +} + +fn fold_variant(v: Variant) -> Result { + let mut params = Params::from_attrs(v.attrs, "traverse")?; + params.validate(&["skip"])?; + let skip = params + .param("skip")? + .map(Param::unit) + .transpose()? + .is_some(); + + let name = v.ident; + Ok(match v.fields { + Fields::Named(fields) => { + let mut field_names = Vec::new(); + let mut fold_fields = Vec::new(); + + for field in fields.named { + let field_name = field.ident.clone().unwrap(); + if !skip { + fold_fields.push(fold_field(&field_name.to_token_stream(), field)?); + } + field_names.push(field_name); + } + + quote! { + Self::#name { #( #field_names ),* } => { + #( #fold_fields )* + Self::#name { #( #field_names ),* } + } + } + } + Fields::Unnamed(fields) => { + let mut field_names = Vec::new(); + let mut fold_fields = Vec::new(); + + for (index, field) in fields.unnamed.into_iter().enumerate() { + let field_name = Ident::new(&format!("i{index}"), Span::call_site()); + if !skip { + fold_fields.push(fold_field(&field_name.to_token_stream(), field)?); + } + field_names.push(field_name); + } + + quote! { + Self::#name( #( #field_names ),* ) => { + #( #fold_fields )* + Self::#name( #( #field_names ),* ) + } + } + } + Fields::Unit => { + quote! { + Self::#name => Self::#name + } + } + }) +} + +fn fold_field(value: &TokenStream, field: Field) -> Result { + let mut params = Params::from_attrs(field.attrs, "traverse")?; + params.validate(&["skip", "with"])?; + + if params + .param("skip")? + .map(Param::unit) + .transpose()? + .is_some() + { + return Ok(TokenStream::new()); + } + + let crate_name = resolve_crate_name(); + + match params.param("with")? { + None => Ok(quote! { + let #value = #crate_name::TraversableFold::traverse_fold(#value, folder)?; + }), + Some(traverse_fn) => { + let traverse_fn = traverse_fn.string_literal()?.parse::()?; + Ok(quote! { + let #value = #traverse_fn(#value, folder)?; + }) + } + } +} diff --git a/traversable/src/function.rs b/traversable/src/function.rs index 0e558ae..58d9428 100644 --- a/traversable/src/function.rs +++ b/traversable/src/function.rs @@ -15,9 +15,13 @@ //! Visitors from functions or closures. use core::any::Any; +use core::any::TypeId; use core::marker::PhantomData; +use core::mem::ManuallyDrop; use core::ops::ControlFlow; +use core::ptr; +use crate::Folder; use crate::Visitor; use crate::VisitorMut; @@ -29,6 +33,14 @@ pub struct FnVisitor { marker_break: PhantomData, } +/// Type returned by `folder` factories. +pub struct FnFolder { + enter: F1, + leave: F2, + marker_type: PhantomData, + marker_break: PhantomData, +} + impl Visitor for FnVisitor where T: Any, @@ -75,8 +87,60 @@ where } } +impl Folder for FnFolder +where + T: Any, + F1: FnMut(T) -> ControlFlow, + F2: FnMut(T) -> ControlFlow, +{ + type Break = B; + + fn enter(&mut self, this: U) -> ControlFlow { + fold_if_type(this, &mut self.enter) + } + + fn leave(&mut self, this: U) -> ControlFlow { + fold_if_type(this, &mut self.leave) + } +} + +fn fold_if_type(this: U, fold: &mut F) -> ControlFlow +where + T: Any, + U: Any, + F: FnMut(T) -> ControlFlow, +{ + if TypeId::of::() != TypeId::of::() { + return ControlFlow::Continue(this); + } + + let this = cast_between_equal_any_types::(this); + match fold(this) { + ControlFlow::Continue(this) => { + ControlFlow::Continue(cast_between_equal_any_types::(this)) + } + ControlFlow::Break(break_value) => ControlFlow::Break(break_value), + } +} + +fn cast_between_equal_any_types(from: From) -> To +where + From: Any, + To: Any, +{ + debug_assert_eq!(TypeId::of::(), TypeId::of::()); + + let from = ManuallyDrop::new(from); + // SAFETY: `fold_if_type` only calls this function after verifying that `From` and `To` have + // the same `TypeId`. `Any` is implemented only for `'static` concrete types, so equal TypeIds + // identify the same type. `ManuallyDrop` prevents dropping the source after ownership has been + // transferred through `ptr::read`. + unsafe { ptr::read((&*from as *const From).cast::()) } +} + type DefaultVisitFn = fn(&T) -> ControlFlow; type DefaultVisitFnMut = fn(&mut T) -> ControlFlow; +type DefaultFoldFn = fn(T) -> ControlFlow; /// Create a visitor that only visits items of a specific type from `enter` and `leave` closures. /// @@ -379,3 +443,89 @@ where marker_break: PhantomData, } } + +/// Create a folder that only folds items of a specific type from `enter` and `leave` closures. +/// +/// This is a convenience function for creating simple owned transforms without defining a new +/// struct and implementing the [`Folder`] trait manually. +/// +/// # Example +/// +/// ```rust +/// # #[cfg(not(feature = "derive"))] +/// # fn main() {} +/// # +/// # #[cfg(feature = "derive")] +/// # fn main() { +/// use core::ops::ControlFlow; +/// +/// use traversable::TraversableFold; +/// use traversable::function::folder_leave; +/// +/// #[derive(TraversableFold)] +/// enum Expr { +/// Add(Box, Box), +/// Literal(i32), +/// } +/// +/// fn simplify(expr: Expr) -> Expr { +/// match expr { +/// Expr::Add(left, right) => match (*left, *right) { +/// (Expr::Literal(0), expr) | (expr, Expr::Literal(0)) => expr, +/// (left, right) => Expr::Add(Box::new(left), Box::new(right)), +/// }, +/// expr => expr, +/// } +/// } +/// +/// let expr = Expr::Add(Box::new(Expr::Literal(0)), Box::new(Expr::Literal(1))); +/// let mut folder = folder_leave::(|expr| ControlFlow::Continue(simplify(expr))); +/// let expr = match expr.traverse_fold(&mut folder) { +/// ControlFlow::Continue(expr) => expr, +/// ControlFlow::Break(()) => unreachable!(), +/// }; +/// +/// assert!(matches!(expr, Expr::Literal(1))); +/// # } +/// ``` +pub fn folder(enter: F1, leave: F2) -> FnFolder +where + T: Any, + F1: FnMut(T) -> ControlFlow, + F2: FnMut(T) -> ControlFlow, +{ + FnFolder { + enter, + leave, + marker_type: PhantomData, + marker_break: PhantomData, + } +} + +/// Similar to [`folder`], but the closure will only be called on entering. +pub fn folder_enter(enter: F) -> FnFolder> +where + T: Any, + F: FnMut(T) -> ControlFlow, +{ + FnFolder { + enter, + leave: ControlFlow::Continue, + marker_type: PhantomData, + marker_break: PhantomData, + } +} + +/// Similar to [`folder`], but the closure will only be called on leaving. +pub fn folder_leave(leave: F) -> FnFolder, F> +where + T: Any, + F: FnMut(T) -> ControlFlow, +{ + FnFolder { + enter: ControlFlow::Continue, + leave, + marker_type: PhantomData, + marker_break: PhantomData, + } +} diff --git a/traversable/src/impls/mod.rs b/traversable/src/impls/mod.rs index 7397469..aef9a7e 100644 --- a/traversable/src/impls/mod.rs +++ b/traversable/src/impls/mod.rs @@ -28,6 +28,13 @@ macro_rules! blank_traverse_impl { ControlFlow::Continue(()) } } + + impl TraversableFold for $type { + #[inline] + fn traverse_fold(self, _folder: &mut V) -> ControlFlow { + ControlFlow::Continue(self) + } + } }; } @@ -49,6 +56,14 @@ macro_rules! trivial_traverse_impl { ControlFlow::Continue(()) } } + + impl TraversableFold for $type { + fn traverse_fold(self, folder: &mut V) -> ControlFlow { + let this = folder.enter(self)?; + let this = folder.leave(this)?; + ControlFlow::Continue(this) + } + } }; } diff --git a/traversable/src/impls/ordered_float_5.rs b/traversable/src/impls/ordered_float_5.rs index 527b684..02364d6 100644 --- a/traversable/src/impls/ordered_float_5.rs +++ b/traversable/src/impls/ordered_float_5.rs @@ -16,7 +16,9 @@ use core::ops::ControlFlow; use ordered_float_5::OrderedFloat; +use crate::Folder; use crate::Traversable; +use crate::TraversableFold; use crate::TraversableMut; use crate::Visitor; use crate::VisitorMut; @@ -36,3 +38,11 @@ impl TraversableMut for OrderedFloat { ControlFlow::Continue(()) } } + +impl TraversableFold for OrderedFloat { + fn traverse_fold(self, folder: &mut V) -> ControlFlow { + let this = folder.enter(self)?; + let this = folder.leave(this)?; + ControlFlow::Continue(this) + } +} diff --git a/traversable/src/impls/std_container.rs b/traversable/src/impls/std_container.rs index dbdd203..5de4ca7 100644 --- a/traversable/src/impls/std_container.rs +++ b/traversable/src/impls/std_container.rs @@ -19,7 +19,9 @@ use std::sync::Arc; use std::sync::Mutex; use std::sync::RwLock; +use crate::Folder; use crate::Traversable; +use crate::TraversableFold; use crate::TraversableMut; use crate::Visitor; use crate::VisitorMut; @@ -58,6 +60,37 @@ impl DerefAndTraverseMut for (TK, &mut TV) { } } +fn traverse_fold_items(items: I, folder: &mut V) -> ControlFlow +where + T: TraversableFold, + V: Folder, + I: IntoIterator, + C: FromIterator, +{ + let mut folded = std::vec::Vec::new(); + for item in items { + folded.push(item.traverse_fold(folder)?); + } + ControlFlow::Continue(folded.into_iter().collect()) +} + +fn traverse_fold_pairs(items: I, folder: &mut V) -> ControlFlow +where + K: TraversableFold, + Value: TraversableFold, + V: Folder, + I: IntoIterator, + C: FromIterator<(K, Value)>, +{ + let mut folded = std::vec::Vec::new(); + for (key, value) in items { + let key = key.traverse_fold(folder)?; + let value = value.traverse_fold(folder)?; + folded.push((key, value)); + } + ControlFlow::Continue(folded.into_iter().collect()) +} + macro_rules! impl_drive_for_into_iterator { ( $type:ty ; $($generics:tt)+ ) => { impl< $($generics)+ > Traversable for $type @@ -105,6 +138,80 @@ impl_drive_for_into_iterator! { std::collections::HashMap ; T, U } impl_drive_for_into_iterator! { Option ; T } impl_drive_for_into_iterator! { Result ; T, U } +macro_rules! impl_fold_for_collection { + ( $type:ty ; $($generics:tt)+ ) => { + impl< $($generics)+ > TraversableFold for $type + where + $type: 'static + IntoIterator + FromIterator<<$type as IntoIterator>::Item>, + <$type as IntoIterator>::Item: TraversableFold, + { + fn traverse_fold(self, folder: &mut V) -> ControlFlow { + traverse_fold_items(self, folder) + } + } + }; +} + +macro_rules! impl_fold_for_map { + ( $type:ty ; $($generics:tt)+ ) => { + impl< $($generics)+ > TraversableFold for $type + where + $type: 'static + IntoIterator + FromIterator<(T, U)>, + T: TraversableFold, + U: TraversableFold, + { + fn traverse_fold(self, folder: &mut V) -> ControlFlow { + traverse_fold_pairs(self, folder) + } + } + }; +} + +impl TraversableFold for [T; N] { + fn traverse_fold(self, folder: &mut V) -> ControlFlow { + let mut folded = std::vec::Vec::with_capacity(N); + for item in self { + folded.push(item.traverse_fold(folder)?); + } + let folded = match folded.try_into() { + Ok(folded) => folded, + Err(_) => unreachable!("folded array length must match the input array length"), + }; + ControlFlow::Continue(folded) + } +} + +impl_fold_for_collection! { std::vec::Vec ; T } +impl_fold_for_collection! { std::collections::BTreeSet ; T } +impl_fold_for_collection! { std::collections::BinaryHeap ; T } +impl_fold_for_collection! { std::collections::HashSet ; T } +impl_fold_for_collection! { std::collections::LinkedList ; T } +impl_fold_for_collection! { std::collections::VecDeque ; T } +impl_fold_for_map! { std::collections::BTreeMap ; T, U } +impl_fold_for_map! { std::collections::HashMap ; T, U } + +impl TraversableFold for Option { + fn traverse_fold(self, folder: &mut V) -> ControlFlow { + match self { + Some(item) => ControlFlow::Continue(Some(item.traverse_fold(folder)?)), + None => ControlFlow::Continue(None), + } + } +} + +impl TraversableFold for Result +where + T: TraversableFold, + U: 'static, +{ + fn traverse_fold(self, folder: &mut V) -> ControlFlow { + match self { + Ok(item) => ControlFlow::Continue(Ok(item.traverse_fold(folder)?)), + Err(error) => ControlFlow::Continue(Err(error)), + } + } +} + impl Traversable for Box { fn traverse(&self, visitor: &mut V) -> ControlFlow { (**self).traverse(visitor) @@ -117,6 +224,12 @@ impl TraversableMut for Box { } } +impl TraversableFold for Box { + fn traverse_fold(self, folder: &mut V) -> ControlFlow { + ControlFlow::Continue(Box::new((*self).traverse_fold(folder)?)) + } +} + impl Traversable for Arc { fn traverse(&self, visitor: &mut V) -> ControlFlow { (**self).traverse(visitor) @@ -143,6 +256,17 @@ where } } +impl TraversableFold for Mutex +where + T: TraversableFold, +{ + fn traverse_fold(self, folder: &mut V) -> ControlFlow { + let item = self.into_inner().unwrap(); + let item = item.traverse_fold(folder)?; + ControlFlow::Continue(Mutex::new(item)) + } +} + impl Traversable for RwLock where T: Traversable, @@ -163,6 +287,17 @@ where } } +impl TraversableFold for RwLock +where + T: TraversableFold, +{ + fn traverse_fold(self, folder: &mut V) -> ControlFlow { + let item = self.into_inner().unwrap(); + let item = item.traverse_fold(folder)?; + ControlFlow::Continue(RwLock::new(item)) + } +} + impl TraversableMut for Arc> where T: TraversableMut, @@ -200,3 +335,13 @@ where self.get_mut().traverse_mut(visitor) } } + +impl TraversableFold for Cell +where + T: TraversableFold, +{ + fn traverse_fold(self, folder: &mut V) -> ControlFlow { + let item = self.into_inner().traverse_fold(folder)?; + ControlFlow::Continue(Cell::new(item)) + } +} diff --git a/traversable/src/impls/std_primary.rs b/traversable/src/impls/std_primary.rs index cc9ba39..7f3dde1 100644 --- a/traversable/src/impls/std_primary.rs +++ b/traversable/src/impls/std_primary.rs @@ -15,7 +15,9 @@ use core::ops::ControlFlow; use std::string::String; +use crate::Folder; use crate::Traversable; +use crate::TraversableFold; use crate::TraversableMut; use crate::Visitor; use crate::VisitorMut; diff --git a/traversable/src/impls/trivial.rs b/traversable/src/impls/trivial.rs index ec3f37a..66eb6f9 100644 --- a/traversable/src/impls/trivial.rs +++ b/traversable/src/impls/trivial.rs @@ -14,7 +14,9 @@ use core::ops::ControlFlow; +use crate::Folder; use crate::Traversable; +use crate::TraversableFold; use crate::TraversableMut; use crate::Visitor; use crate::VisitorMut; diff --git a/traversable/src/impls/tuple.rs b/traversable/src/impls/tuple.rs index 62b0d9c..09616bb 100644 --- a/traversable/src/impls/tuple.rs +++ b/traversable/src/impls/tuple.rs @@ -14,7 +14,9 @@ use core::ops::ControlFlow; +use crate::Folder; use crate::Traversable; +use crate::TraversableFold; use crate::TraversableMut; use crate::Visitor; use crate::VisitorMut; @@ -49,6 +51,22 @@ macro_rules! tuple_impl { ControlFlow::Continue(()) } } + + impl<$( $type ),+> TraversableFold for ($($type,)+) + where + $( + $type: TraversableFold + ),+ + { + #[allow(non_snake_case)] + fn traverse_fold(self, folder: &mut V) -> ControlFlow { + let ($($type,)+) = self; + $( + let $type = $type.traverse_fold(folder)?; + )+ + ControlFlow::Continue(($($type,)+)) + } + } )+ }; } diff --git a/traversable/src/lib.rs b/traversable/src/lib.rs index 2ea856e..912e300 100644 --- a/traversable/src/lib.rs +++ b/traversable/src/lib.rs @@ -16,9 +16,9 @@ //! //! A visitor pattern implementation for traversing data structures. //! -//! This crate provides [`Traversable`] and [`TraversableMut`] traits for types that can be -//! traversed, as well as [`Visitor`] and [`VisitorMut`] traits for types that perform the -//! traversal. +//! This crate provides [`Traversable`], [`TraversableMut`], and [`TraversableFold`] traits for +//! types that can be traversed, as well as [`Visitor`], [`VisitorMut`], and [`Folder`] traits for +//! types that perform the traversal. //! //! It is designed to be flexible and efficient, allowing for deep traversal of complex data //! structures. @@ -29,7 +29,7 @@ //! //! ```toml //! [dependencies] -//! traversable = { version = "0.2", features = ["derive", "std"] } +//! traversable = { version = "0.3", features = ["derive", "std"] } //! ``` //! //! Define your data structures and derive [`Traversable`]: @@ -119,7 +119,8 @@ //! //! ## Features //! -//! * `derive`: Enables procedural macros `#[derive(Traversable)]` and `#[derive(TraversableMut)]`. +//! * `derive`: Enables procedural macros `#[derive(Traversable)]`, `#[derive(TraversableMut)]`, +//! and `#[derive(TraversableFold)]`. //! * `std`: Enables support for standard library types (e.g., `Vec`, `HashMap`, `Box`). //! * `traverse-trivial`: Enables traversal for primitive types (`u8`, `i32`, `bool`, etc.). By //! default, these are ignored. @@ -140,6 +141,9 @@ use core::ops::ControlFlow; /// See [`Traversable`]. pub use traversable_derive::Traversable; #[cfg(feature = "derive")] +/// See [`TraversableFold`]. +pub use traversable_derive::TraversableFold; +#[cfg(feature = "derive")] /// See [`TraversableMut`]. pub use traversable_derive::TraversableMut; @@ -251,6 +255,37 @@ pub trait VisitorMut { } } +/// A folder that can transform an owned data structure while traversing it. +/// +/// Implement this trait to define custom logic that receives owned nodes and returns the node that +/// should continue through traversal. This is useful for bottom-up rewrites such as simplifying an +/// expression tree without using temporary replacement values. +/// +/// [`TraversableFold`] calls [`Folder::enter`] before folding children and [`Folder::leave`] after +/// folding children. The default implementation returns each node unchanged. +/// +/// You can also use [`folder`] to create a folder from closures. +/// +/// [`folder`]: function::folder +pub trait Folder { + /// The type that can be used to break traversal early. + type Break; + + /// Called when the folder is entering an owned node. + /// + /// Default implementation returns the node unchanged and continues traversal. + fn enter(&mut self, this: T) -> ControlFlow { + ControlFlow::Continue(this) + } + + /// Called when the folder is leaving an owned node. + /// + /// Default implementation returns the node unchanged and continues traversal. + fn leave(&mut self, this: T) -> ControlFlow { + ControlFlow::Continue(this) + } +} + /// A trait for types that can be traversed by a visitor. /// /// This trait is the core of the traversable pattern. It allows a [`Visitor`] to @@ -408,3 +443,55 @@ pub trait TraversableMut: core::any::Any { /// Traverse the mutable data structure with the given visitor. fn traverse_mut(&mut self, visitor: &mut V) -> ControlFlow; } + +/// A trait for types that can be traversed and transformed by a folder. +/// +/// This trait consumes `self`, folds its children, and returns the rebuilt value. It is intended for +/// owned transformations where a node may need to be replaced by another value of the same type. +/// +/// # Deriving `TraversableFold` +/// +/// The easiest way to implement `TraversableFold` is to use the derive macro. +/// +/// ```rust +/// # #[cfg(not(feature = "derive"))] +/// # fn main() {} +/// # +/// # #[cfg(feature = "derive")] +/// # fn main() { +/// use traversable::TraversableFold; +/// +/// #[derive(TraversableFold)] +/// struct MyStruct { +/// data: u64, +/// #[traverse(skip)] +/// hidden: String, +/// } +/// # } +/// ``` +/// +/// # Attributes +/// +/// The derive macro supports the following attributes on structs and enums: +/// +/// * `#[traverse(skip_self)]`: Skips calling the folder for the annotated type while still folding +/// its children. +/// * `#[traverse(skip_children)]`: Calls the folder for the annotated type without folding its +/// children. +/// +/// The derive macro supports the following attributes on fields and variants: +/// +/// * `#[traverse(skip)]`: Skips folding into the annotated field or variant. +/// * `#[traverse(with = "function_name")]`: Uses a custom function to fold the field. +/// +/// ## Custom Fold Function +/// +/// When using `#[traverse(with = "path::to::func")]`, the function must have the signature: +/// +/// ```rust,ignore +/// fn func(item: ItemType, folder: &mut V) -> ControlFlow +/// ``` +pub trait TraversableFold: core::any::Any + Sized { + /// Traverse and transform the data structure with the given folder. + fn traverse_fold(self, folder: &mut V) -> ControlFlow; +} diff --git a/traversable/tests/test_folder.rs b/traversable/tests/test_folder.rs new file mode 100644 index 0000000..55344dc --- /dev/null +++ b/traversable/tests/test_folder.rs @@ -0,0 +1,251 @@ +// Copyright 2025 FastLabs Developers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![cfg(all(feature = "std", feature = "derive"))] + +use core::any::Any; +use core::any::TypeId; +use core::ops::ControlFlow; + +use traversable::Folder; +use traversable::TraversableFold; +use traversable::function::folder_leave; + +fn into_continue(flow: ControlFlow<(), T>) -> T { + match flow { + ControlFlow::Continue(value) => value, + ControlFlow::Break(()) => unreachable!(), + } +} + +#[derive(Debug, PartialEq, Eq, TraversableFold)] +enum Expr { + Add(Box, Box), + Literal(i32), +} + +fn simplify(expr: Expr) -> Expr { + match expr { + Expr::Add(left, right) => match (*left, *right) { + (Expr::Literal(0), expr) | (expr, Expr::Literal(0)) => expr, + (left, right) => Expr::Add(Box::new(left), Box::new(right)), + }, + expr => expr, + } +} + +#[test] +fn folder_leave_rewrites_bottom_up() { + let expr = Expr::Add( + Box::new(Expr::Add( + Box::new(Expr::Literal(0)), + Box::new(Expr::Literal(1)), + )), + Box::new(Expr::Literal(0)), + ); + + let mut folder = folder_leave::(|expr| ControlFlow::Continue(simplify(expr))); + let expr = into_continue(expr.traverse_fold(&mut folder)); + + assert_eq!(expr, Expr::Literal(1)); +} + +#[derive(Debug, PartialEq, Eq, TraversableFold)] +struct Child { + value: u64, +} + +#[derive(TraversableFold)] +struct Pair { + #[traverse(with = "fold_and_double")] + folded: Child, + #[traverse(skip)] + skipped: Child, +} + +fn fold_and_double(child: Child, folder: &mut V) -> ControlFlow { + let mut child = child.traverse_fold(folder)?; + child.value *= 2; + ControlFlow::Continue(child) +} + +#[test] +fn field_attributes_control_fold_behavior() { + let pair = Pair { + folded: Child { value: 1 }, + skipped: Child { value: 10 }, + }; + + let mut folder = folder_leave::(|mut child| { + child.value += 1; + ControlFlow::Continue(child) + }); + let pair = into_continue(pair.traverse_fold(&mut folder)); + + assert_eq!(pair.folded, Child { value: 4 }); + assert_eq!(pair.skipped, Child { value: 10 }); +} + +#[test] +fn std_containers_fold_owned_items() { + let mut folder = folder_leave::(|mut child| { + child.value += 1; + ControlFlow::Continue(child) + }); + + let values = vec![Child { value: 1 }, Child { value: 2 }]; + let values = into_continue(values.traverse_fold(&mut folder)); + + assert_eq!(values, vec![Child { value: 2 }, Child { value: 3 }]); + + let value = Some(Child { value: 4 }); + let value = into_continue(value.traverse_fold(&mut folder)); + + assert_eq!(value, Some(Child { value: 5 })); + + let value = Ok::<_, Child>(Child { value: 6 }); + let value = into_continue(value.traverse_fold(&mut folder)); + + assert_eq!(value, Ok(Child { value: 7 })); +} + +#[derive(TraversableFold)] +struct Parent { + child: Child, +} + +#[derive(TraversableFold)] +#[traverse(skip_self)] +struct SkipSelfParent { + child: Child, +} + +#[derive(TraversableFold)] +#[traverse(skip_children)] +#[allow(dead_code)] +struct SkipChildrenParent { + child: Child, +} + +#[derive(TraversableFold)] +#[traverse(skip_self, skip_children)] +#[allow(dead_code)] +struct SkipSelfAndChildrenParent { + child: Child, +} + +#[derive(Default)] +struct Counts { + parent_enter: usize, + parent_leave: usize, + child_enter: usize, + child_leave: usize, + skip_self_parent_enter: usize, + skip_self_parent_leave: usize, + skip_children_parent_enter: usize, + skip_children_parent_leave: usize, + skip_self_and_children_parent_enter: usize, + skip_self_and_children_parent_leave: usize, +} + +impl Folder for Counts { + type Break = (); + + fn enter(&mut self, this: T) -> ControlFlow { + if TypeId::of::() == TypeId::of::() { + self.parent_enter += 1; + } else if TypeId::of::() == TypeId::of::() { + self.child_enter += 1; + } else if TypeId::of::() == TypeId::of::() { + self.skip_self_parent_enter += 1; + } else if TypeId::of::() == TypeId::of::() { + self.skip_children_parent_enter += 1; + } else if TypeId::of::() == TypeId::of::() { + self.skip_self_and_children_parent_enter += 1; + } + + ControlFlow::Continue(this) + } + + fn leave(&mut self, this: T) -> ControlFlow { + if TypeId::of::() == TypeId::of::() { + self.parent_leave += 1; + } else if TypeId::of::() == TypeId::of::() { + self.child_leave += 1; + } else if TypeId::of::() == TypeId::of::() { + self.skip_self_parent_leave += 1; + } else if TypeId::of::() == TypeId::of::() { + self.skip_children_parent_leave += 1; + } else if TypeId::of::() == TypeId::of::() { + self.skip_self_and_children_parent_leave += 1; + } + + ControlFlow::Continue(this) + } +} + +#[test] +fn type_level_attributes_control_fold_behavior() { + let mut counts = Counts::default(); + let parent = Parent { + child: Child { value: 1 }, + }; + + let result = parent.traverse_fold(&mut counts); + + assert!(result.is_continue()); + assert_eq!(counts.parent_enter, 1); + assert_eq!(counts.parent_leave, 1); + assert_eq!(counts.child_enter, 1); + assert_eq!(counts.child_leave, 1); + + let mut counts = Counts::default(); + let skip_self = SkipSelfParent { + child: Child { value: 1 }, + }; + + let result = skip_self.traverse_fold(&mut counts); + + assert!(result.is_continue()); + assert_eq!(counts.skip_self_parent_enter, 0); + assert_eq!(counts.skip_self_parent_leave, 0); + assert_eq!(counts.child_enter, 1); + assert_eq!(counts.child_leave, 1); + + let mut counts = Counts::default(); + let skip_children = SkipChildrenParent { + child: Child { value: 1 }, + }; + + let result = skip_children.traverse_fold(&mut counts); + + assert!(result.is_continue()); + assert_eq!(counts.skip_children_parent_enter, 1); + assert_eq!(counts.skip_children_parent_leave, 1); + assert_eq!(counts.child_enter, 0); + assert_eq!(counts.child_leave, 0); + + let mut counts = Counts::default(); + let skip_self_and_children = SkipSelfAndChildrenParent { + child: Child { value: 1 }, + }; + + let result = skip_self_and_children.traverse_fold(&mut counts); + + assert!(result.is_continue()); + assert_eq!(counts.skip_self_and_children_parent_enter, 0); + assert_eq!(counts.skip_self_and_children_parent_leave, 0); + assert_eq!(counts.child_enter, 0); + assert_eq!(counts.child_leave, 0); +} From d9978c1eea0183ddf0996362e1884bf9a21a0884 Mon Sep 17 00:00:00 2001 From: tison Date: Wed, 3 Jun 2026 21:42:36 +0800 Subject: [PATCH 2/4] fixup Signed-off-by: tison --- traversable/src/function.rs | 4 ++-- traversable/src/lib.rs | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/traversable/src/function.rs b/traversable/src/function.rs index 58d9428..84baa9d 100644 --- a/traversable/src/function.rs +++ b/traversable/src/function.rs @@ -452,10 +452,10 @@ where /// # Example /// /// ```rust -/// # #[cfg(not(feature = "derive"))] +/// # #[cfg(not(all(feature = "derive", feature = "std")))] /// # fn main() {} /// # -/// # #[cfg(feature = "derive")] +/// # #[cfg(all(feature = "derive", feature = "std"))] /// # fn main() { /// use core::ops::ControlFlow; /// diff --git a/traversable/src/lib.rs b/traversable/src/lib.rs index 912e300..ca889c4 100644 --- a/traversable/src/lib.rs +++ b/traversable/src/lib.rs @@ -119,8 +119,8 @@ //! //! ## Features //! -//! * `derive`: Enables procedural macros `#[derive(Traversable)]`, `#[derive(TraversableMut)]`, -//! and `#[derive(TraversableFold)]`. +//! * `derive`: Enables procedural macros `#[derive(Traversable)]`, `#[derive(TraversableMut)]`, and +//! `#[derive(TraversableFold)]`. //! * `std`: Enables support for standard library types (e.g., `Vec`, `HashMap`, `Box`). //! * `traverse-trivial`: Enables traversal for primitive types (`u8`, `i32`, `bool`, etc.). By //! default, these are ignored. @@ -446,8 +446,9 @@ pub trait TraversableMut: core::any::Any { /// A trait for types that can be traversed and transformed by a folder. /// -/// This trait consumes `self`, folds its children, and returns the rebuilt value. It is intended for -/// owned transformations where a node may need to be replaced by another value of the same type. +/// This trait consumes `self`, folds its children, and returns the rebuilt value. It is intended +/// for owned transformations where a node may need to be replaced by another value of the same +/// type. /// /// # Deriving `TraversableFold` /// From f7b61c98ba59bd492fd42fea8767d9f4b69dcf09 Mon Sep 17 00:00:00 2001 From: tison Date: Thu, 4 Jun 2026 09:36:46 +0800 Subject: [PATCH 3/4] fixup Signed-off-by: tison --- traversable/src/function.rs | 45 ++++++++++++++----------------------- 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/traversable/src/function.rs b/traversable/src/function.rs index 84baa9d..9ac753f 100644 --- a/traversable/src/function.rs +++ b/traversable/src/function.rs @@ -15,11 +15,8 @@ //! Visitors from functions or closures. use core::any::Any; -use core::any::TypeId; use core::marker::PhantomData; -use core::mem::ManuallyDrop; use core::ops::ControlFlow; -use core::ptr; use crate::Folder; use crate::Visitor; @@ -96,46 +93,38 @@ where type Break = B; fn enter(&mut self, this: U) -> ControlFlow { - fold_if_type(this, &mut self.enter) + fold(this, &mut self.enter) } fn leave(&mut self, this: U) -> ControlFlow { - fold_if_type(this, &mut self.leave) + fold(this, &mut self.leave) } } -fn fold_if_type(this: U, fold: &mut F) -> ControlFlow +fn fold(this: U, fold: &mut F) -> ControlFlow where T: Any, U: Any, F: FnMut(T) -> ControlFlow, { + use core::any::TypeId; + use core::mem::ManuallyDrop; + use core::mem::transmute_copy; + if TypeId::of::() != TypeId::of::() { return ControlFlow::Continue(this); } - let this = cast_between_equal_any_types::(this); - match fold(this) { - ControlFlow::Continue(this) => { - ControlFlow::Continue(cast_between_equal_any_types::(this)) - } - ControlFlow::Break(break_value) => ControlFlow::Break(break_value), - } -} - -fn cast_between_equal_any_types(from: From) -> To -where - From: Any, - To: Any, -{ - debug_assert_eq!(TypeId::of::(), TypeId::of::()); - - let from = ManuallyDrop::new(from); - // SAFETY: `fold_if_type` only calls this function after verifying that `From` and `To` have - // the same `TypeId`. `Any` is implemented only for `'static` concrete types, so equal TypeIds - // identify the same type. `ManuallyDrop` prevents dropping the source after ownership has been - // transferred through `ptr::read`. - unsafe { ptr::read((&*from as *const From).cast::()) } + let this = unsafe { + let this = ManuallyDrop::new(this); + transmute_copy(&this) + }; + let this = fold(this)?; + let this = unsafe { + let this = ManuallyDrop::new(this); + transmute_copy(&this) + }; + ControlFlow::Continue(this) } type DefaultVisitFn = fn(&T) -> ControlFlow; From b61ecb4383474da2acd32b61ca7f03084f497eac Mon Sep 17 00:00:00 2001 From: tison Date: Thu, 4 Jun 2026 09:39:41 +0800 Subject: [PATCH 4/4] Document folder downcast safety --- traversable-derive/src/lib.rs | 49 ++++++++++------------------------- traversable/src/function.rs | 6 +++++ 2 files changed, 20 insertions(+), 35 deletions(-) diff --git a/traversable-derive/src/lib.rs b/traversable-derive/src/lib.rs index e42ca48..6377637 100644 --- a/traversable-derive/src/lib.rs +++ b/traversable-derive/src/lib.rs @@ -208,20 +208,16 @@ fn resolve_crate_name() -> Path { parse_quote!(::traversable) } +fn take_unit_param(params: &mut Params, name: &str) -> Result { + Ok(params.param(name)?.map(Param::unit).transpose()?.is_some()) +} + fn impl_traversable(input: DeriveInput, mutable: bool) -> Result { let mut params = Params::from_attrs(input.attrs, "traverse")?; params.validate(&["skip_self", "skip_children"])?; - let skip_visit_self = params - .param("skip_self")? - .map(Param::unit) - .transpose()? - .is_some(); - let skip_children = params - .param("skip_children")? - .map(Param::unit) - .transpose()? - .is_some(); + let skip_visit_self = take_unit_param(&mut params, "skip_self")?; + let skip_children = take_unit_param(&mut params, "skip_children")?; let name = input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); @@ -353,7 +349,7 @@ fn traverse_enum(e: DataEnum, mutable: bool) -> Result { fn traverse_variant(v: Variant, mutable: bool) -> Result { let mut params = Params::from_attrs(v.attrs, "traverse")?; params.validate(&["skip"])?; - if params.param("skip")?.map(Param::unit).is_some() { + if take_unit_param(&mut params, "skip")? { return Ok(TokenStream::new()); } let name = v.ident; @@ -390,7 +386,7 @@ fn destructure_fields(fields: Fields) -> Result { .map(|field| { let mut params = Params::from_attrs(field.attrs, "traverse")?; let field_name = field.ident.unwrap(); - Ok(if params.param("skip")?.map(Param::unit).is_some() { + Ok(if take_unit_param(&mut params, "skip")? { quote! { #field_name: _ } } else { field_name.into_token_stream() @@ -408,7 +404,7 @@ fn destructure_fields(fields: Fields) -> Result { .enumerate() .map(|(index, field)| { let mut params = Params::from_attrs(field.attrs, "traverse")?; - Ok(if params.param("skip")?.map(Param::unit).is_some() { + Ok(if take_unit_param(&mut params, "skip")? { quote! { _ } } else { Ident::new(&format!("i{index}",), Span::call_site()).into_token_stream() @@ -427,7 +423,7 @@ fn traverse_field(value: &TokenStream, field: Field, mutable: bool) -> Result Result { let mut params = Params::from_attrs(input.attrs, "traverse")?; params.validate(&["skip_self", "skip_children"])?; - let skip_visit_self = params - .param("skip_self")? - .map(Param::unit) - .transpose()? - .is_some(); - let skip_children = params - .param("skip_children")? - .map(Param::unit) - .transpose()? - .is_some(); + let skip_visit_self = take_unit_param(&mut params, "skip_self")?; + let skip_children = take_unit_param(&mut params, "skip_children")?; let name = input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); @@ -583,11 +571,7 @@ fn fold_enum(e: DataEnum) -> Result { fn fold_variant(v: Variant) -> Result { let mut params = Params::from_attrs(v.attrs, "traverse")?; params.validate(&["skip"])?; - let skip = params - .param("skip")? - .map(Param::unit) - .transpose()? - .is_some(); + let skip = take_unit_param(&mut params, "skip")?; let name = v.ident; Ok(match v.fields { @@ -641,12 +625,7 @@ fn fold_field(value: &TokenStream, field: Field) -> Result { let mut params = Params::from_attrs(field.attrs, "traverse")?; params.validate(&["skip", "with"])?; - if params - .param("skip")? - .map(Param::unit) - .transpose()? - .is_some() - { + if take_unit_param(&mut params, "skip")? { return Ok(TokenStream::new()); } diff --git a/traversable/src/function.rs b/traversable/src/function.rs index 9ac753f..945e5e5 100644 --- a/traversable/src/function.rs +++ b/traversable/src/function.rs @@ -115,11 +115,17 @@ where return ControlFlow::Continue(this); } + // SAFETY: The `TypeId` check above proves that `T` and `U` are the same concrete `'static` + // type. `ManuallyDrop` prevents the original `U` value from being dropped after its bits are + // copied into the owned `T` value that is passed to the typed closure. let this = unsafe { let this = ManuallyDrop::new(this); transmute_copy(&this) }; let this = fold(this)?; + // SAFETY: The same `TypeId` equality still proves that `T` and `U` are the same concrete + // `'static` type. `ManuallyDrop` prevents the closure result from being dropped after its bits + // are copied back into the caller's expected `U` type. let this = unsafe { let this = ManuallyDrop::new(this); transmute_copy(&this)