From 589472d75ae9ad3aac6672bb19919f25a28f9b24 Mon Sep 17 00:00:00 2001 From: Mario Taddeucci Date: Fri, 13 Mar 2026 23:04:37 -0300 Subject: [PATCH 1/4] Add DataFrame iterator methods and corresponding tests - Implement `toLocalIterator`, `foreach`, and `foreachPartition` methods in the DataFrame class. - Add tests for `isEmpty`, `foreach`, `foreachPartition`, and `toLocalIterator` methods in the test suite. --- duckdb/experimental/spark/sql/dataframe.py | 157 ++++++++++++++++++++- tests/fast/spark/test_spark_dataframe.py | 39 +++++ 2 files changed, 189 insertions(+), 7 deletions(-) diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index e7519e81..b309d519 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -1,4 +1,6 @@ -import uuid # noqa: D100 +import itertools # noqa: D100 +import uuid +from collections.abc import Iterable, Iterator from functools import reduce from keyword import iskeyword from typing import ( @@ -31,6 +33,12 @@ from duckdb.experimental.spark.sql import functions as spark_sql_functions +def _construct_row(values: Iterable, names: list[str]) -> Row: + row = tuple.__new__(Row, list(values)) + row.__fields__ = list(names) + return row + + class DataFrame: # noqa: D101 def __init__(self, relation: duckdb.DuckDBPyRelation, session: "SparkSession") -> None: # noqa: D107 self.relation = relation @@ -71,6 +79,146 @@ def toArrow(self) -> "pa.Table": """ return self.relation.to_arrow_table() + def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]: + """Returns an iterator that contains all of the rows in this :class:`DataFrame`. + + The iterator will consume as much memory as the largest partition in this + :class:`DataFrame`. With prefetch it may consume up to the memory of the 2 largest + partitions. + + .. versionadded:: 2.0.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Parameters + ---------- + prefetchPartitions : bool, optional + If Spark should pre-fetch the next partition before it is needed. + + .. versionchanged:: 3.4.0 + This argument does not take effect for Spark Connect. + + Returns: + ------- + Iterator + Iterator of rows. + + Examples: + -------- + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> list(df.toLocalIterator()) + [Row(age=14, name='Tom'), Row(age=23, name='Alice'), Row(age=16, name='Bob')] + """ + columns = self.relation.columns + cur = self.relation.execute() + + while rows := cur.fetchmany(10_000): + yield from (_construct_row(x, columns) for x in rows) + + def foreach(self, f: Callable[[Row], None]) -> None: + """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. + + This is a shorthand for ``df.rdd.foreach()``. + + .. versionadded:: 1.3.0 + + .. versionchanged:: 4.0.0 + Supports Spark Connect. + + Parameters + ---------- + f : function + A function that accepts one parameter which will + receive each row to process. + + Examples: + -------- + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> def func(person): + ... print(person.name) + >>> df.foreach(func) + """ + for row in self.toLocalIterator(): + f(row) + + def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None: + """Applies the ``f`` function to each partition of this :class:`DataFrame`. + + This a shorthand for ``df.rdd.foreachPartition()``. + + .. versionadded:: 1.3.0 + + .. versionchanged:: 4.0.0 + Supports Spark Connect. + + Parameters + ---------- + f : function + A function that accepts one parameter which will receive + each partition to process. + + Examples: + -------- + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> def func(itr): + ... for person in itr: + ... print(person.name) + >>> df.foreachPartition(func) + """ + rows_generator = self.toLocalIterator() + while rows := itertools.islice(rows_generator, 10_000): + f(rows) + + def isEmpty(self) -> bool: + """Checks if the :class:`DataFrame` is empty and returns a boolean value. + + .. versionadded:: 3.3.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Returns: + ------- + bool + Returns ``True`` if the DataFrame is empty, ``False`` otherwise. + + See Also: + -------- + DataFrame.count : Counts the number of rows in DataFrame. + + Notes: + ----- + - An empty DataFrame has no rows. It may have columns, but no data. + + Examples: + -------- + Example 1: Checking if an empty DataFrame is empty + + >>> df_empty = spark.createDataFrame([], "a STRING") + >>> df_empty.isEmpty() + True + + Example 2: Checking if a non-empty DataFrame is empty + + >>> df_non_empty = spark.createDataFrame(["a"], "STRING") + >>> df_non_empty.isEmpty() + False + + Example 3: Checking if a DataFrame with null values is empty + + >>> df_nulls = spark.createDataFrame([(None, None)], "a STRING, b INT") + >>> df_nulls.isEmpty() + False + + Example 4: Checking if a DataFrame with no rows but with columns is empty + + >>> df_no_rows = spark.createDataFrame([], "id INT, value STRING") + >>> df_no_rows.isEmpty() + True + """ + return self.first() is None + def createOrReplaceTempView(self, name: str) -> None: """Creates or replaces a local temporary view with this :class:`DataFrame`. @@ -1381,12 +1529,7 @@ def collect(self) -> list[Row]: # noqa: D102 columns = self.relation.columns result = self.relation.fetchall() - def construct_row(values: list, names: list[str]) -> Row: - row = tuple.__new__(Row, list(values)) - row.__fields__ = list(names) - return row - - rows = [construct_row(x, columns) for x in result] + rows = [_construct_row(x, columns) for x in result] return rows def cache(self) -> "DataFrame": diff --git a/tests/fast/spark/test_spark_dataframe.py b/tests/fast/spark/test_spark_dataframe.py index e242092e..19e99921 100644 --- a/tests/fast/spark/test_spark_dataframe.py +++ b/tests/fast/spark/test_spark_dataframe.py @@ -1,3 +1,5 @@ +from unittest import mock + import pytest _ = pytest.importorskip("duckdb.experimental.spark") @@ -597,3 +599,40 @@ def test_treeString_array_type(self, spark): assert " |-- name:" in tree assert " |-- hobbies: array<" in tree assert "(nullable = true)" in tree + + def test_method_is_empty(self, spark): + data = [(1, "Alice"), (2, "Bob")] + df = spark.createDataFrame(data, ["id", "name"]) + empty_df = spark.createDataFrame([], schema=df.schema) + + assert not df.isEmpty() + assert empty_df.isEmpty() + + def test_dataframe_foreach(self, spark): + data = [(56, "Carol"), (20, "Alice"), (3, "Dave")] + df = spark.createDataFrame(data, ["age", "name"]) + expected = [Row(age=56, name="Carol"), Row(age=20, name="Alice"), Row(age=3, name="Dave")] + + mock_callable = mock.MagicMock() + df.foreach(mock_callable) + mock_callable.assert_has_calls( + [mock.call(expected[0]), mock.call(expected[1]), mock.call(expected[2])], + any_order=True, + ) + + def test_dataframe_foreach_partition(self, spark): + data = [(56, "Carol"), (20, "Alice"), (3, "Dave")] + df = spark.createDataFrame(data, ["age", "name"]) + expected = [Row(age=56, name="Carol"), Row(age=20, name="Alice"), Row(age=3, name="Dave")] + + mock_callable = mock.MagicMock() + df.foreachPartition(mock_callable) + mock_callable.assert_called_once_with(expected) + + def test_to_local_iterator(self, spark): + data = [(56, "Carol"), (20, "Alice"), (3, "Dave")] + df = spark.createDataFrame(data, ["age", "name"]) + expected = [Row(age=56, name="Carol"), Row(age=20, name="Alice"), Row(age=3, name="Dave")] + + res = list(df.toLocalIterator()) + assert res == expected From c51f54d6e541059c7d357c75e1f08995a70304f2 Mon Sep 17 00:00:00 2001 From: Mario Taddeucci Date: Fri, 13 Mar 2026 23:14:49 -0300 Subject: [PATCH 2/4] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- duckdb/experimental/spark/sql/dataframe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index b309d519..ba207e24 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -22,6 +22,8 @@ from .type_utils import duckdb_to_spark_schema from .types import Row, StructType +_LOCAL_ITERATOR_BATCH_SIZE = 10_000 + if TYPE_CHECKING: import pyarrow as pa from pandas.core.frame import DataFrame as PandasDataFrame @@ -167,7 +169,7 @@ def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None: >>> df.foreachPartition(func) """ rows_generator = self.toLocalIterator() - while rows := itertools.islice(rows_generator, 10_000): + while rows := itertools.islice(rows_generator, _LOCAL_ITERATOR_BATCH_SIZE): f(rows) def isEmpty(self) -> bool: From 6078ebd80e97f4201d5a53fa5a951b1c85edf84f Mon Sep 17 00:00:00 2001 From: Mario Taddeucci Date: Fri, 13 Mar 2026 23:18:01 -0300 Subject: [PATCH 3/4] Ensure proper resource management in toLocalIterator and update foreach to handle iterables --- duckdb/experimental/spark/sql/dataframe.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index ba207e24..2368765f 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -115,8 +115,11 @@ def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]: columns = self.relation.columns cur = self.relation.execute() - while rows := cur.fetchmany(10_000): - yield from (_construct_row(x, columns) for x in rows) + try: + while rows := cur.fetchmany(10_000): + yield from (_construct_row(x, columns) for x in rows) + finally: + cur.close() def foreach(self, f: Callable[[Row], None]) -> None: """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. @@ -170,7 +173,7 @@ def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None: """ rows_generator = self.toLocalIterator() while rows := itertools.islice(rows_generator, _LOCAL_ITERATOR_BATCH_SIZE): - f(rows) + f(iter(rows)) def isEmpty(self) -> bool: """Checks if the :class:`DataFrame` is empty and returns a boolean value. From ea9b5805b85a56d62bdf6c2ce6fd8f9fc4787840 Mon Sep 17 00:00:00 2001 From: Mario Taddeucci Date: Fri, 13 Mar 2026 23:19:20 -0300 Subject: [PATCH 4/4] Use constant for batch size in toLocalIterator --- duckdb/experimental/spark/sql/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 2368765f..36ef419d 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -116,7 +116,7 @@ def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]: cur = self.relation.execute() try: - while rows := cur.fetchmany(10_000): + while rows := cur.fetchmany(_LOCAL_ITERATOR_BATCH_SIZE): yield from (_construct_row(x, columns) for x in rows) finally: cur.close()