diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 506547fcd6..6c439f5fed 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -355,6 +355,51 @@ for buf in tbl.scan().to_arrow_batch_reader(): print(f"Buffer contains {len(buf)} rows") ``` +You can control the number of rows per batch using the `batch_size` parameter: + +```python +for buf in tbl.scan().to_arrow_batch_reader(batch_size=1000): + print(f"Buffer contains {len(buf)} rows") +``` + +By default, each file's batches are materialized in memory before being yielded (`order=ScanOrder.TASK`). For large files that may exceed available memory, use `order=ScanOrder.ARRIVAL` to yield batches as they are produced without materializing entire files: + +```python +from pyiceberg.table import ScanOrder + +for buf in tbl.scan().to_arrow_batch_reader(order=ScanOrder.ARRIVAL, batch_size=1000): + print(f"Buffer contains {len(buf)} rows") +``` + +For maximum throughput, use `concurrent_files` to read multiple files in parallel with arrival order. Batches are yielded as they arrive from any file — ordering across files is not guaranteed: + +```python +from pyiceberg.table import ScanOrder + +for buf in tbl.scan().to_arrow_batch_reader(order=ScanOrder.ARRIVAL, concurrent_files=4, batch_size=1000): + print(f"Buffer contains {len(buf)} rows") +``` + +**Ordering semantics:** + +| Configuration | File ordering | Within-file ordering | +|---|---|---| +| `ScanOrder.TASK` (default) | Batches grouped by file, in task submission order | Row order | +| `ScanOrder.ARRIVAL` | Interleaved across files (no grouping guarantee) | Row order within each file | + +Within each file, batch ordering always follows row order. The `limit` parameter is enforced correctly regardless of configuration. + +**Which configuration should I use?** + +| Use case | Recommended config | +|---|---| +| Small tables, simple queries | Default — no extra args needed | +| Large tables, memory-constrained | `order=ScanOrder.ARRIVAL` — one file at a time, minimal memory | +| Maximum throughput with bounded memory | `order=ScanOrder.ARRIVAL, concurrent_files=N` — tune N to balance throughput vs memory | +| Fine-grained batch control | Add `batch_size=N` to any of the above | + +**Note:** `ScanOrder.ARRIVAL` yields batches in arrival order (interleaved across files when `concurrent_files > 1`). For deterministic file ordering, use the default `ScanOrder.TASK` mode. `batch_size` is usually an advanced tuning knob — the PyArrow default of 131,072 rows works well for most workloads. + To avoid any type inconsistencies during writing, you can convert the Iceberg table schema to Arrow: ```python @@ -1619,6 +1664,39 @@ table.scan( ).to_arrow_batch_reader() ``` +The `batch_size` parameter controls the maximum number of rows per RecordBatch (default is PyArrow's 131,072 rows): + +```python +table.scan( + row_filter=GreaterThanOrEqual("trip_distance", 10.0), + selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"), +).to_arrow_batch_reader(batch_size=1000) +``` + +Use `order=ScanOrder.ARRIVAL` to avoid materializing entire files in memory. This yields batches as they are produced by PyArrow, one file at a time: + +```python +from pyiceberg.table import ScanOrder + +table.scan( + row_filter=GreaterThanOrEqual("trip_distance", 10.0), + selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"), +).to_arrow_batch_reader(order=ScanOrder.ARRIVAL) +``` + +For concurrent file reads with arrival order, use `concurrent_files`. Note that batch ordering across files is not guaranteed: + +```python +from pyiceberg.table import ScanOrder + +table.scan( + row_filter=GreaterThanOrEqual("trip_distance", 10.0), + selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"), +).to_arrow_batch_reader(order=ScanOrder.ARRIVAL, concurrent_files=4) +``` + +When using `concurrent_files > 1`, batches from different files may be interleaved. Within each file, batches are always in row order. See the ordering semantics table in the [Apache Arrow section](#apache-arrow) above for details. + ### Pandas diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index a120c3b776..24af3d0223 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -33,11 +33,14 @@ import logging import operator import os +import queue import re +import threading import uuid import warnings from abc import ABC, abstractmethod -from collections.abc import Callable, Iterable, Iterator +from collections.abc import Callable, Generator, Iterable, Iterator +from concurrent.futures import ThreadPoolExecutor from copy import copy from dataclasses import dataclass from enum import Enum @@ -141,7 +144,7 @@ visit, visit_with_partner, ) -from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, TableProperties +from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, ScanOrder, TableProperties from pyiceberg.table.locations import load_location_provider from pyiceberg.table.metadata import TableMetadata from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping @@ -1581,6 +1584,7 @@ def _task_to_record_batches( partition_spec: PartitionSpec | None = None, format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION, downcast_ns_timestamp_to_us: bool | None = None, + batch_size: int | None = None, ) -> Iterator[pa.RecordBatch]: arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) with io.new_input(task.file.file_path).open() as fin: @@ -1612,14 +1616,18 @@ def _task_to_record_batches( file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) - fragment_scanner = ds.Scanner.from_fragment( - fragment=fragment, - schema=physical_schema, + scanner_kwargs: dict[str, Any] = { + "fragment": fragment, + "schema": physical_schema, # This will push down the query to Arrow. # But in case there are positional deletes, we have to apply them first - filter=pyarrow_filter if not positional_deletes else None, - columns=[col.name for col in file_project_schema.columns], - ) + "filter": pyarrow_filter if not positional_deletes else None, + "columns": [col.name for col in file_project_schema.columns], + } + if batch_size is not None: + scanner_kwargs["batch_size"] = batch_size + + fragment_scanner = ds.Scanner.from_fragment(**scanner_kwargs) next_index = 0 batches = fragment_scanner.to_batches() @@ -1677,6 +1685,76 @@ def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> dict[st return deletes_per_file +_QUEUE_SENTINEL = object() + + +def _bounded_concurrent_batches( + tasks: list[FileScanTask], + batch_fn: Callable[[FileScanTask], Iterator[pa.RecordBatch]], + concurrent_files: int, + max_buffered_batches: int = 16, +) -> Generator[pa.RecordBatch, None, None]: + """Read batches from multiple files concurrently with bounded memory. + + Uses a per-scan ThreadPoolExecutor(max_workers=concurrent_files) to naturally + bound concurrency. Workers push batches into a bounded queue which provides + backpressure when the consumer is slower than the producers. + + Args: + tasks: The file scan tasks to process. + batch_fn: A callable that takes a FileScanTask and returns an iterator of RecordBatches. + concurrent_files: Maximum number of files to read concurrently. + max_buffered_batches: Maximum number of batches to buffer in the queue. + """ + if not tasks: + return + + batch_queue: queue.Queue[pa.RecordBatch | BaseException | object] = queue.Queue(maxsize=max_buffered_batches) + cancel = threading.Event() + remaining = len(tasks) + remaining_lock = threading.Lock() + + def worker(task: FileScanTask) -> None: + nonlocal remaining + try: + for batch in batch_fn(task): + if cancel.is_set(): + return + batch_queue.put(batch) + except BaseException as e: + if not cancel.is_set(): + batch_queue.put(e) + finally: + with remaining_lock: + remaining -= 1 + if remaining == 0: + batch_queue.put(_QUEUE_SENTINEL) + + with ThreadPoolExecutor(max_workers=concurrent_files) as executor: + for task in tasks: + executor.submit(worker, task) + + try: + while True: + item = batch_queue.get() + + if item is _QUEUE_SENTINEL: + break + + if isinstance(item, BaseException): + raise item + + yield item + finally: + cancel.set() + # Drain the queue to unblock any workers stuck on put() + while not batch_queue.empty(): + try: + batch_queue.get_nowait() + except queue.Empty: + break + + class ArrowScan: _table_metadata: TableMetadata _io: FileIO @@ -1756,15 +1834,35 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table: return result - def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.RecordBatch]: + def to_record_batches( + self, + tasks: Iterable[FileScanTask], + batch_size: int | None = None, + order: ScanOrder = ScanOrder.TASK, + concurrent_files: int = 1, + ) -> Iterator[pa.RecordBatch]: """Scan the Iceberg table and return an Iterator[pa.RecordBatch]. Returns an Iterator of pa.RecordBatch with data from the Iceberg table by resolving the right columns that match the current table schema. Only data that matches the provided row_filter expression is returned. + Ordering semantics: + - ScanOrder.TASK (default): Batches are grouped by file in task submission order. + - ScanOrder.ARRIVAL: Batches may be interleaved across files. Within each file, + batch ordering follows row order. + Args: tasks: FileScanTasks representing the data files and delete files to read from. + batch_size: The number of rows per batch. If None, PyArrow's default is used. + order: Controls the order in which record batches are returned. + ScanOrder.TASK (default) returns batches in task order, with each task + fully materialized before proceeding to the next. Allows parallel file + reads via executor. ScanOrder.ARRIVAL yields batches as they are + produced without materializing entire files into memory. + concurrent_files: Number of files to read concurrently when order=ScanOrder.ARRIVAL. + Must be >= 1. When > 1, batches may arrive interleaved across files. + Ignored when order=ScanOrder.TASK. Returns: An Iterator of PyArrow RecordBatches. @@ -1772,38 +1870,78 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record Raises: ResolveError: When a required field cannot be found in the file - ValueError: When a field type in the file cannot be projected to the schema type + ValueError: When a field type in the file cannot be projected to the schema type, + or when an invalid order value is provided, or when concurrent_files < 1. """ - deletes_per_file = _read_all_delete_files(self._io, tasks) + if not isinstance(order, ScanOrder): + raise ValueError(f"Invalid order: {order!r}. Must be a ScanOrder enum value (ScanOrder.TASK or ScanOrder.ARRIVAL).") - total_row_count = 0 + if concurrent_files < 1: + raise ValueError(f"concurrent_files must be >= 1, got {concurrent_files}") + + task_list, deletes_per_file = self._prepare_tasks_and_deletes(tasks) + + if order == ScanOrder.ARRIVAL: + return self._apply_limit(self._iter_batches_arrival(task_list, deletes_per_file, batch_size, concurrent_files)) + + return self._apply_limit(self._iter_batches_materialized(task_list, deletes_per_file, batch_size)) + + def _prepare_tasks_and_deletes( + self, tasks: Iterable[FileScanTask] + ) -> tuple[list[FileScanTask], dict[str, list[ChunkedArray]]]: + """Resolve delete files and return tasks as a list.""" + task_list = list(tasks) + deletes_per_file = _read_all_delete_files(self._io, task_list) + return task_list, deletes_per_file + + def _iter_batches_arrival( + self, + task_list: list[FileScanTask], + deletes_per_file: dict[str, list[ChunkedArray]], + batch_size: int | None, + concurrent_files: int, + ) -> Iterator[pa.RecordBatch]: + """Yield batches using bounded concurrent streaming in arrival order.""" + + def batch_fn(task: FileScanTask) -> Iterator[pa.RecordBatch]: + return self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file, batch_size) + + yield from _bounded_concurrent_batches(task_list, batch_fn, concurrent_files) + + def _iter_batches_materialized( + self, + task_list: list[FileScanTask], + deletes_per_file: dict[str, list[ChunkedArray]], + batch_size: int | None, + ) -> Iterator[pa.RecordBatch]: + """Yield batches using executor.map with full file materialization.""" executor = ExecutorFactory.get_or_create() def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]: - # Materialize the iterator here to ensure execution happens within the executor. - # Otherwise, the iterator would be lazily consumed later (in the main thread), - # defeating the purpose of using executor.map. - return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)) + return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file, batch_size)) - limit_reached = False - for batches in executor.map(batches_for_task, tasks): - for batch in batches: - current_batch_size = len(batch) - if self._limit is not None and total_row_count + current_batch_size >= self._limit: - yield batch.slice(0, self._limit - total_row_count) + for batches in executor.map(batches_for_task, task_list): + yield from batches - limit_reached = True - break - else: - yield batch - total_row_count += current_batch_size + def _apply_limit(self, batches: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: + """Apply row limit across batches.""" + if self._limit is None: + yield from batches + return - if limit_reached: - # This break will also cancel all running tasks in the executor - break + total_row_count = 0 + for batch in batches: + remaining = self._limit - total_row_count + if remaining <= 0: + return + if len(batch) > remaining: + yield batch.slice(0, remaining) + return + yield batch + total_row_count += len(batch) def _record_batches_from_scan_tasks_and_deletes( - self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]] + self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]], batch_size: int | None = None ) -> Iterator[pa.RecordBatch]: total_row_count = 0 for task in tasks: @@ -1822,6 +1960,7 @@ def _record_batches_from_scan_tasks_and_deletes( self._table_metadata.specs().get(task.file.spec_id), self._table_metadata.format_version, self._downcast_ns_timestamp_to_us, + batch_size, ) for batch in batches: if self._limit is not None: diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index cc0d9ff341..c7f43b48c1 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -23,6 +23,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator from dataclasses import dataclass +from enum import Enum from functools import cached_property from itertools import chain from types import TracebackType @@ -154,6 +155,20 @@ DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write" +class ScanOrder(str, Enum): + """Order in which record batches are returned from a scan. + + Attributes: + TASK: Batches are returned in task order, with each task fully materialized + before proceeding to the next. Allows parallel file reads via executor. + ARRIVAL: Batches are yielded as they are produced, processing tasks + sequentially without materializing entire files into memory. + """ + + TASK = "task" + ARRIVAL = "arrival" + + @dataclass() class UpsertResult: """Summary the upsert operation.""" @@ -2002,13 +2017,11 @@ def _build_residual_evaluator(self, spec_id: int) -> Callable[[DataFile], Residu # The lambda created here is run in multiple threads. # So we avoid creating _EvaluatorExpression methods bound to a single # shared instance across multiple threads. - return lambda datafile: ( - residual_evaluator_of( - spec=spec, - expr=self.row_filter, - case_sensitive=self.case_sensitive, - schema=self.table_metadata.schema(), - ) + return lambda datafile: residual_evaluator_of( + spec=spec, + expr=self.row_filter, + case_sensitive=self.case_sensitive, + schema=self.table_metadata.schema(), ) @staticmethod @@ -2157,13 +2170,29 @@ def to_arrow(self) -> pa.Table: self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit ).to_table(self.plan_files()) - def to_arrow_batch_reader(self) -> pa.RecordBatchReader: + def to_arrow_batch_reader( + self, batch_size: int | None = None, order: ScanOrder = ScanOrder.TASK, concurrent_files: int = 1 + ) -> pa.RecordBatchReader: """Return an Arrow RecordBatchReader from this DataScan. For large results, using a RecordBatchReader requires less memory than loading an Arrow Table for the same DataScan, because a RecordBatch is read one at a time. + Ordering semantics: + - ScanOrder.TASK (default): Batches are grouped by file in task submission order. + - ScanOrder.ARRIVAL: Batches may be interleaved across files. Within each file, + batch ordering follows row order. + + Args: + batch_size: The number of rows per batch. If None, PyArrow's default is used. + order: Controls the order in which record batches are returned. + ScanOrder.TASK (default) returns batches in task order with parallel + file reads. ScanOrder.ARRIVAL yields batches as they are produced + without materializing entire files into memory. + concurrent_files: Number of files to read concurrently when order=ScanOrder.ARRIVAL. + When > 1, batches may arrive interleaved across files. + Returns: pa.RecordBatchReader: Arrow RecordBatchReader from the Iceberg table's DataScan which can be used to read a stream of record batches one by one. @@ -2175,7 +2204,7 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: target_schema = schema_to_pyarrow(self.projection()) batches = ArrowScan( self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit - ).to_record_batches(self.plan_files()) + ).to_record_batches(self.plan_files(), batch_size=batch_size, order=order, concurrent_files=concurrent_files) return pa.RecordBatchReader.from_batches( target_schema, diff --git a/tests/benchmark/test_read_benchmark.py b/tests/benchmark/test_read_benchmark.py new file mode 100644 index 0000000000..809c106dd1 --- /dev/null +++ b/tests/benchmark/test_read_benchmark.py @@ -0,0 +1,167 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Read throughput micro-benchmark for ArrowScan configurations. + +Measures records/sec and peak Arrow memory across ScanOrder, concurrent_files, +and batch_size configurations introduced for issue #3036. + +Memory is measured using pa.total_allocated_bytes() which tracks PyArrow's C++ +memory pool (Arrow buffers, Parquet decompression), not Python heap allocations. + +Run with: uv run pytest tests/benchmark/test_read_benchmark.py -v -s +""" + +import gc +import statistics +import timeit +from datetime import datetime, timezone + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +from pyiceberg.catalog.sql import SqlCatalog +from pyiceberg.table import ScanOrder, Table + +NUM_FILES = 32 +ROWS_PER_FILE = 500_000 +TOTAL_ROWS = NUM_FILES * ROWS_PER_FILE +NUM_RUNS = 3 + + +def _generate_parquet_file(path: str, num_rows: int, seed: int) -> pa.Schema: + """Write a synthetic Parquet file and return its schema.""" + table = pa.table( + { + "id": pa.array(range(seed, seed + num_rows), type=pa.int64()), + "value": pa.array([float(i) * 0.1 for i in range(num_rows)], type=pa.float64()), + "label": pa.array([f"row_{i}" for i in range(num_rows)], type=pa.string()), + "flag": pa.array([i % 2 == 0 for i in range(num_rows)], type=pa.bool_()), + "ts": pa.array([datetime.now(timezone.utc)] * num_rows, type=pa.timestamp("us", tz="UTC")), + } + ) + pq.write_table(table, path) + return table.schema + + +@pytest.fixture(scope="session") +def benchmark_table(tmp_path_factory: pytest.TempPathFactory) -> Table: + """Create a catalog and table with synthetic Parquet files for benchmarking.""" + warehouse_path = str(tmp_path_factory.mktemp("benchmark_warehouse")) + catalog = SqlCatalog( + "benchmark", + uri=f"sqlite:///{warehouse_path}/pyiceberg_catalog.db", + warehouse=f"file://{warehouse_path}", + ) + catalog.create_namespace("default") + + # Generate files and append to table + table = None + for i in range(NUM_FILES): + file_path = f"{warehouse_path}/data_{i}.parquet" + _generate_parquet_file(file_path, ROWS_PER_FILE, seed=i * ROWS_PER_FILE) + + file_table = pq.read_table(file_path) + if table is None: + table = catalog.create_table("default.benchmark_read", schema=file_table.schema) + table.append(file_table) + + assert table is not None + return table + + +@pytest.mark.parametrize( + "order,concurrent_files,batch_size", + [ + pytest.param(ScanOrder.TASK, 1, None, id="default"), + pytest.param(ScanOrder.ARRIVAL, 1, None, id="arrival-cf1"), + pytest.param(ScanOrder.ARRIVAL, 2, None, id="arrival-cf2"), + pytest.param(ScanOrder.ARRIVAL, 4, None, id="arrival-cf4"), + pytest.param(ScanOrder.ARRIVAL, 8, None, id="arrival-cf8"), + pytest.param(ScanOrder.ARRIVAL, 16, None, id="arrival-cf16"), + ], +) +def test_read_throughput( + benchmark_table: Table, + order: ScanOrder, + concurrent_files: int, + batch_size: int | None, +) -> None: + """Measure records/sec, time to first record, and peak Arrow memory for a scan configuration.""" + effective_batch_size = batch_size or 131_072 # PyArrow default + if order == ScanOrder.ARRIVAL: + config_str = f"order=ARRIVAL, concurrent_files={concurrent_files}, batch_size={effective_batch_size}" + else: + config_str = f"order=TASK (executor.map, all files parallel), batch_size={effective_batch_size}" + print("\n--- ArrowScan Read Throughput Benchmark ---") + print(f"Config: {config_str}") + print(f" Files: {NUM_FILES}, Rows per file: {ROWS_PER_FILE}, Total rows: {TOTAL_ROWS}") + + elapsed_times: list[float] = [] + throughputs: list[float] = [] + peak_memories: list[int] = [] + ttfr_times: list[float] = [] + + for run in range(NUM_RUNS): + # Measure throughput + gc.collect() + pa.default_memory_pool().release_unused() + baseline_mem = pa.total_allocated_bytes() + peak_mem = baseline_mem + + start = timeit.default_timer() + total_rows = 0 + first_batch_time = None + for batch in benchmark_table.scan().to_arrow_batch_reader( + batch_size=batch_size, + order=order, + concurrent_files=concurrent_files, + ): + if first_batch_time is None: + first_batch_time = timeit.default_timer() - start + total_rows += len(batch) + current_mem = pa.total_allocated_bytes() + if current_mem > peak_mem: + peak_mem = current_mem + elapsed = timeit.default_timer() - start + + peak_above_baseline = peak_mem - baseline_mem + rows_per_sec = total_rows / elapsed if elapsed > 0 else 0 + elapsed_times.append(elapsed) + throughputs.append(rows_per_sec) + peak_memories.append(peak_above_baseline) + ttfr_times.append(first_batch_time or 0.0) + + print( + f" Run {run + 1}: {elapsed:.2f}s, {rows_per_sec:,.0f} rows/s, " + f"TTFR: {(first_batch_time or 0) * 1000:.1f}ms, " + f"peak arrow mem: {peak_above_baseline / (1024 * 1024):.1f} MB" + ) + + assert total_rows == TOTAL_ROWS, f"Expected {TOTAL_ROWS} rows, got {total_rows}" + + mean_elapsed = statistics.mean(elapsed_times) + stdev_elapsed = statistics.stdev(elapsed_times) if len(elapsed_times) > 1 else 0.0 + mean_throughput = statistics.mean(throughputs) + mean_peak_mem = statistics.mean(peak_memories) + mean_ttfr = statistics.mean(ttfr_times) + + print( + f" Mean: {mean_elapsed:.2f}s ± {stdev_elapsed:.2f}s, {mean_throughput:,.0f} rows/s, " + f"TTFR: {mean_ttfr * 1000:.1f}ms, " + f"peak arrow mem: {mean_peak_mem / (1024 * 1024):.1f} MB" + ) diff --git a/tests/catalog/test_hive.py b/tests/catalog/test_hive.py index 88b653e44f..a8c0c943da 100644 --- a/tests/catalog/test_hive.py +++ b/tests/catalog/test_hive.py @@ -1314,8 +1314,8 @@ def test_hive_wait_for_lock() -> None: assert catalog._client.check_lock.call_count == 3 # lock wait should exit with WaitingForLockException finally after enough retries + catalog._client.check_lock.reset_mock() catalog._client.check_lock.side_effect = [waiting for _ in range(10)] - catalog._client.check_lock.call_count = 0 with pytest.raises(WaitingForLockException): catalog._wait_for_lock("db", "tbl", lockid, catalog._client) assert catalog._client.check_lock.call_count == 5 diff --git a/tests/io/test_bounded_concurrent_batches.py b/tests/io/test_bounded_concurrent_batches.py new file mode 100644 index 0000000000..e80e7c1798 --- /dev/null +++ b/tests/io/test_bounded_concurrent_batches.py @@ -0,0 +1,258 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for _bounded_concurrent_batches in pyiceberg.io.pyarrow.""" + +import threading +import time +from collections.abc import Iterator +from unittest.mock import MagicMock + +import pyarrow as pa +import pytest + +from pyiceberg.io.pyarrow import _bounded_concurrent_batches +from pyiceberg.table import FileScanTask, ScanOrder + + +def _make_task() -> FileScanTask: + """Create a mock FileScanTask.""" + task = MagicMock(spec=FileScanTask) + return task + + +def _make_batches(num_batches: int, rows_per_batch: int = 10, start: int = 0) -> list[pa.RecordBatch]: + """Create a list of simple RecordBatches.""" + return [ + pa.record_batch({"col": list(range(start + i * rows_per_batch, start + (i + 1) * rows_per_batch))}) + for i in range(num_batches) + ] + + +def test_correctness_single_file() -> None: + """Test that a single file produces correct results.""" + task = _make_task() + expected_batches = _make_batches(3) + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + yield from expected_batches + + result = list(_bounded_concurrent_batches([task], batch_fn, concurrent_files=1, max_buffered_batches=16)) + + assert len(result) == 3 + total_rows = sum(len(b) for b in result) + assert total_rows == 30 + + +def test_correctness_multiple_files() -> None: + """Test that multiple files produce all expected batches.""" + tasks = [_make_task() for _ in range(4)] + batches_per_file = 3 + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + idx = tasks.index(t) + yield from _make_batches(batches_per_file, start=idx * 100) + + result = list(_bounded_concurrent_batches(tasks, batch_fn, concurrent_files=2, max_buffered_batches=16)) + + total_rows = sum(len(b) for b in result) + assert total_rows == batches_per_file * len(tasks) * 10 # 3 batches * 4 files * 10 rows + + +def test_arrival_order_yields_incrementally() -> None: + """Test that batches are yielded incrementally, not all at once.""" + barrier = threading.Event() + tasks = [_make_task(), _make_task()] + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + yield pa.record_batch({"col": [1, 2, 3]}) + barrier.wait(timeout=5.0) + yield pa.record_batch({"col": [4, 5, 6]}) + + gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_files=2, max_buffered_batches=16) + + # Should get at least one batch before all are done + first = next(gen) + assert first.num_rows == 3 + + # Unblock remaining batches + barrier.set() + + remaining = list(gen) + total = 1 + len(remaining) + assert total >= 3 # At least 3 more batches (one blocked from each task + the unblocked ones) + + +def test_backpressure() -> None: + """Test that workers block when the queue is full.""" + max_buffered = 2 + tasks = [_make_task()] + produced_count = 0 + produce_lock = threading.Lock() + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + nonlocal produced_count + for i in range(10): + with produce_lock: + produced_count += 1 + yield pa.record_batch({"col": [i]}) + + gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_files=1, max_buffered_batches=max_buffered) + + # Consume slowly and check that not all batches are produced immediately + first = next(gen) + assert first is not None + time.sleep(0.3) + + # The producer should be blocked by backpressure at some point + # (not all 10 batches produced instantly) + remaining = list(gen) + assert len(remaining) + 1 == 10 + + +def test_error_propagation() -> None: + """Test that errors from workers are propagated to the consumer.""" + tasks = [_make_task()] + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + yield pa.record_batch({"col": [1]}) + raise ValueError("test error") + + gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_files=1, max_buffered_batches=16) + + # Should get the first batch + first = next(gen) + assert first.num_rows == 1 + + # Should get the error + with pytest.raises(ValueError, match="test error"): + list(gen) + + +def test_early_termination() -> None: + """Test that stopping consumption cancels workers.""" + tasks = [_make_task() for _ in range(5)] + worker_started = threading.Event() + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + worker_started.set() + for i in range(100): + yield pa.record_batch({"col": [i]}) + time.sleep(0.01) + + gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_files=3, max_buffered_batches=4) + + # Consume a few batches then stop + worker_started.wait(timeout=5.0) + batches = [] + for _ in range(5): + batches.append(next(gen)) + + # Close the generator, triggering finally block + gen.close() + + assert len(batches) == 5 + + +def test_concurrency_limit() -> None: + """Test that at most concurrent_files files are read concurrently.""" + concurrent_files = 2 + tasks = [_make_task() for _ in range(6)] + active_count = 0 + max_active = 0 + active_lock = threading.Lock() + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + nonlocal active_count, max_active + with active_lock: + active_count += 1 + max_active = max(max_active, active_count) + try: + time.sleep(0.05) + yield pa.record_batch({"col": [1]}) + finally: + with active_lock: + active_count -= 1 + + result = list(_bounded_concurrent_batches(tasks, batch_fn, concurrent_files=concurrent_files, max_buffered_batches=16)) + + assert len(result) == 6 + assert max_active <= concurrent_files + + +def test_empty_tasks() -> None: + """Test that no tasks produces no batches.""" + + def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]: + yield from [] + + result = list(_bounded_concurrent_batches([], batch_fn, concurrent_files=2, max_buffered_batches=16)) + assert result == [] + + +def test_concurrent_with_limit_via_arrowscan(tmpdir: str) -> None: + """Test concurrent_files with limit through ArrowScan integration.""" + from pyiceberg.expressions import AlwaysTrue + from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO + from pyiceberg.manifest import DataFileContent, FileFormat + from pyiceberg.partitioning import PartitionSpec + from pyiceberg.schema import Schema + from pyiceberg.table.metadata import TableMetadataV2 + from pyiceberg.types import LongType, NestedField + + PYARROW_PARQUET_FIELD_ID_KEY = b"PARQUET:field_id" + + table_schema = Schema(NestedField(1, "col", LongType(), required=True)) + pa_schema = pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]) + + tasks = [] + for i in range(4): + filepath = f"{tmpdir}/file_{i}.parquet" + arrow_table = pa.table({"col": pa.array(range(i * 100, (i + 1) * 100))}, schema=pa_schema) + import pyarrow.parquet as pq + + pq.write_table(arrow_table, filepath) + from pyiceberg.manifest import DataFile + + data_file = DataFile.from_args( + content=DataFileContent.DATA, + file_path=filepath, + file_format=FileFormat.PARQUET, + partition={}, + record_count=100, + file_size_in_bytes=22, + ) + data_file.spec_id = 0 + tasks.append(FileScanTask(data_file)) + + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=1, + format_version=2, + schemas=[table_schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=table_schema, + row_filter=AlwaysTrue(), + case_sensitive=True, + limit=150, + ) + + batches = list(scan.to_record_batches(tasks, order=ScanOrder.ARRIVAL, concurrent_files=2)) + total_rows = sum(len(b) for b in batches) + assert total_rows == 150 diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 04bc3ecfac..dd450c588e 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -86,7 +86,7 @@ from pyiceberg.manifest import DataFile, DataFileContent, FileFormat from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema, make_compatible_name, visit -from pyiceberg.table import FileScanTask, TableProperties +from pyiceberg.table import FileScanTask, ScanOrder, TableProperties from pyiceberg.table.metadata import TableMetadataV2 from pyiceberg.table.name_mapping import create_mapping_from_schema from pyiceberg.transforms import HourTransform, IdentityTransform @@ -3048,6 +3048,282 @@ def _expected_batch(unit: str) -> pa.RecordBatch: assert _expected_batch("ns" if format_version > 2 else "us").equals(actual_result) +def test_task_to_record_batches_with_batch_size(tmpdir: str) -> None: + """Test that batch_size controls the number of rows per batch.""" + num_rows = 1000 + arrow_table = pa.table( + {"col": pa.array(range(num_rows))}, + schema=pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]), + ) + data_file = _write_table_to_data_file(f"{tmpdir}/test_batch_size.parquet", arrow_table.schema, arrow_table) + table_schema = Schema(NestedField(1, "col", LongType(), required=True)) + + batches = list( + _task_to_record_batches( + PyArrowFileIO(), + FileScanTask(data_file), + bound_row_filter=AlwaysTrue(), + projected_schema=table_schema, + table_schema=table_schema, + projected_field_ids={1}, + positional_deletes=None, + case_sensitive=True, + batch_size=100, + ) + ) + + assert len(batches) > 1 + for batch in batches: + assert len(batch) <= 100 + assert sum(len(b) for b in batches) == num_rows + + +def test_task_to_record_batches_default_batch_size(tmpdir: str) -> None: + """Test that batch_size=None uses PyArrow default (single batch for small files).""" + num_rows = 100 + arrow_table = pa.table( + {"col": pa.array(range(num_rows))}, + schema=pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]), + ) + data_file = _write_table_to_data_file(f"{tmpdir}/test_default_batch_size.parquet", arrow_table.schema, arrow_table) + table_schema = Schema(NestedField(1, "col", LongType(), required=True)) + + batches = list( + _task_to_record_batches( + PyArrowFileIO(), + FileScanTask(data_file), + bound_row_filter=AlwaysTrue(), + projected_schema=table_schema, + table_schema=table_schema, + projected_field_ids={1}, + positional_deletes=None, + case_sensitive=True, + ) + ) + + # With default batch_size, a small file should produce a single batch + assert len(batches) == 1 + assert len(batches[0]) == num_rows + + +def _create_scan_and_tasks( + tmpdir: str, + num_files: int = 1, + rows_per_file: int = 100, + limit: int | None = None, + delete_rows_per_file: list[list[int]] | None = None, +) -> tuple[ArrowScan, list[FileScanTask]]: + """Helper to create an ArrowScan and FileScanTasks for testing. + + Args: + delete_rows_per_file: If provided, a list of lists of row positions to delete + per file. Length must match num_files. Each inner list contains 0-based + row positions within that file to mark as positionally deleted. + """ + table_schema = Schema(NestedField(1, "col", LongType(), required=True)) + pa_schema = pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]) + tasks = [] + for i in range(num_files): + start = i * rows_per_file + arrow_table = pa.table({"col": pa.array(range(start, start + rows_per_file))}, schema=pa_schema) + data_file = _write_table_to_data_file(f"{tmpdir}/file_{i}.parquet", pa_schema, arrow_table) + data_file.spec_id = 0 + + delete_files = set() + if delete_rows_per_file and delete_rows_per_file[i]: + delete_table = pa.table( + { + "file_path": [data_file.file_path] * len(delete_rows_per_file[i]), + "pos": delete_rows_per_file[i], + } + ) + delete_path = f"{tmpdir}/deletes_{i}.parquet" + pq.write_table(delete_table, delete_path) + delete_files.add( + DataFile.from_args( + content=DataFileContent.POSITION_DELETES, + file_path=delete_path, + file_format=FileFormat.PARQUET, + partition={}, + record_count=len(delete_rows_per_file[i]), + file_size_in_bytes=22, + ) + ) + + tasks.append(FileScanTask(data_file=data_file, delete_files=delete_files)) + + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=1, + format_version=2, + schemas=[table_schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=table_schema, + row_filter=AlwaysTrue(), + case_sensitive=True, + limit=limit, + ) + return scan, tasks + + +def test_task_order_produces_same_results(tmpdir: str) -> None: + """Test that order=ScanOrder.TASK produces the same results as the default behavior.""" + scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100) + + batches_default = list(scan.to_record_batches(tasks, order=ScanOrder.TASK)) + # Re-create tasks since iterators are consumed + _, tasks2 = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100) + batches_task_order = list(scan.to_record_batches(tasks2, order=ScanOrder.TASK)) + + total_default = sum(len(b) for b in batches_default) + total_task_order = sum(len(b) for b in batches_task_order) + assert total_default == 300 + assert total_task_order == 300 + + +def test_arrival_order_yields_all_batches(tmpdir: str) -> None: + """Test that order=ScanOrder.ARRIVAL yields all batches correctly.""" + scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100) + + batches = list(scan.to_record_batches(tasks, order=ScanOrder.ARRIVAL)) + + total_rows = sum(len(b) for b in batches) + assert total_rows == 300 + # Verify all values are present + all_values = sorted([v for b in batches for v in b.column("col").to_pylist()]) + assert all_values == list(range(300)) + + +def test_arrival_order_with_limit(tmpdir: str) -> None: + """Test that order=ScanOrder.ARRIVAL respects the row limit.""" + scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100, limit=150) + + batches = list(scan.to_record_batches(tasks, order=ScanOrder.ARRIVAL)) + + total_rows = sum(len(b) for b in batches) + assert total_rows == 150 + + +def test_arrival_order_within_file_ordering_preserved(tmpdir: str) -> None: + """Test that within-file row ordering is preserved in arrival order mode.""" + scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100) + + batches = list(scan.to_record_batches(tasks, order=ScanOrder.ARRIVAL)) + all_values = sorted([v for b in batches for v in b.column("col").to_pylist()]) + + # All values should be present, within-file ordering is preserved + assert all_values == list(range(300)) + + +def test_arrival_order_with_positional_deletes(tmpdir: str) -> None: + """Test that order=ScanOrder.ARRIVAL correctly applies positional deletes.""" + # 3 files, 10 rows each; delete rows 0,5 from file 0, row 3 from file 1, nothing from file 2 + scan, tasks = _create_scan_and_tasks( + tmpdir, + num_files=3, + rows_per_file=10, + delete_rows_per_file=[[0, 5], [3], []], + ) + + batches = list(scan.to_record_batches(tasks, order=ScanOrder.ARRIVAL)) + + total_rows = sum(len(b) for b in batches) + assert total_rows == 27 # 30 - 3 deletes + all_values = sorted([v for b in batches for v in b.column("col").to_pylist()]) + # File 0: 0-9, delete rows 0,5 → values 1,2,3,4,6,7,8,9 + # File 1: 10-19, delete row 3 → values 10,11,12,14,15,16,17,18,19 + # File 2: 20-29, no deletes → values 20-29 + expected = [1, 2, 3, 4, 6, 7, 8, 9] + [10, 11, 12, 14, 15, 16, 17, 18, 19] + list(range(20, 30)) + assert all_values == sorted(expected) + + +def test_arrival_order_with_positional_deletes_and_limit(tmpdir: str) -> None: + """Test that order=ScanOrder.ARRIVAL with positional deletes respects the row limit.""" + # 3 files, 10 rows each; delete row 0 from each file + scan, tasks = _create_scan_and_tasks( + tmpdir, + num_files=3, + rows_per_file=10, + limit=15, + delete_rows_per_file=[[0], [0], [0]], + ) + + batches = list(scan.to_record_batches(tasks, order=ScanOrder.ARRIVAL)) + + total_rows = sum(len(b) for b in batches) + assert total_rows == 15 + + +def test_task_order_with_positional_deletes(tmpdir: str) -> None: + """Test that the default task order mode correctly applies positional deletes.""" + # 3 files, 10 rows each; delete rows from each file + scan, tasks = _create_scan_and_tasks( + tmpdir, + num_files=3, + rows_per_file=10, + delete_rows_per_file=[[0, 5], [3], []], + ) + + batches = list(scan.to_record_batches(tasks, order=ScanOrder.TASK)) + + total_rows = sum(len(b) for b in batches) + assert total_rows == 27 # 30 - 3 deletes + all_values = sorted([v for b in batches for v in b.column("col").to_pylist()]) + expected = [1, 2, 3, 4, 6, 7, 8, 9] + [10, 11, 12, 14, 15, 16, 17, 18, 19] + list(range(20, 30)) + assert all_values == sorted(expected) + + +def test_concurrent_files_with_positional_deletes(tmpdir: str) -> None: + """Test that order=ScanOrder.ARRIVAL with concurrent_files correctly applies positional deletes.""" + # 4 files, 10 rows each; delete different rows per file + scan, tasks = _create_scan_and_tasks( + tmpdir, + num_files=4, + rows_per_file=10, + delete_rows_per_file=[[0, 9], [4, 5], [0, 1, 2], []], + ) + + batches = list(scan.to_record_batches(tasks, order=ScanOrder.ARRIVAL, concurrent_files=2)) + + total_rows = sum(len(b) for b in batches) + assert total_rows == 33 # 40 - 7 deletes + all_values = sorted([v for b in batches for v in b.column("col").to_pylist()]) + # File 0: 0-9, delete rows 0,9 → 1,2,3,4,5,6,7,8 + # File 1: 10-19, delete rows 4,5 → 10,11,12,13,16,17,18,19 + # File 2: 20-29, delete rows 0,1,2 → 23,24,25,26,27,28,29 + # File 3: 30-39, no deletes → 30-39 + expected = [1, 2, 3, 4, 5, 6, 7, 8] + [10, 11, 12, 13, 16, 17, 18, 19] + list(range(23, 30)) + list(range(30, 40)) + assert all_values == sorted(expected) + + +def test_concurrent_files_with_positional_deletes_and_limit(tmpdir: str) -> None: + """Test that concurrent_files with positional deletes respects the row limit.""" + # 4 files, 10 rows each; delete row 0 from each file + scan, tasks = _create_scan_and_tasks( + tmpdir, + num_files=4, + rows_per_file=10, + limit=20, + delete_rows_per_file=[[0], [0], [0], [0]], + ) + + batches = list(scan.to_record_batches(tasks, order=ScanOrder.ARRIVAL, concurrent_files=2)) + + total_rows = sum(len(b) for b in batches) + assert total_rows == 20 + + +def test_concurrent_files_invalid_value(tmpdir: str) -> None: + """Test that concurrent_files < 1 raises ValueError.""" + scan, tasks = _create_scan_and_tasks(tmpdir, num_files=1, rows_per_file=10) + + with pytest.raises(ValueError, match="concurrent_files must be >= 1"): + list(scan.to_record_batches(tasks, order=ScanOrder.ARRIVAL, concurrent_files=0)) + + def test_parse_location_defaults() -> None: """Test that parse_location uses defaults."""