diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 506547fcd6..589c967f49 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -355,6 +355,39 @@ for buf in tbl.scan().to_arrow_batch_reader(): print(f"Buffer contains {len(buf)} rows") ``` +You can control the number of rows per batch using the `batch_size` parameter: + +```python +for buf in tbl.scan().to_arrow_batch_reader(batch_size=1000): + print(f"Buffer contains {len(buf)} rows") +``` + +By default, each file's batches are materialized in memory before being yielded. For large files that may exceed available memory, use `streaming=True` to yield batches as they are produced without materializing entire files: + +```python +for buf in tbl.scan().to_arrow_batch_reader(streaming=True, batch_size=1000): + print(f"Buffer contains {len(buf)} rows") +``` + +For maximum throughput, use `concurrent_files` to read multiple files in parallel while streaming. Batches are yielded as they arrive from any file — ordering across files is not guaranteed: + +```python +for buf in tbl.scan().to_arrow_batch_reader(streaming=True, concurrent_files=4, batch_size=1000): + print(f"Buffer contains {len(buf)} rows") +``` + +The maximum number of buffered batches can be tuned via the `scan.max-buffered-batches` table property (default 16). + +**Ordering semantics:** + +| Configuration | File ordering | Within-file ordering | +|---|---|---| +| Default (`streaming=False`) | Batches grouped by file, in task submission order | Row order | +| `streaming=True` | Batches grouped by file, sequential | Row order | +| `streaming=True, concurrent_files>1` | Interleaved across files (no grouping guarantee) | Row order within each file | + +In all modes, within-file batch ordering follows row order. The `limit` parameter is enforced correctly regardless of configuration. + To avoid any type inconsistencies during writing, you can convert the Iceberg table schema to Arrow: ```python @@ -1619,6 +1652,35 @@ table.scan( ).to_arrow_batch_reader() ``` +The `batch_size` parameter controls the maximum number of rows per RecordBatch (default is PyArrow's 131,072 rows): + +```python +table.scan( + row_filter=GreaterThanOrEqual("trip_distance", 10.0), + selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"), +).to_arrow_batch_reader(batch_size=1000) +``` + +Use `streaming=True` to avoid materializing entire files in memory. This yields batches as they are produced by PyArrow, one file at a time: + +```python +table.scan( + row_filter=GreaterThanOrEqual("trip_distance", 10.0), + selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"), +).to_arrow_batch_reader(streaming=True) +``` + +For concurrent file reads with streaming, use `concurrent_files`. Note that batch ordering across files is not guaranteed: + +```python +table.scan( + row_filter=GreaterThanOrEqual("trip_distance", 10.0), + selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"), +).to_arrow_batch_reader(streaming=True, concurrent_files=4) +``` + +When using `concurrent_files > 1`, batches from different files may be interleaved. Within each file, batches are always in row order. See the ordering semantics table in the [Arrow Batch Reader section](#arrow-batch-reader) above for details. + ### Pandas diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index a120c3b776..cb58d01063 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -33,7 +33,9 @@ import logging import operator import os +import queue import re +import threading import uuid import warnings from abc import ABC, abstractmethod @@ -1581,6 +1583,7 @@ def _task_to_record_batches( partition_spec: PartitionSpec | None = None, format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION, downcast_ns_timestamp_to_us: bool | None = None, + batch_size: int | None = None, ) -> Iterator[pa.RecordBatch]: arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) with io.new_input(task.file.file_path).open() as fin: @@ -1612,14 +1615,18 @@ def _task_to_record_batches( file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) - fragment_scanner = ds.Scanner.from_fragment( - fragment=fragment, - schema=physical_schema, + scanner_kwargs: dict[str, Any] = { + "fragment": fragment, + "schema": physical_schema, # This will push down the query to Arrow. # But in case there are positional deletes, we have to apply them first - filter=pyarrow_filter if not positional_deletes else None, - columns=[col.name for col in file_project_schema.columns], - ) + "filter": pyarrow_filter if not positional_deletes else None, + "columns": [col.name for col in file_project_schema.columns], + } + if batch_size is not None: + scanner_kwargs["batch_size"] = batch_size + + fragment_scanner = ds.Scanner.from_fragment(**scanner_kwargs) next_index = 0 batches = fragment_scanner.to_batches() @@ -1677,6 +1684,89 @@ def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> dict[st return deletes_per_file +_QUEUE_SENTINEL = object() + + +def _bounded_concurrent_batches( + tasks: list[FileScanTask], + batch_fn: Callable[[FileScanTask], Iterator[pa.RecordBatch]], + concurrent_files: int, + max_buffered_batches: int = 16, +) -> Iterator[pa.RecordBatch]: + """Read batches from multiple files concurrently with bounded memory. + + Workers read from files in parallel (up to concurrent_files at a time) and push + batches into a shared queue. The consumer yields batches from the queue. + A sentinel value signals completion, avoiding timeout-based polling. + + Args: + tasks: The file scan tasks to process. + batch_fn: A callable that takes a FileScanTask and returns an iterator of RecordBatches. + concurrent_files: Maximum number of files to read concurrently. + max_buffered_batches: Maximum number of batches to buffer in the queue. + """ + if not tasks: + return + + batch_queue: queue.Queue[pa.RecordBatch | BaseException | object] = queue.Queue(maxsize=max_buffered_batches) + cancel_event = threading.Event() + pending_count = len(tasks) + pending_lock = threading.Lock() + file_semaphore = threading.Semaphore(concurrent_files) + + def worker(task: FileScanTask) -> None: + nonlocal pending_count + acquired = False + try: + # Acquire semaphore — blocks until a slot is available or cancelled + while not cancel_event.is_set(): + if file_semaphore.acquire(timeout=0.5): + acquired = True + break + if cancel_event.is_set(): + return + + for batch in batch_fn(task): + if cancel_event.is_set(): + return + batch_queue.put(batch) + except BaseException as e: + if not cancel_event.is_set(): + batch_queue.put(e) + finally: + if acquired: + file_semaphore.release() + with pending_lock: + pending_count -= 1 + if pending_count == 0: + batch_queue.put(_QUEUE_SENTINEL) + + executor = ExecutorFactory.get_or_create() + futures = [executor.submit(worker, task) for task in tasks] + + try: + while True: + item = batch_queue.get() + + if item is _QUEUE_SENTINEL: + break + + if isinstance(item, BaseException): + raise item + + yield item + finally: + cancel_event.set() + # Drain the queue to unblock any workers stuck on put() + while not batch_queue.empty(): + try: + batch_queue.get_nowait() + except queue.Empty: + break + for future in futures: + future.cancel() + + class ArrowScan: _table_metadata: TableMetadata _io: FileIO @@ -1756,15 +1846,34 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table: return result - def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.RecordBatch]: + def to_record_batches( + self, + tasks: Iterable[FileScanTask], + batch_size: int | None = None, + streaming: bool = False, + concurrent_files: int = 1, + ) -> Iterator[pa.RecordBatch]: """Scan the Iceberg table and return an Iterator[pa.RecordBatch]. Returns an Iterator of pa.RecordBatch with data from the Iceberg table by resolving the right columns that match the current table schema. Only data that matches the provided row_filter expression is returned. + Ordering semantics: + - Default (streaming=False): Batches are grouped by file in task submission order. + - streaming=True, concurrent_files=1: Batches are grouped by file, processed sequentially. + - streaming=True, concurrent_files>1: Batches may be interleaved across files. + In all modes, within-file batch ordering follows row order. + Args: tasks: FileScanTasks representing the data files and delete files to read from. + batch_size: The number of rows per batch. If None, PyArrow's default is used. + streaming: If True, yield batches as they are produced without materializing + entire files into memory. Files are still processed sequentially when + concurrent_files=1. + concurrent_files: Number of files to read concurrently when streaming=True. + When > 1, batches may arrive interleaved across files. Ignored when + streaming=False. Returns: An Iterator of PyArrow RecordBatches. @@ -1776,34 +1885,67 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record """ deletes_per_file = _read_all_delete_files(self._io, tasks) - total_row_count = 0 - executor = ExecutorFactory.get_or_create() + if streaming and concurrent_files > 1: + # Concurrent streaming path: read multiple files in parallel with bounded queue. + # Ordering is NOT guaranteed across files — batches arrive as produced. + task_list = list(tasks) - def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]: - # Materialize the iterator here to ensure execution happens within the executor. - # Otherwise, the iterator would be lazily consumed later (in the main thread), - # defeating the purpose of using executor.map. - return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)) + def batch_fn(task: FileScanTask) -> Iterator[pa.RecordBatch]: + return self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file, batch_size) - limit_reached = False - for batches in executor.map(batches_for_task, tasks): - for batch in batches: + from pyiceberg.table import TableProperties + + max_buffered = int( + self._table_metadata.properties.get( + TableProperties.SCAN_MAX_BUFFERED_BATCHES, + TableProperties.SCAN_MAX_BUFFERED_BATCHES_DEFAULT, + ) + ) + + total_row_count = 0 + for batch in _bounded_concurrent_batches(task_list, batch_fn, concurrent_files, max_buffered): current_batch_size = len(batch) if self._limit is not None and total_row_count + current_batch_size >= self._limit: yield batch.slice(0, self._limit - total_row_count) - - limit_reached = True - break + return else: yield batch total_row_count += current_batch_size + elif streaming: + # Streaming path: process all tasks sequentially, yielding batches as produced. + # _record_batches_from_scan_tasks_and_deletes handles the limit internally + # when called with all tasks, so no outer limit check is needed. + yield from self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file, batch_size) + else: + # Non-streaming path: existing behavior with executor.map + list() + total_row_count = 0 + executor = ExecutorFactory.get_or_create() + + def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]: + # Materialize the iterator here to ensure execution happens within the executor. + # Otherwise, the iterator would be lazily consumed later (in the main thread), + # defeating the purpose of using executor.map. + return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file, batch_size)) + + limit_reached = False + for batches in executor.map(batches_for_task, tasks): + for batch in batches: + current_batch_size = len(batch) + if self._limit is not None and total_row_count + current_batch_size >= self._limit: + yield batch.slice(0, self._limit - total_row_count) + + limit_reached = True + break + else: + yield batch + total_row_count += current_batch_size - if limit_reached: - # This break will also cancel all running tasks in the executor - break + if limit_reached: + # This break will also cancel all running tasks in the executor + break def _record_batches_from_scan_tasks_and_deletes( - self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]] + self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]], batch_size: int | None = None ) -> Iterator[pa.RecordBatch]: total_row_count = 0 for task in tasks: @@ -1822,6 +1964,7 @@ def _record_batches_from_scan_tasks_and_deletes( self._table_metadata.specs().get(task.file.spec_id), self._table_metadata.format_version, self._downcast_ns_timestamp_to_us, + batch_size, ) for batch in batches: if self._limit is not None: diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index cc0d9ff341..19192794cb 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -247,6 +247,9 @@ class TableProperties: MIN_SNAPSHOTS_TO_KEEP = "history.expire.min-snapshots-to-keep" MIN_SNAPSHOTS_TO_KEEP_DEFAULT = 1 + SCAN_MAX_BUFFERED_BATCHES = "scan.max-buffered-batches" + SCAN_MAX_BUFFERED_BATCHES_DEFAULT = 16 + class Transaction: _table: Table @@ -2157,13 +2160,29 @@ def to_arrow(self) -> pa.Table: self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit ).to_table(self.plan_files()) - def to_arrow_batch_reader(self) -> pa.RecordBatchReader: + def to_arrow_batch_reader( + self, batch_size: int | None = None, streaming: bool = False, concurrent_files: int = 1 + ) -> pa.RecordBatchReader: """Return an Arrow RecordBatchReader from this DataScan. For large results, using a RecordBatchReader requires less memory than loading an Arrow Table for the same DataScan, because a RecordBatch is read one at a time. + Ordering semantics: + - Default (streaming=False): Batches are grouped by file in task submission order. + - streaming=True, concurrent_files=1: Batches are grouped by file, processed sequentially. + - streaming=True, concurrent_files>1: Batches may be interleaved across files. + In all modes, within-file batch ordering follows row order. + + Args: + batch_size: The number of rows per batch. If None, PyArrow's default is used. + streaming: If True, yield batches as they are produced without materializing + entire files into memory. Files are still processed sequentially when + concurrent_files=1. + concurrent_files: Number of files to read concurrently when streaming=True. + When > 1, batches may arrive interleaved across files. + Returns: pa.RecordBatchReader: Arrow RecordBatchReader from the Iceberg table's DataScan which can be used to read a stream of record batches one by one. @@ -2175,7 +2194,7 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: target_schema = schema_to_pyarrow(self.projection()) batches = ArrowScan( self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit - ).to_record_batches(self.plan_files()) + ).to_record_batches(self.plan_files(), batch_size=batch_size, streaming=streaming, concurrent_files=concurrent_files) return pa.RecordBatchReader.from_batches( target_schema, diff --git a/tests/benchmark/test_read_benchmark.py b/tests/benchmark/test_read_benchmark.py new file mode 100644 index 0000000000..6547999889 --- /dev/null +++ b/tests/benchmark/test_read_benchmark.py @@ -0,0 +1,181 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Read throughput micro-benchmark for ArrowScan configurations. + +Measures records/sec and peak Arrow memory across streaming, concurrent_files, +and batch_size configurations introduced for issue #3036. + +Memory is measured using pa.total_allocated_bytes() which tracks PyArrow's C++ +memory pool (Arrow buffers, Parquet decompression), not Python heap allocations. + +Run with: uv run pytest tests/benchmark/test_read_benchmark.py -v -s -m benchmark +""" + +import gc +import statistics +import timeit +from datetime import datetime, timezone + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +from pyiceberg.catalog.sql import SqlCatalog + +NUM_FILES = 32 +ROWS_PER_FILE = 500_000 +TOTAL_ROWS = NUM_FILES * ROWS_PER_FILE +NUM_RUNS = 3 + + +def _generate_parquet_file(path: str, num_rows: int, seed: int) -> pa.Schema: + """Write a synthetic Parquet file and return its schema.""" + table = pa.table( + { + "id": pa.array(range(seed, seed + num_rows), type=pa.int64()), + "value": pa.array([float(i) * 0.1 for i in range(num_rows)], type=pa.float64()), + "label": pa.array([f"row_{i}" for i in range(num_rows)], type=pa.string()), + "flag": pa.array([i % 2 == 0 for i in range(num_rows)], type=pa.bool_()), + "ts": pa.array([datetime.now(timezone.utc)] * num_rows, type=pa.timestamp("us", tz="UTC")), + } + ) + pq.write_table(table, path) + return table.schema + + +@pytest.fixture(scope="session") +def benchmark_table(tmp_path_factory: pytest.TempPathFactory) -> "pyiceberg.table.Table": # noqa: F821 + """Create a catalog and table with synthetic Parquet files for benchmarking.""" + warehouse_path = str(tmp_path_factory.mktemp("benchmark_warehouse")) + catalog = SqlCatalog( + "benchmark", + uri=f"sqlite:///{warehouse_path}/pyiceberg_catalog.db", + warehouse=f"file://{warehouse_path}", + ) + catalog.create_namespace("default") + + # Generate files and append to table + table = None + for i in range(NUM_FILES): + file_path = f"{warehouse_path}/data_{i}.parquet" + _generate_parquet_file(file_path, ROWS_PER_FILE, seed=i * ROWS_PER_FILE) + + file_table = pq.read_table(file_path) + if table is None: + table = catalog.create_table("default.benchmark_read", schema=file_table.schema) + table.append(file_table) + + return table + + +def _measure_peak_arrow_memory(benchmark_table, batch_size, streaming, concurrent_files): + """Run a scan and track peak PyArrow C++ memory allocation.""" + gc.collect() + pa.default_memory_pool().release_unused() + baseline = pa.total_allocated_bytes() + peak = baseline + + total_rows = 0 + for batch in benchmark_table.scan().to_arrow_batch_reader( + batch_size=batch_size, + streaming=streaming, + concurrent_files=concurrent_files, + ): + total_rows += len(batch) + current = pa.total_allocated_bytes() + if current > peak: + peak = current + # Release the batch immediately to simulate a streaming consumer + del batch + + return total_rows, peak - baseline + + +@pytest.mark.benchmark +@pytest.mark.parametrize( + "streaming,concurrent_files,batch_size", + [ + pytest.param(False, 1, None, id="default"), + pytest.param(True, 1, None, id="streaming-cf1"), + pytest.param(True, 2, None, id="streaming-cf2"), + pytest.param(True, 4, None, id="streaming-cf4"), + pytest.param(True, 8, None, id="streaming-cf8"), + pytest.param(True, 16, None, id="streaming-cf16"), + ], +) +def test_read_throughput( + benchmark_table: "pyiceberg.table.Table", # noqa: F821 + streaming: bool, + concurrent_files: int, + batch_size: int | None, +) -> None: + """Measure records/sec and peak Arrow memory for a scan configuration.""" + effective_batch_size = batch_size or 131_072 # PyArrow default + if streaming: + config_str = f"streaming=True, concurrent_files={concurrent_files}, batch_size={effective_batch_size}" + else: + config_str = f"streaming=False (executor.map, all files parallel), batch_size={effective_batch_size}" + print(f"\n--- ArrowScan Read Throughput Benchmark ---") + print(f"Config: {config_str}") + print(f" Files: {NUM_FILES}, Rows per file: {ROWS_PER_FILE}, Total rows: {TOTAL_ROWS}") + + elapsed_times: list[float] = [] + throughputs: list[float] = [] + peak_memories: list[int] = [] + + for run in range(NUM_RUNS): + # Measure throughput + gc.collect() + pa.default_memory_pool().release_unused() + baseline_mem = pa.total_allocated_bytes() + peak_mem = baseline_mem + + start = timeit.default_timer() + total_rows = 0 + for batch in benchmark_table.scan().to_arrow_batch_reader( + batch_size=batch_size, + streaming=streaming, + concurrent_files=concurrent_files, + ): + total_rows += len(batch) + current_mem = pa.total_allocated_bytes() + if current_mem > peak_mem: + peak_mem = current_mem + elapsed = timeit.default_timer() - start + + peak_above_baseline = peak_mem - baseline_mem + rows_per_sec = total_rows / elapsed if elapsed > 0 else 0 + elapsed_times.append(elapsed) + throughputs.append(rows_per_sec) + peak_memories.append(peak_above_baseline) + + print( + f" Run {run + 1}: {elapsed:.2f}s, {rows_per_sec:,.0f} rows/s, " + f"peak arrow mem: {peak_above_baseline / (1024 * 1024):.1f} MB" + ) + + assert total_rows == TOTAL_ROWS, f"Expected {TOTAL_ROWS} rows, got {total_rows}" + + mean_elapsed = statistics.mean(elapsed_times) + stdev_elapsed = statistics.stdev(elapsed_times) if len(elapsed_times) > 1 else 0.0 + mean_throughput = statistics.mean(throughputs) + mean_peak_mem = statistics.mean(peak_memories) + + print( + f" Mean: {mean_elapsed:.2f}s ± {stdev_elapsed:.2f}s, {mean_throughput:,.0f} rows/s, " + f"peak arrow mem: {mean_peak_mem / (1024 * 1024):.1f} MB" + ) diff --git a/tests/io/test_bounded_concurrent_batches.py b/tests/io/test_bounded_concurrent_batches.py new file mode 100644 index 0000000000..96cd5720e2 --- /dev/null +++ b/tests/io/test_bounded_concurrent_batches.py @@ -0,0 +1,255 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for _bounded_concurrent_batches in pyiceberg.io.pyarrow.""" + +import threading +import time +from collections.abc import Iterator +from unittest.mock import MagicMock + +import pyarrow as pa +import pytest + +from pyiceberg.io.pyarrow import _bounded_concurrent_batches +from pyiceberg.table import FileScanTask + + +def _make_task() -> FileScanTask: + """Create a mock FileScanTask.""" + task = MagicMock(spec=FileScanTask) + return task + + +def _make_batches(num_batches: int, rows_per_batch: int = 10, start: int = 0) -> list[pa.RecordBatch]: + """Create a list of simple RecordBatches.""" + return [pa.record_batch({"col": list(range(start + i * rows_per_batch, start + (i + 1) * rows_per_batch))}) for i in range(num_batches)] + + +def test_correctness_single_file() -> None: + """Test that a single file produces correct results.""" + task = _make_task() + expected_batches = _make_batches(3) + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + yield from expected_batches + + result = list(_bounded_concurrent_batches([task], batch_fn, concurrent_files=1, max_buffered_batches=16)) + + assert len(result) == 3 + total_rows = sum(len(b) for b in result) + assert total_rows == 30 + + +def test_correctness_multiple_files() -> None: + """Test that multiple files produce all expected batches.""" + tasks = [_make_task() for _ in range(4)] + batches_per_file = 3 + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + idx = tasks.index(t) + yield from _make_batches(batches_per_file, start=idx * 100) + + result = list(_bounded_concurrent_batches(tasks, batch_fn, concurrent_files=2, max_buffered_batches=16)) + + total_rows = sum(len(b) for b in result) + assert total_rows == batches_per_file * len(tasks) * 10 # 3 batches * 4 files * 10 rows + + +def test_streaming_yields_incrementally() -> None: + """Test that batches are yielded incrementally, not all at once.""" + barrier = threading.Event() + tasks = [_make_task(), _make_task()] + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + yield pa.record_batch({"col": [1, 2, 3]}) + barrier.wait(timeout=5.0) + yield pa.record_batch({"col": [4, 5, 6]}) + + gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_files=2, max_buffered_batches=16) + + # Should get at least one batch before all are done + first = next(gen) + assert first.num_rows == 3 + + # Unblock remaining batches + barrier.set() + + remaining = list(gen) + total = 1 + len(remaining) + assert total >= 3 # At least 3 more batches (one blocked from each task + the unblocked ones) + + +def test_backpressure() -> None: + """Test that workers block when the queue is full.""" + max_buffered = 2 + tasks = [_make_task()] + produced_count = 0 + produce_lock = threading.Lock() + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + nonlocal produced_count + for i in range(10): + with produce_lock: + produced_count += 1 + yield pa.record_batch({"col": [i]}) + + gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_files=1, max_buffered_batches=max_buffered) + + # Consume slowly and check that not all batches are produced immediately + first = next(gen) + assert first is not None + time.sleep(0.3) + + # The producer should be blocked by backpressure at some point + # (not all 10 batches produced instantly) + remaining = list(gen) + assert len(remaining) + 1 == 10 + + +def test_error_propagation() -> None: + """Test that errors from workers are propagated to the consumer.""" + tasks = [_make_task()] + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + yield pa.record_batch({"col": [1]}) + raise ValueError("test error") + + gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_files=1, max_buffered_batches=16) + + # Should get the first batch + first = next(gen) + assert first.num_rows == 1 + + # Should get the error + with pytest.raises(ValueError, match="test error"): + list(gen) + + +def test_early_termination() -> None: + """Test that stopping consumption cancels workers.""" + tasks = [_make_task() for _ in range(5)] + worker_started = threading.Event() + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + worker_started.set() + for i in range(100): + yield pa.record_batch({"col": [i]}) + time.sleep(0.01) + + gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_files=3, max_buffered_batches=4) + + # Consume a few batches then stop + worker_started.wait(timeout=5.0) + batches = [] + for _ in range(5): + batches.append(next(gen)) + + # Close the generator, triggering finally block + gen.close() + + assert len(batches) == 5 + + +def test_concurrency_limit() -> None: + """Test that at most concurrent_files files are read concurrently.""" + concurrent_files = 2 + tasks = [_make_task() for _ in range(6)] + active_count = 0 + max_active = 0 + active_lock = threading.Lock() + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + nonlocal active_count, max_active + with active_lock: + active_count += 1 + max_active = max(max_active, active_count) + try: + time.sleep(0.05) + yield pa.record_batch({"col": [1]}) + finally: + with active_lock: + active_count -= 1 + + result = list(_bounded_concurrent_batches(tasks, batch_fn, concurrent_files=concurrent_files, max_buffered_batches=16)) + + assert len(result) == 6 + assert max_active <= concurrent_files + + +def test_empty_tasks() -> None: + """Test that no tasks produces no batches.""" + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + yield from [] + + result = list(_bounded_concurrent_batches([], batch_fn, concurrent_files=2, max_buffered_batches=16)) + assert result == [] + + +def test_concurrent_with_limit_via_arrowscan(tmpdir: str) -> None: + """Test concurrent_files with limit through ArrowScan integration.""" + from pyiceberg.expressions import AlwaysTrue + from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO, schema_to_pyarrow, write_file + from pyiceberg.manifest import DataFileContent, FileFormat + from pyiceberg.partitioning import PartitionSpec + from pyiceberg.schema import Schema + from pyiceberg.table.metadata import TableMetadataV2 + from pyiceberg.types import LongType, NestedField + + PYARROW_PARQUET_FIELD_ID_KEY = b"PARQUET:field_id" + + table_schema = Schema(NestedField(1, "col", LongType(), required=True)) + pa_schema = pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]) + + tasks = [] + for i in range(4): + filepath = f"{tmpdir}/file_{i}.parquet" + arrow_table = pa.table({"col": pa.array(range(i * 100, (i + 1) * 100))}, schema=pa_schema) + import pyarrow.parquet as pq + + pq.write_table(arrow_table, filepath) + from pyiceberg.manifest import DataFile + + data_file = DataFile.from_args( + content=DataFileContent.DATA, + file_path=filepath, + file_format=FileFormat.PARQUET, + partition={}, + record_count=100, + file_size_in_bytes=22, + ) + data_file.spec_id = 0 + tasks.append(FileScanTask(data_file)) + + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=1, + format_version=2, + schemas=[table_schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=table_schema, + row_filter=AlwaysTrue(), + case_sensitive=True, + limit=150, + ) + + batches = list(scan.to_record_batches(tasks, streaming=True, concurrent_files=2)) + total_rows = sum(len(b) for b in batches) + assert total_rows == 150 diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 04bc3ecfac..847dbadad2 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -3048,6 +3048,144 @@ def _expected_batch(unit: str) -> pa.RecordBatch: assert _expected_batch("ns" if format_version > 2 else "us").equals(actual_result) +def test_task_to_record_batches_with_batch_size(tmpdir: str) -> None: + """Test that batch_size controls the number of rows per batch.""" + num_rows = 1000 + arrow_table = pa.table( + {"col": pa.array(range(num_rows))}, + schema=pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]), + ) + data_file = _write_table_to_data_file(f"{tmpdir}/test_batch_size.parquet", arrow_table.schema, arrow_table) + table_schema = Schema(NestedField(1, "col", LongType(), required=True)) + + batches = list( + _task_to_record_batches( + PyArrowFileIO(), + FileScanTask(data_file), + bound_row_filter=AlwaysTrue(), + projected_schema=table_schema, + table_schema=table_schema, + projected_field_ids={1}, + positional_deletes=None, + case_sensitive=True, + batch_size=100, + ) + ) + + assert len(batches) > 1 + for batch in batches: + assert len(batch) <= 100 + assert sum(len(b) for b in batches) == num_rows + + +def test_task_to_record_batches_default_batch_size(tmpdir: str) -> None: + """Test that batch_size=None uses PyArrow default (single batch for small files).""" + num_rows = 100 + arrow_table = pa.table( + {"col": pa.array(range(num_rows))}, + schema=pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]), + ) + data_file = _write_table_to_data_file(f"{tmpdir}/test_default_batch_size.parquet", arrow_table.schema, arrow_table) + table_schema = Schema(NestedField(1, "col", LongType(), required=True)) + + batches = list( + _task_to_record_batches( + PyArrowFileIO(), + FileScanTask(data_file), + bound_row_filter=AlwaysTrue(), + projected_schema=table_schema, + table_schema=table_schema, + projected_field_ids={1}, + positional_deletes=None, + case_sensitive=True, + ) + ) + + # With default batch_size, a small file should produce a single batch + assert len(batches) == 1 + assert len(batches[0]) == num_rows + + +def _create_scan_and_tasks( + tmpdir: str, num_files: int = 1, rows_per_file: int = 100, limit: int | None = None +) -> tuple[ArrowScan, list[FileScanTask]]: + """Helper to create an ArrowScan and FileScanTasks for testing.""" + table_schema = Schema(NestedField(1, "col", LongType(), required=True)) + pa_schema = pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]) + tasks = [] + for i in range(num_files): + start = i * rows_per_file + arrow_table = pa.table({"col": pa.array(range(start, start + rows_per_file))}, schema=pa_schema) + data_file = _write_table_to_data_file(f"{tmpdir}/file_{i}.parquet", pa_schema, arrow_table) + data_file.spec_id = 0 + tasks.append(FileScanTask(data_file)) + + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=1, + format_version=2, + schemas=[table_schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=table_schema, + row_filter=AlwaysTrue(), + case_sensitive=True, + limit=limit, + ) + return scan, tasks + + +def test_streaming_false_produces_same_results(tmpdir: str) -> None: + """Test that streaming=False produces the same results as the default behavior.""" + scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100) + + batches_default = list(scan.to_record_batches(tasks, streaming=False)) + # Re-create tasks since iterators are consumed + _, tasks2 = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100) + batches_streaming = list(scan.to_record_batches(tasks2, streaming=False)) + + total_default = sum(len(b) for b in batches_default) + total_streaming = sum(len(b) for b in batches_streaming) + assert total_default == 300 + assert total_streaming == 300 + + +def test_streaming_true_yields_all_batches(tmpdir: str) -> None: + """Test that streaming=True yields all batches correctly.""" + scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100) + + batches = list(scan.to_record_batches(tasks, streaming=True)) + + total_rows = sum(len(b) for b in batches) + assert total_rows == 300 + # Verify all values are present + all_values = sorted([v for b in batches for v in b.column("col").to_pylist()]) + assert all_values == list(range(300)) + + +def test_streaming_true_with_limit(tmpdir: str) -> None: + """Test that streaming=True respects the row limit.""" + scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100, limit=150) + + batches = list(scan.to_record_batches(tasks, streaming=True)) + + total_rows = sum(len(b) for b in batches) + assert total_rows == 150 + + +def test_streaming_file_ordering_preserved(tmpdir: str) -> None: + """Test that file ordering is preserved in both streaming modes.""" + scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100) + + batches = list(scan.to_record_batches(tasks, streaming=True)) + all_values = [v for b in batches for v in b.column("col").to_pylist()] + + # Values should be in file order: 0-99 from file 0, 100-199 from file 1, 200-299 from file 2 + assert all_values == list(range(300)) + + def test_parse_location_defaults() -> None: """Test that parse_location uses defaults."""