Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 24 additions & 49 deletions src/iter/iterator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use pyo3::types::PyAnyMethods;
use std::sync::atomic;
use std::sync::Arc;

Expand All @@ -12,43 +11,31 @@ pub struct PyIterator {
#[pyo3::pymethods]
impl PyIterator {
#[new]
fn new(dom: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
let dom = dom
.extract::<pyo3::PyRef<'_, crate::tree::PyTreeDom>>()
.map_err(|_| {
pyo3::PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(
"expected TreeDom for dom, got {}",
crate::tools::get_type_name(dom)
))
})?;

Ok(Self {
fn new(dom: &crate::tree::PyTreeDom) -> Self {
Self {
dom: dom.dom.clone(),
index: atomic::AtomicUsize::new(0),
})
}
}

fn __iter__(self_: pyo3::PyRef<'_, Self>) -> pyo3::PyRef<'_, Self> {
self_
}

fn __next__(self_: pyo3::PyRef<'_, Self>) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
fn __next__(&self) -> pyo3::PyResult<crate::nodes::NodeGuard> {
let node = {
let tree = self_.dom.lock();
let tree = self.dom.lock();

// NOTE:
// Unfortunately the ego_tree crate does not let us to use directly usize for getting nodes.
match tree
.nodes()
.nth(self_.index.load(atomic::Ordering::Relaxed))
{
Some(x) => crate::nodes::NodeGuard::from_noderef(self_.dom.clone(), x),
match tree.nodes().nth(self.index.load(atomic::Ordering::Relaxed)) {
Some(x) => crate::nodes::NodeGuard::from_noderef(self.dom.clone(), x),
None => return Err(pyo3::PyErr::new::<pyo3::exceptions::PyStopIteration, _>(())),
}
};

self_.index.fetch_add(1, atomic::Ordering::Relaxed);
Ok(node.into_any(self_.py()))
self.index.fetch_add(1, atomic::Ordering::Relaxed);
Ok(node)
}
}

Expand All @@ -69,13 +56,8 @@ macro_rules! axis_iterators {
#[pyo3::pymethods]
impl $name {
#[new]
fn new(node: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
let node = crate::nodes::NodeGuard::from_pyobject(node).map_err(|_| {
pyo3::PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(
"expected a node (such as Element, Text, Comment, ...) for node, got {}",
crate::tools::get_type_name(node)
))
})?;
fn new(node: crate::nodes::PyNodeRef) -> pyo3::PyResult<Self> {
let node = node.as_node_guard();

Ok(Self { guard: $f(&node) })
}
Expand All @@ -84,12 +66,11 @@ macro_rules! axis_iterators {
self_
}

fn __next__(mut self_: pyo3::PyRefMut<'_, Self>) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
let node = self_.guard.take();
self_.guard = node.as_ref().and_then($f);
fn __next__(&mut self) -> pyo3::PyResult<crate::nodes::NodeGuard> {
let node = self.guard.take();
self.guard = node.as_ref().and_then($f);

node.map(|x| x.into_any(self_.py()))
.ok_or_else(|| pyo3::PyErr::new::<pyo3::exceptions::PyStopIteration, _>(()))
node.ok_or_else(|| pyo3::PyErr::new::<pyo3::exceptions::PyStopIteration, _>(()))
}
}
)*
Expand Down Expand Up @@ -122,13 +103,8 @@ pub struct PyChildren {
#[pyo3::pymethods]
impl PyChildren {
#[new]
fn new(node: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
let node = crate::nodes::NodeGuard::from_pyobject(node).map_err(|_| {
pyo3::PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(
"expected a node (such as Element, Text, Comment, ...) for node, got {}",
crate::tools::get_type_name(node)
))
})?;
fn new(node: crate::nodes::PyNodeRef) -> pyo3::PyResult<Self> {
let node = node.as_node_guard();

let front = node.first_child();
let back = node.last_child();
Expand All @@ -140,30 +116,29 @@ impl PyChildren {
self_
}

fn __next__(mut self_: pyo3::PyRefMut<'_, Self>) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
fn __next__(&mut self) -> pyo3::PyResult<crate::nodes::NodeGuard> {
let mut is_same = false;

if let (Some(x), Some(y)) = (&self_.front, &self_.back) {
if let (Some(x), Some(y)) = (&self.front, &self.back) {
if x.id == y.id {
is_same = true;
}
}

let node = {
if is_same {
let node = self_.front.take();
self_.back = None;
let node = self.front.take();
self.back = None;
node
} else {
let node = self_.front.take();
self_.front = node
let node = self.front.take();
self.front = node
.as_ref()
.and_then(crate::nodes::NodeGuard::next_sibling);
node
}
};

node.map(|x| x.into_any(self_.py()))
.ok_or_else(|| pyo3::PyErr::new::<pyo3::exceptions::PyStopIteration, _>(()))
node.ok_or_else(|| pyo3::PyErr::new::<pyo3::exceptions::PyStopIteration, _>(()))
}
}
34 changes: 10 additions & 24 deletions src/iter/traverse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,18 @@ impl PyTraverse {
#[pyo3::pymethods]
impl PyTraverse {
#[new]
fn new(node: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
let node = crate::nodes::NodeGuard::from_pyobject(node).map_err(|_| {
pyo3::PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(
"expected a node (such as Element, Text, Comment, ...) for node, got {}",
crate::tools::get_type_name(node)
))
})?;

fn new(node: crate::nodes::PyNodeRef) -> pyo3::PyResult<Self> {
let node = node.as_node_guard().clone();
Ok(Self::from_nodeguard(node))
}

fn __iter__(self_: pyo3::PyRef<'_, Self>) -> pyo3::PyRef<'_, Self> {
self_
}

pub fn __next__(mut self_: pyo3::PyRefMut<'_, Self>) -> pyo3::PyResult<(pyo3::Py<pyo3::PyAny>, bool)> {
let py = self_.py();
match self_.next_edge() {
Some((x, y)) => Ok((x.into_any(py), y)),
pub fn __next__(&mut self) -> pyo3::PyResult<(crate::nodes::NodeGuard, bool)> {
match self.next_edge() {
Some((x, y)) => Ok((x, y)),
None => Err(pyo3::PyErr::new::<pyo3::exceptions::PyStopIteration, _>(())),
}
}
Expand All @@ -86,13 +79,8 @@ pub struct PyDescendants(PyTraverse);
#[pyo3::pymethods]
impl PyDescendants {
#[new]
fn new(node: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
let node = crate::nodes::NodeGuard::from_pyobject(node).map_err(|_| {
pyo3::PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(
"expected a node (such as Element, Text, Comment, ...) for node, got {}",
crate::tools::get_type_name(node)
))
})?;
fn new(node: crate::nodes::PyNodeRef) -> pyo3::PyResult<Self> {
let node = node.as_node_guard().clone();

Ok(Self(PyTraverse {
root: Some(node),
Expand All @@ -104,15 +92,13 @@ impl PyDescendants {
self_
}

fn __next__(mut self_: pyo3::PyRefMut<'_, Self>) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
let py = self_.py();

while let Some((node, is_close)) = self_.0.next_edge() {
fn __next__(&mut self) -> pyo3::PyResult<crate::nodes::NodeGuard> {
while let Some((node, is_close)) = self.0.next_edge() {
if is_close {
continue;
}

return Ok(node.into_any(py));
return Ok(node);
}

Err(pyo3::PyErr::new::<pyo3::exceptions::PyStopIteration, _>(()))
Expand Down
Loading