From 8f8a2d2104283ac62aa411b20e0c444b33ef7eea Mon Sep 17 00:00:00 2001 From: Sumedh Sakdeo Date: Sat, 14 Feb 2026 13:09:49 -0800 Subject: [PATCH] feat: forward batch_size parameter to PyArrow Scanner Add batch_size parameter to _task_to_record_batches, _record_batches_from_scan_tasks_and_deletes, ArrowScan.to_record_batches, and DataScan.to_arrow_batch_reader so users can control the number of rows per RecordBatch returned by PyArrow's Scanner. Closes partially #3036 Co-Authored-By: Claude Opus 4.6 --- mkdocs/docs/api.md | 16 ++++++++++ pyiceberg/io/pyarrow.py | 24 +++++++++------ pyiceberg/table/__init__.py | 7 +++-- tests/io/test_pyarrow.py | 58 +++++++++++++++++++++++++++++++++++++ 4 files changed, 94 insertions(+), 11 deletions(-) diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 506547fcd6..dadaac93b4 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -355,6 +355,13 @@ for buf in tbl.scan().to_arrow_batch_reader(): print(f"Buffer contains {len(buf)} rows") ``` +You can control the number of rows per batch using the `batch_size` parameter: + +```python +for buf in tbl.scan().to_arrow_batch_reader(batch_size=1000): + print(f"Buffer contains {len(buf)} rows") +``` + To avoid any type inconsistencies during writing, you can convert the Iceberg table schema to Arrow: ```python @@ -1619,6 +1626,15 @@ table.scan( ).to_arrow_batch_reader() ``` +The `batch_size` parameter controls the maximum number of rows per RecordBatch (default is PyArrow's 131,072 rows): + +```python +table.scan( + row_filter=GreaterThanOrEqual("trip_distance", 10.0), + selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"), +).to_arrow_batch_reader(batch_size=1000) +``` + ### Pandas diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index a120c3b776..e7c0da5262 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1581,6 +1581,7 @@ def _task_to_record_batches( partition_spec: PartitionSpec | None = None, format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION, downcast_ns_timestamp_to_us: bool | None = None, + batch_size: int | None = None, ) -> Iterator[pa.RecordBatch]: arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) with io.new_input(task.file.file_path).open() as fin: @@ -1612,14 +1613,18 @@ def _task_to_record_batches( file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) - fragment_scanner = ds.Scanner.from_fragment( - fragment=fragment, - schema=physical_schema, + scanner_kwargs: dict[str, Any] = { + "fragment": fragment, + "schema": physical_schema, # This will push down the query to Arrow. # But in case there are positional deletes, we have to apply them first - filter=pyarrow_filter if not positional_deletes else None, - columns=[col.name for col in file_project_schema.columns], - ) + "filter": pyarrow_filter if not positional_deletes else None, + "columns": [col.name for col in file_project_schema.columns], + } + if batch_size is not None: + scanner_kwargs["batch_size"] = batch_size + + fragment_scanner = ds.Scanner.from_fragment(**scanner_kwargs) next_index = 0 batches = fragment_scanner.to_batches() @@ -1756,7 +1761,7 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table: return result - def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.RecordBatch]: + def to_record_batches(self, tasks: Iterable[FileScanTask], batch_size: int | None = None) -> Iterator[pa.RecordBatch]: """Scan the Iceberg table and return an Iterator[pa.RecordBatch]. Returns an Iterator of pa.RecordBatch with data from the Iceberg table @@ -1783,7 +1788,7 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]: # Materialize the iterator here to ensure execution happens within the executor. # Otherwise, the iterator would be lazily consumed later (in the main thread), # defeating the purpose of using executor.map. - return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)) + return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file, batch_size)) limit_reached = False for batches in executor.map(batches_for_task, tasks): @@ -1803,7 +1808,7 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]: break def _record_batches_from_scan_tasks_and_deletes( - self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]] + self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]], batch_size: int | None = None ) -> Iterator[pa.RecordBatch]: total_row_count = 0 for task in tasks: @@ -1822,6 +1827,7 @@ def _record_batches_from_scan_tasks_and_deletes( self._table_metadata.specs().get(task.file.spec_id), self._table_metadata.format_version, self._downcast_ns_timestamp_to_us, + batch_size, ) for batch in batches: if self._limit is not None: diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index cc0d9ff341..63fd1509e4 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -2157,13 +2157,16 @@ def to_arrow(self) -> pa.Table: self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit ).to_table(self.plan_files()) - def to_arrow_batch_reader(self) -> pa.RecordBatchReader: + def to_arrow_batch_reader(self, batch_size: int | None = None) -> pa.RecordBatchReader: """Return an Arrow RecordBatchReader from this DataScan. For large results, using a RecordBatchReader requires less memory than loading an Arrow Table for the same DataScan, because a RecordBatch is read one at a time. + Args: + batch_size: The number of rows per batch. If None, PyArrow's default is used. + Returns: pa.RecordBatchReader: Arrow RecordBatchReader from the Iceberg table's DataScan which can be used to read a stream of record batches one by one. @@ -2175,7 +2178,7 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: target_schema = schema_to_pyarrow(self.projection()) batches = ArrowScan( self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit - ).to_record_batches(self.plan_files()) + ).to_record_batches(self.plan_files(), batch_size=batch_size) return pa.RecordBatchReader.from_batches( target_schema, diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 04bc3ecfac..fb03f785c6 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -3048,6 +3048,64 @@ def _expected_batch(unit: str) -> pa.RecordBatch: assert _expected_batch("ns" if format_version > 2 else "us").equals(actual_result) +def test_task_to_record_batches_with_batch_size(tmpdir: str) -> None: + """Test that batch_size controls the number of rows per batch.""" + num_rows = 1000 + arrow_table = pa.table( + {"col": pa.array(range(num_rows))}, + schema=pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]), + ) + data_file = _write_table_to_data_file(f"{tmpdir}/test_batch_size.parquet", arrow_table.schema, arrow_table) + table_schema = Schema(NestedField(1, "col", LongType(), required=True)) + + batches = list( + _task_to_record_batches( + PyArrowFileIO(), + FileScanTask(data_file), + bound_row_filter=AlwaysTrue(), + projected_schema=table_schema, + table_schema=table_schema, + projected_field_ids={1}, + positional_deletes=None, + case_sensitive=True, + batch_size=100, + ) + ) + + assert len(batches) > 1 + for batch in batches: + assert len(batch) <= 100 + assert sum(len(b) for b in batches) == num_rows + + +def test_task_to_record_batches_default_batch_size(tmpdir: str) -> None: + """Test that batch_size=None uses PyArrow default (single batch for small files).""" + num_rows = 100 + arrow_table = pa.table( + {"col": pa.array(range(num_rows))}, + schema=pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]), + ) + data_file = _write_table_to_data_file(f"{tmpdir}/test_default_batch_size.parquet", arrow_table.schema, arrow_table) + table_schema = Schema(NestedField(1, "col", LongType(), required=True)) + + batches = list( + _task_to_record_batches( + PyArrowFileIO(), + FileScanTask(data_file), + bound_row_filter=AlwaysTrue(), + projected_schema=table_schema, + table_schema=table_schema, + projected_field_ids={1}, + positional_deletes=None, + case_sensitive=True, + ) + ) + + # With default batch_size, a small file should produce a single batch + assert len(batches) == 1 + assert len(batches[0]) == num_rows + + def test_parse_location_defaults() -> None: """Test that parse_location uses defaults."""