Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions mkdocs/docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,13 @@ for buf in tbl.scan().to_arrow_batch_reader():
print(f"Buffer contains {len(buf)} rows")
```

You can control the number of rows per batch using the `batch_size` parameter:

```python
for buf in tbl.scan().to_arrow_batch_reader(batch_size=1000):
print(f"Buffer contains {len(buf)} rows")
```

To avoid any type inconsistencies during writing, you can convert the Iceberg table schema to Arrow:

```python
Expand Down Expand Up @@ -1619,6 +1626,15 @@ table.scan(
).to_arrow_batch_reader()
```

The `batch_size` parameter controls the maximum number of rows per RecordBatch (default is PyArrow's 131,072 rows):

```python
table.scan(
row_filter=GreaterThanOrEqual("trip_distance", 10.0),
selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"),
).to_arrow_batch_reader(batch_size=1000)
```

### Pandas

<!-- prettier-ignore-start -->
Expand Down
24 changes: 15 additions & 9 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1581,6 +1581,7 @@ def _task_to_record_batches(
partition_spec: PartitionSpec | None = None,
format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION,
downcast_ns_timestamp_to_us: bool | None = None,
batch_size: int | None = None,
) -> Iterator[pa.RecordBatch]:
arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
with io.new_input(task.file.file_path).open() as fin:
Expand Down Expand Up @@ -1612,14 +1613,18 @@ def _task_to_record_batches(

file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)

fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
schema=physical_schema,
scanner_kwargs: dict[str, Any] = {
"fragment": fragment,
"schema": physical_schema,
# This will push down the query to Arrow.
# But in case there are positional deletes, we have to apply them first
filter=pyarrow_filter if not positional_deletes else None,
columns=[col.name for col in file_project_schema.columns],
)
"filter": pyarrow_filter if not positional_deletes else None,
"columns": [col.name for col in file_project_schema.columns],
}
if batch_size is not None:
scanner_kwargs["batch_size"] = batch_size

fragment_scanner = ds.Scanner.from_fragment(**scanner_kwargs)

next_index = 0
batches = fragment_scanner.to_batches()
Expand Down Expand Up @@ -1756,7 +1761,7 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table:

return result

def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.RecordBatch]:
def to_record_batches(self, tasks: Iterable[FileScanTask], batch_size: int | None = None) -> Iterator[pa.RecordBatch]:
"""Scan the Iceberg table and return an Iterator[pa.RecordBatch].
Returns an Iterator of pa.RecordBatch with data from the Iceberg table
Expand All @@ -1783,7 +1788,7 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]:
# Materialize the iterator here to ensure execution happens within the executor.
# Otherwise, the iterator would be lazily consumed later (in the main thread),
# defeating the purpose of using executor.map.
return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file))
return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file, batch_size))

limit_reached = False
for batches in executor.map(batches_for_task, tasks):
Expand All @@ -1803,7 +1808,7 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]:
break

def _record_batches_from_scan_tasks_and_deletes(
self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]]
self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]], batch_size: int | None = None
) -> Iterator[pa.RecordBatch]:
total_row_count = 0
for task in tasks:
Expand All @@ -1822,6 +1827,7 @@ def _record_batches_from_scan_tasks_and_deletes(
self._table_metadata.specs().get(task.file.spec_id),
self._table_metadata.format_version,
self._downcast_ns_timestamp_to_us,
batch_size,
)
for batch in batches:
if self._limit is not None:
Expand Down
7 changes: 5 additions & 2 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2157,13 +2157,16 @@ def to_arrow(self) -> pa.Table:
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
).to_table(self.plan_files())

def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
def to_arrow_batch_reader(self, batch_size: int | None = None) -> pa.RecordBatchReader:
"""Return an Arrow RecordBatchReader from this DataScan.

For large results, using a RecordBatchReader requires less memory than
loading an Arrow Table for the same DataScan, because a RecordBatch
is read one at a time.

Args:
batch_size: The number of rows per batch. If None, PyArrow's default is used.

Returns:
pa.RecordBatchReader: Arrow RecordBatchReader from the Iceberg table's DataScan
which can be used to read a stream of record batches one by one.
Expand All @@ -2175,7 +2178,7 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
target_schema = schema_to_pyarrow(self.projection())
batches = ArrowScan(
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
).to_record_batches(self.plan_files())
).to_record_batches(self.plan_files(), batch_size=batch_size)

return pa.RecordBatchReader.from_batches(
target_schema,
Expand Down
58 changes: 58 additions & 0 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3048,6 +3048,64 @@ def _expected_batch(unit: str) -> pa.RecordBatch:
assert _expected_batch("ns" if format_version > 2 else "us").equals(actual_result)


def test_task_to_record_batches_with_batch_size(tmpdir: str) -> None:
"""Test that batch_size controls the number of rows per batch."""
num_rows = 1000
arrow_table = pa.table(
{"col": pa.array(range(num_rows))},
schema=pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]),
)
data_file = _write_table_to_data_file(f"{tmpdir}/test_batch_size.parquet", arrow_table.schema, arrow_table)
table_schema = Schema(NestedField(1, "col", LongType(), required=True))

batches = list(
_task_to_record_batches(
PyArrowFileIO(),
FileScanTask(data_file),
bound_row_filter=AlwaysTrue(),
projected_schema=table_schema,
table_schema=table_schema,
projected_field_ids={1},
positional_deletes=None,
case_sensitive=True,
batch_size=100,
)
)

assert len(batches) > 1
for batch in batches:
assert len(batch) <= 100
assert sum(len(b) for b in batches) == num_rows


def test_task_to_record_batches_default_batch_size(tmpdir: str) -> None:
"""Test that batch_size=None uses PyArrow default (single batch for small files)."""
num_rows = 100
arrow_table = pa.table(
{"col": pa.array(range(num_rows))},
schema=pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]),
)
data_file = _write_table_to_data_file(f"{tmpdir}/test_default_batch_size.parquet", arrow_table.schema, arrow_table)
table_schema = Schema(NestedField(1, "col", LongType(), required=True))

batches = list(
_task_to_record_batches(
PyArrowFileIO(),
FileScanTask(data_file),
bound_row_filter=AlwaysTrue(),
projected_schema=table_schema,
table_schema=table_schema,
projected_field_ids={1},
positional_deletes=None,
case_sensitive=True,
)
)

# With default batch_size, a small file should produce a single batch
assert len(batches) == 1
assert len(batches[0]) == num_rows


def test_parse_location_defaults() -> None:
"""Test that parse_location uses defaults."""

Expand Down
Loading