From 8f8a2d2104283ac62aa411b20e0c444b33ef7eea Mon Sep 17 00:00:00 2001 From: Sumedh Sakdeo Date: Sat, 14 Feb 2026 13:09:49 -0800 Subject: [PATCH 1/4] feat: forward batch_size parameter to PyArrow Scanner Add batch_size parameter to _task_to_record_batches, _record_batches_from_scan_tasks_and_deletes, ArrowScan.to_record_batches, and DataScan.to_arrow_batch_reader so users can control the number of rows per RecordBatch returned by PyArrow's Scanner. Closes partially #3036 Co-Authored-By: Claude Opus 4.6 --- mkdocs/docs/api.md | 16 ++++++++++ pyiceberg/io/pyarrow.py | 24 +++++++++------ pyiceberg/table/__init__.py | 7 +++-- tests/io/test_pyarrow.py | 58 +++++++++++++++++++++++++++++++++++++ 4 files changed, 94 insertions(+), 11 deletions(-) diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 506547fcd6..dadaac93b4 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -355,6 +355,13 @@ for buf in tbl.scan().to_arrow_batch_reader(): print(f"Buffer contains {len(buf)} rows") ``` +You can control the number of rows per batch using the `batch_size` parameter: + +```python +for buf in tbl.scan().to_arrow_batch_reader(batch_size=1000): + print(f"Buffer contains {len(buf)} rows") +``` + To avoid any type inconsistencies during writing, you can convert the Iceberg table schema to Arrow: ```python @@ -1619,6 +1626,15 @@ table.scan( ).to_arrow_batch_reader() ``` +The `batch_size` parameter controls the maximum number of rows per RecordBatch (default is PyArrow's 131,072 rows): + +```python +table.scan( + row_filter=GreaterThanOrEqual("trip_distance", 10.0), + selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"), +).to_arrow_batch_reader(batch_size=1000) +``` + ### Pandas diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index a120c3b776..e7c0da5262 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1581,6 +1581,7 @@ def _task_to_record_batches( partition_spec: PartitionSpec | None = None, format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION, downcast_ns_timestamp_to_us: bool | None = None, + batch_size: int | None = None, ) -> Iterator[pa.RecordBatch]: arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) with io.new_input(task.file.file_path).open() as fin: @@ -1612,14 +1613,18 @@ def _task_to_record_batches( file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) - fragment_scanner = ds.Scanner.from_fragment( - fragment=fragment, - schema=physical_schema, + scanner_kwargs: dict[str, Any] = { + "fragment": fragment, + "schema": physical_schema, # This will push down the query to Arrow. # But in case there are positional deletes, we have to apply them first - filter=pyarrow_filter if not positional_deletes else None, - columns=[col.name for col in file_project_schema.columns], - ) + "filter": pyarrow_filter if not positional_deletes else None, + "columns": [col.name for col in file_project_schema.columns], + } + if batch_size is not None: + scanner_kwargs["batch_size"] = batch_size + + fragment_scanner = ds.Scanner.from_fragment(**scanner_kwargs) next_index = 0 batches = fragment_scanner.to_batches() @@ -1756,7 +1761,7 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table: return result - def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.RecordBatch]: + def to_record_batches(self, tasks: Iterable[FileScanTask], batch_size: int | None = None) -> Iterator[pa.RecordBatch]: """Scan the Iceberg table and return an Iterator[pa.RecordBatch]. Returns an Iterator of pa.RecordBatch with data from the Iceberg table @@ -1783,7 +1788,7 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]: # Materialize the iterator here to ensure execution happens within the executor. # Otherwise, the iterator would be lazily consumed later (in the main thread), # defeating the purpose of using executor.map. - return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)) + return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file, batch_size)) limit_reached = False for batches in executor.map(batches_for_task, tasks): @@ -1803,7 +1808,7 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]: break def _record_batches_from_scan_tasks_and_deletes( - self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]] + self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]], batch_size: int | None = None ) -> Iterator[pa.RecordBatch]: total_row_count = 0 for task in tasks: @@ -1822,6 +1827,7 @@ def _record_batches_from_scan_tasks_and_deletes( self._table_metadata.specs().get(task.file.spec_id), self._table_metadata.format_version, self._downcast_ns_timestamp_to_us, + batch_size, ) for batch in batches: if self._limit is not None: diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index cc0d9ff341..63fd1509e4 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -2157,13 +2157,16 @@ def to_arrow(self) -> pa.Table: self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit ).to_table(self.plan_files()) - def to_arrow_batch_reader(self) -> pa.RecordBatchReader: + def to_arrow_batch_reader(self, batch_size: int | None = None) -> pa.RecordBatchReader: """Return an Arrow RecordBatchReader from this DataScan. For large results, using a RecordBatchReader requires less memory than loading an Arrow Table for the same DataScan, because a RecordBatch is read one at a time. + Args: + batch_size: The number of rows per batch. If None, PyArrow's default is used. + Returns: pa.RecordBatchReader: Arrow RecordBatchReader from the Iceberg table's DataScan which can be used to read a stream of record batches one by one. @@ -2175,7 +2178,7 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: target_schema = schema_to_pyarrow(self.projection()) batches = ArrowScan( self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit - ).to_record_batches(self.plan_files()) + ).to_record_batches(self.plan_files(), batch_size=batch_size) return pa.RecordBatchReader.from_batches( target_schema, diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 04bc3ecfac..fb03f785c6 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -3048,6 +3048,64 @@ def _expected_batch(unit: str) -> pa.RecordBatch: assert _expected_batch("ns" if format_version > 2 else "us").equals(actual_result) +def test_task_to_record_batches_with_batch_size(tmpdir: str) -> None: + """Test that batch_size controls the number of rows per batch.""" + num_rows = 1000 + arrow_table = pa.table( + {"col": pa.array(range(num_rows))}, + schema=pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]), + ) + data_file = _write_table_to_data_file(f"{tmpdir}/test_batch_size.parquet", arrow_table.schema, arrow_table) + table_schema = Schema(NestedField(1, "col", LongType(), required=True)) + + batches = list( + _task_to_record_batches( + PyArrowFileIO(), + FileScanTask(data_file), + bound_row_filter=AlwaysTrue(), + projected_schema=table_schema, + table_schema=table_schema, + projected_field_ids={1}, + positional_deletes=None, + case_sensitive=True, + batch_size=100, + ) + ) + + assert len(batches) > 1 + for batch in batches: + assert len(batch) <= 100 + assert sum(len(b) for b in batches) == num_rows + + +def test_task_to_record_batches_default_batch_size(tmpdir: str) -> None: + """Test that batch_size=None uses PyArrow default (single batch for small files).""" + num_rows = 100 + arrow_table = pa.table( + {"col": pa.array(range(num_rows))}, + schema=pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]), + ) + data_file = _write_table_to_data_file(f"{tmpdir}/test_default_batch_size.parquet", arrow_table.schema, arrow_table) + table_schema = Schema(NestedField(1, "col", LongType(), required=True)) + + batches = list( + _task_to_record_batches( + PyArrowFileIO(), + FileScanTask(data_file), + bound_row_filter=AlwaysTrue(), + projected_schema=table_schema, + table_schema=table_schema, + projected_field_ids={1}, + positional_deletes=None, + case_sensitive=True, + ) + ) + + # With default batch_size, a small file should produce a single batch + assert len(batches) == 1 + assert len(batches[0]) == num_rows + + def test_parse_location_defaults() -> None: """Test that parse_location uses defaults.""" From b72b7ba8934ffa2682f217c730d55b24e78f475f Mon Sep 17 00:00:00 2001 From: Sumedh Sakdeo Date: Sat, 14 Feb 2026 13:12:44 -0800 Subject: [PATCH 2/4] feat: add streaming flag to ArrowScan.to_record_batches When streaming=True, batches are yielded as they are produced by PyArrow without materializing entire files into memory. Files are still processed sequentially, preserving file ordering. The inner method handles the global limit correctly when called with all tasks, avoiding double-counting. This addresses the OOM issue in #3036 for single-file-at-a-time streaming. Co-Authored-By: Claude Opus 4.6 --- mkdocs/docs/api.md | 16 ++++++++ pyiceberg/io/pyarrow.py | 60 +++++++++++++++++----------- pyiceberg/table/__init__.py | 6 ++- tests/io/test_pyarrow.py | 80 +++++++++++++++++++++++++++++++++++++ 4 files changed, 136 insertions(+), 26 deletions(-) diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index dadaac93b4..2b08d53643 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -362,6 +362,13 @@ for buf in tbl.scan().to_arrow_batch_reader(batch_size=1000): print(f"Buffer contains {len(buf)} rows") ``` +By default, each file's batches are materialized in memory before being yielded. For large files that may exceed available memory, use `streaming=True` to yield batches as they are produced without materializing entire files: + +```python +for buf in tbl.scan().to_arrow_batch_reader(streaming=True, batch_size=1000): + print(f"Buffer contains {len(buf)} rows") +``` + To avoid any type inconsistencies during writing, you can convert the Iceberg table schema to Arrow: ```python @@ -1635,6 +1642,15 @@ table.scan( ).to_arrow_batch_reader(batch_size=1000) ``` +Use `streaming=True` to avoid materializing entire files in memory. This yields batches as they are produced by PyArrow, one file at a time: + +```python +table.scan( + row_filter=GreaterThanOrEqual("trip_distance", 10.0), + selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"), +).to_arrow_batch_reader(streaming=True) +``` + ### Pandas diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index e7c0da5262..67f159d805 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1761,7 +1761,9 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table: return result - def to_record_batches(self, tasks: Iterable[FileScanTask], batch_size: int | None = None) -> Iterator[pa.RecordBatch]: + def to_record_batches( + self, tasks: Iterable[FileScanTask], batch_size: int | None = None, streaming: bool = False + ) -> Iterator[pa.RecordBatch]: """Scan the Iceberg table and return an Iterator[pa.RecordBatch]. Returns an Iterator of pa.RecordBatch with data from the Iceberg table @@ -1770,6 +1772,9 @@ def to_record_batches(self, tasks: Iterable[FileScanTask], batch_size: int | Non Args: tasks: FileScanTasks representing the data files and delete files to read from. + batch_size: The number of rows per batch. If None, PyArrow's default is used. + streaming: If True, yield batches as they are produced without materializing + entire files into memory. Files are still processed sequentially. Returns: An Iterator of PyArrow RecordBatches. @@ -1781,31 +1786,38 @@ def to_record_batches(self, tasks: Iterable[FileScanTask], batch_size: int | Non """ deletes_per_file = _read_all_delete_files(self._io, tasks) - total_row_count = 0 - executor = ExecutorFactory.get_or_create() - - def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]: - # Materialize the iterator here to ensure execution happens within the executor. - # Otherwise, the iterator would be lazily consumed later (in the main thread), - # defeating the purpose of using executor.map. - return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file, batch_size)) - - limit_reached = False - for batches in executor.map(batches_for_task, tasks): - for batch in batches: - current_batch_size = len(batch) - if self._limit is not None and total_row_count + current_batch_size >= self._limit: - yield batch.slice(0, self._limit - total_row_count) + if streaming: + # Streaming path: process all tasks sequentially, yielding batches as produced. + # _record_batches_from_scan_tasks_and_deletes handles the limit internally + # when called with all tasks, so no outer limit check is needed. + yield from self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file, batch_size) + else: + # Non-streaming path: existing behavior with executor.map + list() + total_row_count = 0 + executor = ExecutorFactory.get_or_create() + + def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]: + # Materialize the iterator here to ensure execution happens within the executor. + # Otherwise, the iterator would be lazily consumed later (in the main thread), + # defeating the purpose of using executor.map. + return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file, batch_size)) + + limit_reached = False + for batches in executor.map(batches_for_task, tasks): + for batch in batches: + current_batch_size = len(batch) + if self._limit is not None and total_row_count + current_batch_size >= self._limit: + yield batch.slice(0, self._limit - total_row_count) + + limit_reached = True + break + else: + yield batch + total_row_count += current_batch_size - limit_reached = True + if limit_reached: + # This break will also cancel all running tasks in the executor break - else: - yield batch - total_row_count += current_batch_size - - if limit_reached: - # This break will also cancel all running tasks in the executor - break def _record_batches_from_scan_tasks_and_deletes( self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]], batch_size: int | None = None diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 63fd1509e4..b8674a39e6 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -2157,7 +2157,7 @@ def to_arrow(self) -> pa.Table: self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit ).to_table(self.plan_files()) - def to_arrow_batch_reader(self, batch_size: int | None = None) -> pa.RecordBatchReader: + def to_arrow_batch_reader(self, batch_size: int | None = None, streaming: bool = False) -> pa.RecordBatchReader: """Return an Arrow RecordBatchReader from this DataScan. For large results, using a RecordBatchReader requires less memory than @@ -2166,6 +2166,8 @@ def to_arrow_batch_reader(self, batch_size: int | None = None) -> pa.RecordBatch Args: batch_size: The number of rows per batch. If None, PyArrow's default is used. + streaming: If True, yield batches as they are produced without materializing + entire files into memory. Files are still processed sequentially. Returns: pa.RecordBatchReader: Arrow RecordBatchReader from the Iceberg table's DataScan @@ -2178,7 +2180,7 @@ def to_arrow_batch_reader(self, batch_size: int | None = None) -> pa.RecordBatch target_schema = schema_to_pyarrow(self.projection()) batches = ArrowScan( self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit - ).to_record_batches(self.plan_files(), batch_size=batch_size) + ).to_record_batches(self.plan_files(), batch_size=batch_size, streaming=streaming) return pa.RecordBatchReader.from_batches( target_schema, diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index fb03f785c6..847dbadad2 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -3106,6 +3106,86 @@ def test_task_to_record_batches_default_batch_size(tmpdir: str) -> None: assert len(batches[0]) == num_rows +def _create_scan_and_tasks( + tmpdir: str, num_files: int = 1, rows_per_file: int = 100, limit: int | None = None +) -> tuple[ArrowScan, list[FileScanTask]]: + """Helper to create an ArrowScan and FileScanTasks for testing.""" + table_schema = Schema(NestedField(1, "col", LongType(), required=True)) + pa_schema = pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]) + tasks = [] + for i in range(num_files): + start = i * rows_per_file + arrow_table = pa.table({"col": pa.array(range(start, start + rows_per_file))}, schema=pa_schema) + data_file = _write_table_to_data_file(f"{tmpdir}/file_{i}.parquet", pa_schema, arrow_table) + data_file.spec_id = 0 + tasks.append(FileScanTask(data_file)) + + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=1, + format_version=2, + schemas=[table_schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=table_schema, + row_filter=AlwaysTrue(), + case_sensitive=True, + limit=limit, + ) + return scan, tasks + + +def test_streaming_false_produces_same_results(tmpdir: str) -> None: + """Test that streaming=False produces the same results as the default behavior.""" + scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100) + + batches_default = list(scan.to_record_batches(tasks, streaming=False)) + # Re-create tasks since iterators are consumed + _, tasks2 = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100) + batches_streaming = list(scan.to_record_batches(tasks2, streaming=False)) + + total_default = sum(len(b) for b in batches_default) + total_streaming = sum(len(b) for b in batches_streaming) + assert total_default == 300 + assert total_streaming == 300 + + +def test_streaming_true_yields_all_batches(tmpdir: str) -> None: + """Test that streaming=True yields all batches correctly.""" + scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100) + + batches = list(scan.to_record_batches(tasks, streaming=True)) + + total_rows = sum(len(b) for b in batches) + assert total_rows == 300 + # Verify all values are present + all_values = sorted([v for b in batches for v in b.column("col").to_pylist()]) + assert all_values == list(range(300)) + + +def test_streaming_true_with_limit(tmpdir: str) -> None: + """Test that streaming=True respects the row limit.""" + scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100, limit=150) + + batches = list(scan.to_record_batches(tasks, streaming=True)) + + total_rows = sum(len(b) for b in batches) + assert total_rows == 150 + + +def test_streaming_file_ordering_preserved(tmpdir: str) -> None: + """Test that file ordering is preserved in both streaming modes.""" + scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100) + + batches = list(scan.to_record_batches(tasks, streaming=True)) + all_values = [v for b in batches for v in b.column("col").to_pylist()] + + # Values should be in file order: 0-99 from file 0, 100-199 from file 1, 200-299 from file 2 + assert all_values == list(range(300)) + + def test_parse_location_defaults() -> None: """Test that parse_location uses defaults.""" From 7cfb2b153abed3b87f2f07b25f478e08000594d2 Mon Sep 17 00:00:00 2001 From: Sumedh Sakdeo Date: Sat, 14 Feb 2026 13:23:56 -0800 Subject: [PATCH 3/4] feat: add concurrent_files flag for bounded concurrent streaming Add _bounded_concurrent_batches() with proper lock discipline: - Queue backpressure caps memory (scan.max-buffered-batches, default 16) - Semaphore limits concurrent file reads (concurrent_files param) - Cancel event with timeouts on all blocking ops (no lock over IO) - Error propagation and early termination support When streaming=True and concurrent_files > 1, batches are yielded as they arrive from parallel file reads. File ordering is not guaranteed (documented). Co-Authored-By: Claude Opus 4.6 --- mkdocs/docs/api.md | 30 +++ pyiceberg/io/pyarrow.py | 131 +++++++++- pyiceberg/table/__init__.py | 20 +- tests/io/test_bounded_concurrent_batches.py | 255 ++++++++++++++++++++ 4 files changed, 430 insertions(+), 6 deletions(-) create mode 100644 tests/io/test_bounded_concurrent_batches.py diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 2b08d53643..589c967f49 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -369,6 +369,25 @@ for buf in tbl.scan().to_arrow_batch_reader(streaming=True, batch_size=1000): print(f"Buffer contains {len(buf)} rows") ``` +For maximum throughput, use `concurrent_files` to read multiple files in parallel while streaming. Batches are yielded as they arrive from any file — ordering across files is not guaranteed: + +```python +for buf in tbl.scan().to_arrow_batch_reader(streaming=True, concurrent_files=4, batch_size=1000): + print(f"Buffer contains {len(buf)} rows") +``` + +The maximum number of buffered batches can be tuned via the `scan.max-buffered-batches` table property (default 16). + +**Ordering semantics:** + +| Configuration | File ordering | Within-file ordering | +|---|---|---| +| Default (`streaming=False`) | Batches grouped by file, in task submission order | Row order | +| `streaming=True` | Batches grouped by file, sequential | Row order | +| `streaming=True, concurrent_files>1` | Interleaved across files (no grouping guarantee) | Row order within each file | + +In all modes, within-file batch ordering follows row order. The `limit` parameter is enforced correctly regardless of configuration. + To avoid any type inconsistencies during writing, you can convert the Iceberg table schema to Arrow: ```python @@ -1651,6 +1670,17 @@ table.scan( ).to_arrow_batch_reader(streaming=True) ``` +For concurrent file reads with streaming, use `concurrent_files`. Note that batch ordering across files is not guaranteed: + +```python +table.scan( + row_filter=GreaterThanOrEqual("trip_distance", 10.0), + selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"), +).to_arrow_batch_reader(streaming=True, concurrent_files=4) +``` + +When using `concurrent_files > 1`, batches from different files may be interleaved. Within each file, batches are always in row order. See the ordering semantics table in the [Arrow Batch Reader section](#arrow-batch-reader) above for details. + ### Pandas diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 67f159d805..cb58d01063 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -33,7 +33,9 @@ import logging import operator import os +import queue import re +import threading import uuid import warnings from abc import ABC, abstractmethod @@ -1682,6 +1684,89 @@ def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> dict[st return deletes_per_file +_QUEUE_SENTINEL = object() + + +def _bounded_concurrent_batches( + tasks: list[FileScanTask], + batch_fn: Callable[[FileScanTask], Iterator[pa.RecordBatch]], + concurrent_files: int, + max_buffered_batches: int = 16, +) -> Iterator[pa.RecordBatch]: + """Read batches from multiple files concurrently with bounded memory. + + Workers read from files in parallel (up to concurrent_files at a time) and push + batches into a shared queue. The consumer yields batches from the queue. + A sentinel value signals completion, avoiding timeout-based polling. + + Args: + tasks: The file scan tasks to process. + batch_fn: A callable that takes a FileScanTask and returns an iterator of RecordBatches. + concurrent_files: Maximum number of files to read concurrently. + max_buffered_batches: Maximum number of batches to buffer in the queue. + """ + if not tasks: + return + + batch_queue: queue.Queue[pa.RecordBatch | BaseException | object] = queue.Queue(maxsize=max_buffered_batches) + cancel_event = threading.Event() + pending_count = len(tasks) + pending_lock = threading.Lock() + file_semaphore = threading.Semaphore(concurrent_files) + + def worker(task: FileScanTask) -> None: + nonlocal pending_count + acquired = False + try: + # Acquire semaphore — blocks until a slot is available or cancelled + while not cancel_event.is_set(): + if file_semaphore.acquire(timeout=0.5): + acquired = True + break + if cancel_event.is_set(): + return + + for batch in batch_fn(task): + if cancel_event.is_set(): + return + batch_queue.put(batch) + except BaseException as e: + if not cancel_event.is_set(): + batch_queue.put(e) + finally: + if acquired: + file_semaphore.release() + with pending_lock: + pending_count -= 1 + if pending_count == 0: + batch_queue.put(_QUEUE_SENTINEL) + + executor = ExecutorFactory.get_or_create() + futures = [executor.submit(worker, task) for task in tasks] + + try: + while True: + item = batch_queue.get() + + if item is _QUEUE_SENTINEL: + break + + if isinstance(item, BaseException): + raise item + + yield item + finally: + cancel_event.set() + # Drain the queue to unblock any workers stuck on put() + while not batch_queue.empty(): + try: + batch_queue.get_nowait() + except queue.Empty: + break + for future in futures: + future.cancel() + + class ArrowScan: _table_metadata: TableMetadata _io: FileIO @@ -1762,7 +1847,11 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table: return result def to_record_batches( - self, tasks: Iterable[FileScanTask], batch_size: int | None = None, streaming: bool = False + self, + tasks: Iterable[FileScanTask], + batch_size: int | None = None, + streaming: bool = False, + concurrent_files: int = 1, ) -> Iterator[pa.RecordBatch]: """Scan the Iceberg table and return an Iterator[pa.RecordBatch]. @@ -1770,11 +1859,21 @@ def to_record_batches( by resolving the right columns that match the current table schema. Only data that matches the provided row_filter expression is returned. + Ordering semantics: + - Default (streaming=False): Batches are grouped by file in task submission order. + - streaming=True, concurrent_files=1: Batches are grouped by file, processed sequentially. + - streaming=True, concurrent_files>1: Batches may be interleaved across files. + In all modes, within-file batch ordering follows row order. + Args: tasks: FileScanTasks representing the data files and delete files to read from. batch_size: The number of rows per batch. If None, PyArrow's default is used. streaming: If True, yield batches as they are produced without materializing - entire files into memory. Files are still processed sequentially. + entire files into memory. Files are still processed sequentially when + concurrent_files=1. + concurrent_files: Number of files to read concurrently when streaming=True. + When > 1, batches may arrive interleaved across files. Ignored when + streaming=False. Returns: An Iterator of PyArrow RecordBatches. @@ -1786,7 +1885,33 @@ def to_record_batches( """ deletes_per_file = _read_all_delete_files(self._io, tasks) - if streaming: + if streaming and concurrent_files > 1: + # Concurrent streaming path: read multiple files in parallel with bounded queue. + # Ordering is NOT guaranteed across files — batches arrive as produced. + task_list = list(tasks) + + def batch_fn(task: FileScanTask) -> Iterator[pa.RecordBatch]: + return self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file, batch_size) + + from pyiceberg.table import TableProperties + + max_buffered = int( + self._table_metadata.properties.get( + TableProperties.SCAN_MAX_BUFFERED_BATCHES, + TableProperties.SCAN_MAX_BUFFERED_BATCHES_DEFAULT, + ) + ) + + total_row_count = 0 + for batch in _bounded_concurrent_batches(task_list, batch_fn, concurrent_files, max_buffered): + current_batch_size = len(batch) + if self._limit is not None and total_row_count + current_batch_size >= self._limit: + yield batch.slice(0, self._limit - total_row_count) + return + else: + yield batch + total_row_count += current_batch_size + elif streaming: # Streaming path: process all tasks sequentially, yielding batches as produced. # _record_batches_from_scan_tasks_and_deletes handles the limit internally # when called with all tasks, so no outer limit check is needed. diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index b8674a39e6..19192794cb 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -247,6 +247,9 @@ class TableProperties: MIN_SNAPSHOTS_TO_KEEP = "history.expire.min-snapshots-to-keep" MIN_SNAPSHOTS_TO_KEEP_DEFAULT = 1 + SCAN_MAX_BUFFERED_BATCHES = "scan.max-buffered-batches" + SCAN_MAX_BUFFERED_BATCHES_DEFAULT = 16 + class Transaction: _table: Table @@ -2157,17 +2160,28 @@ def to_arrow(self) -> pa.Table: self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit ).to_table(self.plan_files()) - def to_arrow_batch_reader(self, batch_size: int | None = None, streaming: bool = False) -> pa.RecordBatchReader: + def to_arrow_batch_reader( + self, batch_size: int | None = None, streaming: bool = False, concurrent_files: int = 1 + ) -> pa.RecordBatchReader: """Return an Arrow RecordBatchReader from this DataScan. For large results, using a RecordBatchReader requires less memory than loading an Arrow Table for the same DataScan, because a RecordBatch is read one at a time. + Ordering semantics: + - Default (streaming=False): Batches are grouped by file in task submission order. + - streaming=True, concurrent_files=1: Batches are grouped by file, processed sequentially. + - streaming=True, concurrent_files>1: Batches may be interleaved across files. + In all modes, within-file batch ordering follows row order. + Args: batch_size: The number of rows per batch. If None, PyArrow's default is used. streaming: If True, yield batches as they are produced without materializing - entire files into memory. Files are still processed sequentially. + entire files into memory. Files are still processed sequentially when + concurrent_files=1. + concurrent_files: Number of files to read concurrently when streaming=True. + When > 1, batches may arrive interleaved across files. Returns: pa.RecordBatchReader: Arrow RecordBatchReader from the Iceberg table's DataScan @@ -2180,7 +2194,7 @@ def to_arrow_batch_reader(self, batch_size: int | None = None, streaming: bool = target_schema = schema_to_pyarrow(self.projection()) batches = ArrowScan( self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit - ).to_record_batches(self.plan_files(), batch_size=batch_size, streaming=streaming) + ).to_record_batches(self.plan_files(), batch_size=batch_size, streaming=streaming, concurrent_files=concurrent_files) return pa.RecordBatchReader.from_batches( target_schema, diff --git a/tests/io/test_bounded_concurrent_batches.py b/tests/io/test_bounded_concurrent_batches.py new file mode 100644 index 0000000000..96cd5720e2 --- /dev/null +++ b/tests/io/test_bounded_concurrent_batches.py @@ -0,0 +1,255 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for _bounded_concurrent_batches in pyiceberg.io.pyarrow.""" + +import threading +import time +from collections.abc import Iterator +from unittest.mock import MagicMock + +import pyarrow as pa +import pytest + +from pyiceberg.io.pyarrow import _bounded_concurrent_batches +from pyiceberg.table import FileScanTask + + +def _make_task() -> FileScanTask: + """Create a mock FileScanTask.""" + task = MagicMock(spec=FileScanTask) + return task + + +def _make_batches(num_batches: int, rows_per_batch: int = 10, start: int = 0) -> list[pa.RecordBatch]: + """Create a list of simple RecordBatches.""" + return [pa.record_batch({"col": list(range(start + i * rows_per_batch, start + (i + 1) * rows_per_batch))}) for i in range(num_batches)] + + +def test_correctness_single_file() -> None: + """Test that a single file produces correct results.""" + task = _make_task() + expected_batches = _make_batches(3) + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + yield from expected_batches + + result = list(_bounded_concurrent_batches([task], batch_fn, concurrent_files=1, max_buffered_batches=16)) + + assert len(result) == 3 + total_rows = sum(len(b) for b in result) + assert total_rows == 30 + + +def test_correctness_multiple_files() -> None: + """Test that multiple files produce all expected batches.""" + tasks = [_make_task() for _ in range(4)] + batches_per_file = 3 + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + idx = tasks.index(t) + yield from _make_batches(batches_per_file, start=idx * 100) + + result = list(_bounded_concurrent_batches(tasks, batch_fn, concurrent_files=2, max_buffered_batches=16)) + + total_rows = sum(len(b) for b in result) + assert total_rows == batches_per_file * len(tasks) * 10 # 3 batches * 4 files * 10 rows + + +def test_streaming_yields_incrementally() -> None: + """Test that batches are yielded incrementally, not all at once.""" + barrier = threading.Event() + tasks = [_make_task(), _make_task()] + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + yield pa.record_batch({"col": [1, 2, 3]}) + barrier.wait(timeout=5.0) + yield pa.record_batch({"col": [4, 5, 6]}) + + gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_files=2, max_buffered_batches=16) + + # Should get at least one batch before all are done + first = next(gen) + assert first.num_rows == 3 + + # Unblock remaining batches + barrier.set() + + remaining = list(gen) + total = 1 + len(remaining) + assert total >= 3 # At least 3 more batches (one blocked from each task + the unblocked ones) + + +def test_backpressure() -> None: + """Test that workers block when the queue is full.""" + max_buffered = 2 + tasks = [_make_task()] + produced_count = 0 + produce_lock = threading.Lock() + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + nonlocal produced_count + for i in range(10): + with produce_lock: + produced_count += 1 + yield pa.record_batch({"col": [i]}) + + gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_files=1, max_buffered_batches=max_buffered) + + # Consume slowly and check that not all batches are produced immediately + first = next(gen) + assert first is not None + time.sleep(0.3) + + # The producer should be blocked by backpressure at some point + # (not all 10 batches produced instantly) + remaining = list(gen) + assert len(remaining) + 1 == 10 + + +def test_error_propagation() -> None: + """Test that errors from workers are propagated to the consumer.""" + tasks = [_make_task()] + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + yield pa.record_batch({"col": [1]}) + raise ValueError("test error") + + gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_files=1, max_buffered_batches=16) + + # Should get the first batch + first = next(gen) + assert first.num_rows == 1 + + # Should get the error + with pytest.raises(ValueError, match="test error"): + list(gen) + + +def test_early_termination() -> None: + """Test that stopping consumption cancels workers.""" + tasks = [_make_task() for _ in range(5)] + worker_started = threading.Event() + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + worker_started.set() + for i in range(100): + yield pa.record_batch({"col": [i]}) + time.sleep(0.01) + + gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_files=3, max_buffered_batches=4) + + # Consume a few batches then stop + worker_started.wait(timeout=5.0) + batches = [] + for _ in range(5): + batches.append(next(gen)) + + # Close the generator, triggering finally block + gen.close() + + assert len(batches) == 5 + + +def test_concurrency_limit() -> None: + """Test that at most concurrent_files files are read concurrently.""" + concurrent_files = 2 + tasks = [_make_task() for _ in range(6)] + active_count = 0 + max_active = 0 + active_lock = threading.Lock() + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + nonlocal active_count, max_active + with active_lock: + active_count += 1 + max_active = max(max_active, active_count) + try: + time.sleep(0.05) + yield pa.record_batch({"col": [1]}) + finally: + with active_lock: + active_count -= 1 + + result = list(_bounded_concurrent_batches(tasks, batch_fn, concurrent_files=concurrent_files, max_buffered_batches=16)) + + assert len(result) == 6 + assert max_active <= concurrent_files + + +def test_empty_tasks() -> None: + """Test that no tasks produces no batches.""" + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + yield from [] + + result = list(_bounded_concurrent_batches([], batch_fn, concurrent_files=2, max_buffered_batches=16)) + assert result == [] + + +def test_concurrent_with_limit_via_arrowscan(tmpdir: str) -> None: + """Test concurrent_files with limit through ArrowScan integration.""" + from pyiceberg.expressions import AlwaysTrue + from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO, schema_to_pyarrow, write_file + from pyiceberg.manifest import DataFileContent, FileFormat + from pyiceberg.partitioning import PartitionSpec + from pyiceberg.schema import Schema + from pyiceberg.table.metadata import TableMetadataV2 + from pyiceberg.types import LongType, NestedField + + PYARROW_PARQUET_FIELD_ID_KEY = b"PARQUET:field_id" + + table_schema = Schema(NestedField(1, "col", LongType(), required=True)) + pa_schema = pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]) + + tasks = [] + for i in range(4): + filepath = f"{tmpdir}/file_{i}.parquet" + arrow_table = pa.table({"col": pa.array(range(i * 100, (i + 1) * 100))}, schema=pa_schema) + import pyarrow.parquet as pq + + pq.write_table(arrow_table, filepath) + from pyiceberg.manifest import DataFile + + data_file = DataFile.from_args( + content=DataFileContent.DATA, + file_path=filepath, + file_format=FileFormat.PARQUET, + partition={}, + record_count=100, + file_size_in_bytes=22, + ) + data_file.spec_id = 0 + tasks.append(FileScanTask(data_file)) + + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=1, + format_version=2, + schemas=[table_schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=table_schema, + row_filter=AlwaysTrue(), + case_sensitive=True, + limit=150, + ) + + batches = list(scan.to_record_batches(tasks, streaming=True, concurrent_files=2)) + total_rows = sum(len(b) for b in batches) + assert total_rows == 150 From e8dcd7a20811b858becfcf8deeb60b57479509b3 Mon Sep 17 00:00:00 2001 From: Sumedh Sakdeo Date: Sat, 14 Feb 2026 14:06:27 -0800 Subject: [PATCH 4/4] feat: add read throughput micro-benchmark for ArrowScan configurations Measures records/sec and peak memory across streaming, concurrent_files, and batch_size configurations to validate performance characteristics of the new scan modes introduced for #3036. Co-Authored-By: Claude Opus 4.6 --- tests/benchmark/test_read_benchmark.py | 181 +++++++++++++++++++++++++ 1 file changed, 181 insertions(+) create mode 100644 tests/benchmark/test_read_benchmark.py diff --git a/tests/benchmark/test_read_benchmark.py b/tests/benchmark/test_read_benchmark.py new file mode 100644 index 0000000000..6547999889 --- /dev/null +++ b/tests/benchmark/test_read_benchmark.py @@ -0,0 +1,181 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Read throughput micro-benchmark for ArrowScan configurations. + +Measures records/sec and peak Arrow memory across streaming, concurrent_files, +and batch_size configurations introduced for issue #3036. + +Memory is measured using pa.total_allocated_bytes() which tracks PyArrow's C++ +memory pool (Arrow buffers, Parquet decompression), not Python heap allocations. + +Run with: uv run pytest tests/benchmark/test_read_benchmark.py -v -s -m benchmark +""" + +import gc +import statistics +import timeit +from datetime import datetime, timezone + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +from pyiceberg.catalog.sql import SqlCatalog + +NUM_FILES = 32 +ROWS_PER_FILE = 500_000 +TOTAL_ROWS = NUM_FILES * ROWS_PER_FILE +NUM_RUNS = 3 + + +def _generate_parquet_file(path: str, num_rows: int, seed: int) -> pa.Schema: + """Write a synthetic Parquet file and return its schema.""" + table = pa.table( + { + "id": pa.array(range(seed, seed + num_rows), type=pa.int64()), + "value": pa.array([float(i) * 0.1 for i in range(num_rows)], type=pa.float64()), + "label": pa.array([f"row_{i}" for i in range(num_rows)], type=pa.string()), + "flag": pa.array([i % 2 == 0 for i in range(num_rows)], type=pa.bool_()), + "ts": pa.array([datetime.now(timezone.utc)] * num_rows, type=pa.timestamp("us", tz="UTC")), + } + ) + pq.write_table(table, path) + return table.schema + + +@pytest.fixture(scope="session") +def benchmark_table(tmp_path_factory: pytest.TempPathFactory) -> "pyiceberg.table.Table": # noqa: F821 + """Create a catalog and table with synthetic Parquet files for benchmarking.""" + warehouse_path = str(tmp_path_factory.mktemp("benchmark_warehouse")) + catalog = SqlCatalog( + "benchmark", + uri=f"sqlite:///{warehouse_path}/pyiceberg_catalog.db", + warehouse=f"file://{warehouse_path}", + ) + catalog.create_namespace("default") + + # Generate files and append to table + table = None + for i in range(NUM_FILES): + file_path = f"{warehouse_path}/data_{i}.parquet" + _generate_parquet_file(file_path, ROWS_PER_FILE, seed=i * ROWS_PER_FILE) + + file_table = pq.read_table(file_path) + if table is None: + table = catalog.create_table("default.benchmark_read", schema=file_table.schema) + table.append(file_table) + + return table + + +def _measure_peak_arrow_memory(benchmark_table, batch_size, streaming, concurrent_files): + """Run a scan and track peak PyArrow C++ memory allocation.""" + gc.collect() + pa.default_memory_pool().release_unused() + baseline = pa.total_allocated_bytes() + peak = baseline + + total_rows = 0 + for batch in benchmark_table.scan().to_arrow_batch_reader( + batch_size=batch_size, + streaming=streaming, + concurrent_files=concurrent_files, + ): + total_rows += len(batch) + current = pa.total_allocated_bytes() + if current > peak: + peak = current + # Release the batch immediately to simulate a streaming consumer + del batch + + return total_rows, peak - baseline + + +@pytest.mark.benchmark +@pytest.mark.parametrize( + "streaming,concurrent_files,batch_size", + [ + pytest.param(False, 1, None, id="default"), + pytest.param(True, 1, None, id="streaming-cf1"), + pytest.param(True, 2, None, id="streaming-cf2"), + pytest.param(True, 4, None, id="streaming-cf4"), + pytest.param(True, 8, None, id="streaming-cf8"), + pytest.param(True, 16, None, id="streaming-cf16"), + ], +) +def test_read_throughput( + benchmark_table: "pyiceberg.table.Table", # noqa: F821 + streaming: bool, + concurrent_files: int, + batch_size: int | None, +) -> None: + """Measure records/sec and peak Arrow memory for a scan configuration.""" + effective_batch_size = batch_size or 131_072 # PyArrow default + if streaming: + config_str = f"streaming=True, concurrent_files={concurrent_files}, batch_size={effective_batch_size}" + else: + config_str = f"streaming=False (executor.map, all files parallel), batch_size={effective_batch_size}" + print(f"\n--- ArrowScan Read Throughput Benchmark ---") + print(f"Config: {config_str}") + print(f" Files: {NUM_FILES}, Rows per file: {ROWS_PER_FILE}, Total rows: {TOTAL_ROWS}") + + elapsed_times: list[float] = [] + throughputs: list[float] = [] + peak_memories: list[int] = [] + + for run in range(NUM_RUNS): + # Measure throughput + gc.collect() + pa.default_memory_pool().release_unused() + baseline_mem = pa.total_allocated_bytes() + peak_mem = baseline_mem + + start = timeit.default_timer() + total_rows = 0 + for batch in benchmark_table.scan().to_arrow_batch_reader( + batch_size=batch_size, + streaming=streaming, + concurrent_files=concurrent_files, + ): + total_rows += len(batch) + current_mem = pa.total_allocated_bytes() + if current_mem > peak_mem: + peak_mem = current_mem + elapsed = timeit.default_timer() - start + + peak_above_baseline = peak_mem - baseline_mem + rows_per_sec = total_rows / elapsed if elapsed > 0 else 0 + elapsed_times.append(elapsed) + throughputs.append(rows_per_sec) + peak_memories.append(peak_above_baseline) + + print( + f" Run {run + 1}: {elapsed:.2f}s, {rows_per_sec:,.0f} rows/s, " + f"peak arrow mem: {peak_above_baseline / (1024 * 1024):.1f} MB" + ) + + assert total_rows == TOTAL_ROWS, f"Expected {TOTAL_ROWS} rows, got {total_rows}" + + mean_elapsed = statistics.mean(elapsed_times) + stdev_elapsed = statistics.stdev(elapsed_times) if len(elapsed_times) > 1 else 0.0 + mean_throughput = statistics.mean(throughputs) + mean_peak_mem = statistics.mean(peak_memories) + + print( + f" Mean: {mean_elapsed:.2f}s ± {stdev_elapsed:.2f}s, {mean_throughput:,.0f} rows/s, " + f"peak arrow mem: {mean_peak_mem / (1024 * 1024):.1f} MB" + )