diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index e7519e81..36ef419d 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -1,4 +1,6 @@ -import uuid # noqa: D100 +import itertools # noqa: D100 +import uuid +from collections.abc import Iterable, Iterator from functools import reduce from keyword import iskeyword from typing import ( @@ -20,6 +22,8 @@ from .type_utils import duckdb_to_spark_schema from .types import Row, StructType +_LOCAL_ITERATOR_BATCH_SIZE = 10_000 + if TYPE_CHECKING: import pyarrow as pa from pandas.core.frame import DataFrame as PandasDataFrame @@ -31,6 +35,12 @@ from duckdb.experimental.spark.sql import functions as spark_sql_functions +def _construct_row(values: Iterable, names: list[str]) -> Row: + row = tuple.__new__(Row, list(values)) + row.__fields__ = list(names) + return row + + class DataFrame: # noqa: D101 def __init__(self, relation: duckdb.DuckDBPyRelation, session: "SparkSession") -> None: # noqa: D107 self.relation = relation @@ -71,6 +81,149 @@ def toArrow(self) -> "pa.Table": """ return self.relation.to_arrow_table() + def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]: + """Returns an iterator that contains all of the rows in this :class:`DataFrame`. + + The iterator will consume as much memory as the largest partition in this + :class:`DataFrame`. With prefetch it may consume up to the memory of the 2 largest + partitions. + + .. versionadded:: 2.0.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Parameters + ---------- + prefetchPartitions : bool, optional + If Spark should pre-fetch the next partition before it is needed. + + .. versionchanged:: 3.4.0 + This argument does not take effect for Spark Connect. + + Returns: + ------- + Iterator + Iterator of rows. + + Examples: + -------- + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> list(df.toLocalIterator()) + [Row(age=14, name='Tom'), Row(age=23, name='Alice'), Row(age=16, name='Bob')] + """ + columns = self.relation.columns + cur = self.relation.execute() + + try: + while rows := cur.fetchmany(_LOCAL_ITERATOR_BATCH_SIZE): + yield from (_construct_row(x, columns) for x in rows) + finally: + cur.close() + + def foreach(self, f: Callable[[Row], None]) -> None: + """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. + + This is a shorthand for ``df.rdd.foreach()``. + + .. versionadded:: 1.3.0 + + .. versionchanged:: 4.0.0 + Supports Spark Connect. + + Parameters + ---------- + f : function + A function that accepts one parameter which will + receive each row to process. + + Examples: + -------- + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> def func(person): + ... print(person.name) + >>> df.foreach(func) + """ + for row in self.toLocalIterator(): + f(row) + + def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None: + """Applies the ``f`` function to each partition of this :class:`DataFrame`. + + This a shorthand for ``df.rdd.foreachPartition()``. + + .. versionadded:: 1.3.0 + + .. versionchanged:: 4.0.0 + Supports Spark Connect. + + Parameters + ---------- + f : function + A function that accepts one parameter which will receive + each partition to process. + + Examples: + -------- + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> def func(itr): + ... for person in itr: + ... print(person.name) + >>> df.foreachPartition(func) + """ + rows_generator = self.toLocalIterator() + while rows := itertools.islice(rows_generator, _LOCAL_ITERATOR_BATCH_SIZE): + f(iter(rows)) + + def isEmpty(self) -> bool: + """Checks if the :class:`DataFrame` is empty and returns a boolean value. + + .. versionadded:: 3.3.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Returns: + ------- + bool + Returns ``True`` if the DataFrame is empty, ``False`` otherwise. + + See Also: + -------- + DataFrame.count : Counts the number of rows in DataFrame. + + Notes: + ----- + - An empty DataFrame has no rows. It may have columns, but no data. + + Examples: + -------- + Example 1: Checking if an empty DataFrame is empty + + >>> df_empty = spark.createDataFrame([], "a STRING") + >>> df_empty.isEmpty() + True + + Example 2: Checking if a non-empty DataFrame is empty + + >>> df_non_empty = spark.createDataFrame(["a"], "STRING") + >>> df_non_empty.isEmpty() + False + + Example 3: Checking if a DataFrame with null values is empty + + >>> df_nulls = spark.createDataFrame([(None, None)], "a STRING, b INT") + >>> df_nulls.isEmpty() + False + + Example 4: Checking if a DataFrame with no rows but with columns is empty + + >>> df_no_rows = spark.createDataFrame([], "id INT, value STRING") + >>> df_no_rows.isEmpty() + True + """ + return self.first() is None + def createOrReplaceTempView(self, name: str) -> None: """Creates or replaces a local temporary view with this :class:`DataFrame`. @@ -1381,12 +1534,7 @@ def collect(self) -> list[Row]: # noqa: D102 columns = self.relation.columns result = self.relation.fetchall() - def construct_row(values: list, names: list[str]) -> Row: - row = tuple.__new__(Row, list(values)) - row.__fields__ = list(names) - return row - - rows = [construct_row(x, columns) for x in result] + rows = [_construct_row(x, columns) for x in result] return rows def cache(self) -> "DataFrame": diff --git a/tests/fast/spark/test_spark_dataframe.py b/tests/fast/spark/test_spark_dataframe.py index e242092e..19e99921 100644 --- a/tests/fast/spark/test_spark_dataframe.py +++ b/tests/fast/spark/test_spark_dataframe.py @@ -1,3 +1,5 @@ +from unittest import mock + import pytest _ = pytest.importorskip("duckdb.experimental.spark") @@ -597,3 +599,40 @@ def test_treeString_array_type(self, spark): assert " |-- name:" in tree assert " |-- hobbies: array<" in tree assert "(nullable = true)" in tree + + def test_method_is_empty(self, spark): + data = [(1, "Alice"), (2, "Bob")] + df = spark.createDataFrame(data, ["id", "name"]) + empty_df = spark.createDataFrame([], schema=df.schema) + + assert not df.isEmpty() + assert empty_df.isEmpty() + + def test_dataframe_foreach(self, spark): + data = [(56, "Carol"), (20, "Alice"), (3, "Dave")] + df = spark.createDataFrame(data, ["age", "name"]) + expected = [Row(age=56, name="Carol"), Row(age=20, name="Alice"), Row(age=3, name="Dave")] + + mock_callable = mock.MagicMock() + df.foreach(mock_callable) + mock_callable.assert_has_calls( + [mock.call(expected[0]), mock.call(expected[1]), mock.call(expected[2])], + any_order=True, + ) + + def test_dataframe_foreach_partition(self, spark): + data = [(56, "Carol"), (20, "Alice"), (3, "Dave")] + df = spark.createDataFrame(data, ["age", "name"]) + expected = [Row(age=56, name="Carol"), Row(age=20, name="Alice"), Row(age=3, name="Dave")] + + mock_callable = mock.MagicMock() + df.foreachPartition(mock_callable) + mock_callable.assert_called_once_with(expected) + + def test_to_local_iterator(self, spark): + data = [(56, "Carol"), (20, "Alice"), (3, "Dave")] + df = spark.createDataFrame(data, ["age", "name"]) + expected = [Row(age=56, name="Carol"), Row(age=20, name="Alice"), Row(age=3, name="Dave")] + + res = list(df.toLocalIterator()) + assert res == expected