diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index a120c3b776..24bac4f58f 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1581,6 +1581,7 @@ def _task_to_record_batches( partition_spec: PartitionSpec | None = None, format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION, downcast_ns_timestamp_to_us: bool | None = None, + batch_size: int | None = None, ) -> Iterator[pa.RecordBatch]: arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) with io.new_input(task.file.file_path).open() as fin: @@ -1612,14 +1613,17 @@ def _task_to_record_batches( file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) - fragment_scanner = ds.Scanner.from_fragment( - fragment=fragment, - schema=physical_schema, + scanner_kwargs: dict[str, Any] = { + "fragment": fragment, + "schema": physical_schema, # This will push down the query to Arrow. # But in case there are positional deletes, we have to apply them first - filter=pyarrow_filter if not positional_deletes else None, - columns=[col.name for col in file_project_schema.columns], - ) + "filter": pyarrow_filter if not positional_deletes else None, + "columns": [col.name for col in file_project_schema.columns], + } + if batch_size is not None: + scanner_kwargs["batch_size"] = batch_size + fragment_scanner = ds.Scanner.from_fragment(**scanner_kwargs) next_index = 0 batches = fragment_scanner.to_batches() @@ -1802,8 +1806,30 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]: # This break will also cancel all running tasks in the executor break + def to_record_batch_stream(self, tasks: Iterable[FileScanTask], batch_size: int | None = None) -> Iterator[pa.RecordBatch]: + """Scan the Iceberg table and return an Iterator[pa.RecordBatch] in a streaming fashion. + + Files are read sequentially and batches are yielded one at a time + without materializing all batches in memory. Use this when memory + efficiency is more important than throughput. + + Args: + tasks: FileScanTasks representing the data files and delete files to read from. + batch_size: Maximum number of rows per RecordBatch. If None, + uses PyArrow's default (131,072 rows). + + Yields: + pa.RecordBatch: Record batches from the scan, one at a time. + """ + tasks = list(tasks) if not isinstance(tasks, list) else tasks + deletes_per_file = _read_all_delete_files(self._io, tasks) + yield from self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file, batch_size) + def _record_batches_from_scan_tasks_and_deletes( - self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]] + self, + tasks: Iterable[FileScanTask], + deletes_per_file: dict[str, list[ChunkedArray]], + batch_size: int | None = None, ) -> Iterator[pa.RecordBatch]: total_row_count = 0 for task in tasks: @@ -1822,6 +1848,7 @@ def _record_batches_from_scan_tasks_and_deletes( self._table_metadata.specs().get(task.file.spec_id), self._table_metadata.format_version, self._downcast_ns_timestamp_to_us, + batch_size, ) for batch in batches: if self._limit is not None: diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index cc0d9ff341..14a35745ff 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -2182,6 +2182,26 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: batches, ).cast(target_schema) + def to_record_batches(self, batch_size: int | None = None) -> Iterator[pa.RecordBatch]: + """Read record batches in a streaming fashion from this DataScan. + + Files are read sequentially and batches are yielded one at a time + without materializing all batches in memory. Use this when memory + efficiency is more important than throughput. + + Args: + batch_size: Maximum number of rows per RecordBatch. If None, + uses PyArrow's default (131,072 rows). + + Yields: + pa.RecordBatch: Record batches from the scan, one at a time. + """ + from pyiceberg.io.pyarrow import ArrowScan + + yield from ArrowScan( + self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit + ).to_record_batch_stream(self.plan_files(), batch_size) + def to_pandas(self, **kwargs: Any) -> pd.DataFrame: """Read a Pandas DataFrame eagerly from this Iceberg table. diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py index 6c8b4a20a7..c78a6cc094 100644 --- a/tests/integration/test_reads.py +++ b/tests/integration/test_reads.py @@ -1272,3 +1272,51 @@ def test_scan_source_field_missing_in_spec(catalog: Catalog, spark: SparkSession table = catalog.load_table(identifier) assert len(list(table.scan().plan_files())) == 3 + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")]) +def test_datascan_to_record_batches(catalog: Catalog) -> None: + table = create_table(catalog) + + arrow_table = pa.Table.from_pydict( + { + "str": ["a", "b", "c"], + "int": [1, 2, 3], + }, + schema=pa.schema([pa.field("str", pa.large_string()), pa.field("int", pa.int32())]), + ) + table.append(arrow_table) + + scan = table.scan() + streaming_batches = list(scan.to_record_batches()) + streaming_result = pa.concat_tables([pa.Table.from_batches([b]) for b in streaming_batches], promote_options="permissive") + + eager_result = scan.to_arrow() + + assert streaming_result.num_rows == eager_result.num_rows + assert streaming_result.column_names == eager_result.column_names + assert streaming_result.sort_by("int").equals(eager_result.sort_by("int")) + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")]) +def test_datascan_to_record_batches_with_batch_size(catalog: Catalog) -> None: + table = create_table(catalog) + + arrow_table = pa.Table.from_pydict( + { + "str": [f"val_{i}" for i in range(100)], + "int": list(range(100)), + }, + schema=pa.schema([pa.field("str", pa.large_string()), pa.field("int", pa.int32())]), + ) + table.append(arrow_table) + + scan = table.scan() + batches = list(scan.to_record_batches(batch_size=10)) + + total_rows = sum(len(b) for b in batches) + assert total_rows == 100 + for batch in batches: + assert len(batch) <= 10 diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 04bc3ecfac..8bba598c55 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -55,6 +55,7 @@ BoundNotStartsWith, BoundReference, BoundStartsWith, + EqualTo, GreaterThan, Not, Or, @@ -4884,3 +4885,862 @@ def test_partition_column_projection_with_schema_evolution(catalog: InMemoryCata result_sorted = result.sort_by("name") assert result_sorted["name"].to_pylist() == ["Alice", "Bob", "Charlie", "David"] assert result_sorted["new_column"].to_pylist() == [None, None, "new1", "new2"] + + +def test_task_to_record_batches_with_batch_size(tmpdir: str) -> None: + schema = Schema(NestedField(1, "id", IntegerType(), required=False)) + pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)}) + + # Create a parquet file with 1000 rows + table = pa.Table.from_arrays([pa.array(list(range(1000)))], schema=pyarrow_schema) + data_file = _write_table_to_data_file(f"{tmpdir}/batch_size_test.parquet", pyarrow_schema, table) + data_file.spec_id = 0 + + task = FileScanTask(data_file=data_file) + + batches = list( + _task_to_record_batches( + PyArrowFileIO(), + task, + bound_row_filter=AlwaysTrue(), + projected_schema=schema, + table_schema=schema, + projected_field_ids={1}, + positional_deletes=None, + case_sensitive=True, + batch_size=100, + ) + ) + + total_rows = sum(len(b) for b in batches) + assert total_rows == 1000 + for batch in batches: + assert len(batch) <= 100 + + +def test_to_record_batch_stream_basic(tmpdir: str) -> None: + schema = Schema(NestedField(1, "id", IntegerType(), required=False)) + pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)}) + + table = pa.Table.from_arrays([pa.array(list(range(100)))], schema=pyarrow_schema) + data_file = _write_table_to_data_file(f"{tmpdir}/streaming_basic.parquet", pyarrow_schema, table) + data_file.spec_id = 0 + + task = FileScanTask(data_file=data_file) + + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=1, + format_version=2, + schemas=[schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=schema, + row_filter=AlwaysTrue(), + case_sensitive=True, + ) + + result = scan.to_record_batch_stream([task]) + # Should be a generator/iterator, not a list + import types + + assert isinstance(result, types.GeneratorType) + + batches = list(result) + total_rows = sum(len(b) for b in batches) + assert total_rows == 100 + + +def test_to_record_batch_stream_with_batch_size(tmpdir: str) -> None: + schema = Schema(NestedField(1, "id", IntegerType(), required=False)) + pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)}) + + table = pa.Table.from_arrays([pa.array(list(range(500)))], schema=pyarrow_schema) + data_file = _write_table_to_data_file(f"{tmpdir}/streaming_batch_size.parquet", pyarrow_schema, table) + data_file.spec_id = 0 + + task = FileScanTask(data_file=data_file) + + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=1, + format_version=2, + schemas=[schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=schema, + row_filter=AlwaysTrue(), + case_sensitive=True, + ) + + batches = list(scan.to_record_batch_stream([task], batch_size=50)) + + total_rows = sum(len(b) for b in batches) + assert total_rows == 500 + for batch in batches: + assert len(batch) <= 50 + + +def test_to_record_batch_stream_with_limit(tmpdir: str) -> None: + schema = Schema(NestedField(1, "id", IntegerType(), required=False)) + pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)}) + + table = pa.Table.from_arrays([pa.array(list(range(500)))], schema=pyarrow_schema) + data_file = _write_table_to_data_file(f"{tmpdir}/streaming_limit.parquet", pyarrow_schema, table) + data_file.spec_id = 0 + + task = FileScanTask(data_file=data_file) + + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=1, + format_version=2, + schemas=[schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=schema, + row_filter=AlwaysTrue(), + case_sensitive=True, + limit=100, + ) + + batches = list(scan.to_record_batch_stream([task])) + + total_rows = sum(len(b) for b in batches) + assert total_rows == 100 + + +def test_to_record_batch_stream_with_deletes( + deletes_file: str, request: pytest.FixtureRequest, table_schema_simple: Schema +) -> None: + file_format = FileFormat.PARQUET if deletes_file.endswith(".parquet") else FileFormat.ORC + + if file_format == FileFormat.PARQUET: + example_task = request.getfixturevalue("example_task") + else: + example_task = request.getfixturevalue("example_task_orc") + + example_task_with_delete = FileScanTask( + data_file=example_task.file, + delete_files={ + DataFile.from_args( + content=DataFileContent.POSITION_DELETES, + file_path=deletes_file, + file_format=file_format, + ) + }, + ) + + metadata_location = "file://a/b/c.json" + scan = ArrowScan( + table_metadata=TableMetadataV2( + location=metadata_location, + last_column_id=1, + format_version=2, + current_schema_id=1, + schemas=[table_schema_simple], + partition_specs=[PartitionSpec()], + ), + io=load_file_io(), + projected_schema=table_schema_simple, + row_filter=AlwaysTrue(), + ) + + # Compare streaming path to table path + streaming_batches = list(scan.to_record_batch_stream([example_task_with_delete])) + streaming_table = pa.concat_tables([pa.Table.from_batches([b]) for b in streaming_batches], promote_options="permissive") + eager_table = scan.to_table(tasks=[example_task_with_delete]) + + assert streaming_table.num_rows == eager_table.num_rows + assert streaming_table.column_names == eager_table.column_names + + +def test_to_record_batch_stream_multiple_files(tmpdir: str) -> None: + schema = Schema(NestedField(1, "id", IntegerType(), required=False)) + pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)}) + + tasks = [] + total_expected = 0 + for i in range(3): + num_rows = (i + 1) * 100 # 100, 200, 300 + total_expected += num_rows + table = pa.Table.from_arrays([pa.array(list(range(num_rows)))], schema=pyarrow_schema) + data_file = _write_table_to_data_file(f"{tmpdir}/multi_{i}.parquet", pyarrow_schema, table) + data_file.spec_id = 0 + tasks.append(FileScanTask(data_file=data_file)) + + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=1, + format_version=2, + schemas=[schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=schema, + row_filter=AlwaysTrue(), + case_sensitive=True, + ) + + batches = list(scan.to_record_batch_stream(tasks)) + total_rows = sum(len(b) for b in batches) + assert total_rows == total_expected # 600 rows total + + +# ============================================================================ +# Enhanced Streaming Tests - Filters, Projections, Complex Types +# ============================================================================ + + +def test_to_record_batch_stream_with_simple_filter(tmpdir: str) -> None: + """Test row filtering with streaming. + + Verifies that streaming correctly applies simple filters. + Run: pytest tests/io/test_pyarrow.py::test_to_record_batch_stream_with_simple_filter -v + """ + schema = Schema(NestedField(1, "id", IntegerType(), required=False)) + pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)}) + + # Create 1000 rows with id field (0-999) + table = pa.Table.from_arrays([pa.array(list(range(1000)))], schema=pyarrow_schema) + data_file = _write_table_to_data_file(f"{tmpdir}/filter_simple.parquet", pyarrow_schema, table) + data_file.spec_id = 0 + + task = FileScanTask(data_file=data_file) + + # Apply filter: id > 500 + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=1, + format_version=2, + schemas=[schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=schema, + row_filter=GreaterThan("id", 500), + case_sensitive=True, + ) + + result = scan.to_record_batch_stream([task]) + + # Verify generator type + import types + + assert isinstance(result, types.GeneratorType) + + # Materialize and verify data + batches = list(result) + total_rows = sum(len(b) for b in batches) + assert total_rows == 499 # ids 501-999 + + # Verify all IDs > 500 + for batch in batches: + ids = batch.column("id").to_pylist() + assert all(id_val > 500 for id_val in ids) + + +def test_to_record_batch_stream_with_complex_filter(tmpdir: str) -> None: + """Test complex AND filter with streaming. + + Verifies that streaming correctly applies complex filters with multiple conditions. + Run: pytest tests/io/test_pyarrow.py::test_to_record_batch_stream_with_complex_filter -v + """ + schema = Schema( + NestedField(1, "id", IntegerType(), required=False), + NestedField(2, "age", IntegerType(), required=False), + NestedField(3, "active", BooleanType(), required=False), + ) + pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)}) + + # Create test data: 100 rows with varied age and active values + ids = list(range(100)) + ages = [20 + (i % 50) for i in range(100)] # ages 20-69 + actives = [i % 2 == 0 for i in range(100)] # alternating True/False + + table = pa.Table.from_arrays( + [pa.array(ids), pa.array(ages), pa.array(actives)], schema=pyarrow_schema + ) + data_file = _write_table_to_data_file(f"{tmpdir}/filter_complex.parquet", pyarrow_schema, table) + data_file.spec_id = 0 + + task = FileScanTask(data_file=data_file) + + # Apply filter: age > 30 AND active = True + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=3, + format_version=2, + schemas=[schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=schema, + row_filter=And(GreaterThan("age", 30), EqualTo("active", True)), + case_sensitive=True, + ) + + result = scan.to_record_batch_stream([task]) + + # Verify generator type + import types + + assert isinstance(result, types.GeneratorType) + + # Materialize and verify data + batches = list(result) + total_rows = sum(len(b) for b in batches) + + # Verify all rows match both conditions + for batch in batches: + ages_col = batch.column("age").to_pylist() + actives_col = batch.column("active").to_pylist() + assert all(age > 30 for age in ages_col) + assert all(active is True for active in actives_col) + + # Expected: age > 30 (ages 31-69 = 39 values) AND active = True (every other row starting at 0) + # From 100 rows, ages cycle 20-69, so ages > 30 are ids where (id % 50) > 10 + # Active is True for even ids + # So we need even ids where (id % 50) > 10 + expected = len([i for i in range(100) if ages[i] > 30 and actives[i]]) + assert total_rows == expected + + +def test_to_record_batch_stream_empty_result(tmpdir: str) -> None: + """Test filter returning no results. + + Verifies that streaming handles empty results gracefully. + Run: pytest tests/io/test_pyarrow.py::test_to_record_batch_stream_empty_result -v + """ + schema = Schema(NestedField(1, "id", IntegerType(), required=False)) + pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)}) + + # Create 100 rows with ids 0-99 + table = pa.Table.from_arrays([pa.array(list(range(100)))], schema=pyarrow_schema) + data_file = _write_table_to_data_file(f"{tmpdir}/filter_empty.parquet", pyarrow_schema, table) + data_file.spec_id = 0 + + task = FileScanTask(data_file=data_file) + + # Apply filter that matches nothing: id > 999 + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=1, + format_version=2, + schemas=[schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=schema, + row_filter=GreaterThan("id", 999), + case_sensitive=True, + ) + + result = scan.to_record_batch_stream([task]) + + # Verify generator type + import types + + assert isinstance(result, types.GeneratorType) + + # Materialize - should be empty + batches = list(result) + assert len(batches) == 0 + + +def test_to_record_batch_stream_with_projection(tmpdir: str) -> None: + """Test column projection with streaming. + + Verifies that streaming correctly projects subset of columns. + Run: pytest tests/io/test_pyarrow.py::test_to_record_batch_stream_with_projection -v + """ + # Full schema: id, name, age, active (4 fields) + full_schema = Schema( + NestedField(1, "id", IntegerType(), required=False), + NestedField(2, "name", StringType(), required=False), + NestedField(3, "age", IntegerType(), required=False), + NestedField(4, "active", BooleanType(), required=False), + ) + pyarrow_schema = schema_to_pyarrow(full_schema, metadata={ICEBERG_SCHEMA: bytes(full_schema.model_dump_json(), UTF8)}) + + # Create test data + table = pa.Table.from_arrays( + [ + pa.array(list(range(100))), + pa.array([f"name_{i}" for i in range(100)]), + pa.array([20 + i for i in range(100)]), + pa.array([i % 2 == 0 for i in range(100)]), + ], + schema=pyarrow_schema, + ) + data_file = _write_table_to_data_file(f"{tmpdir}/projection.parquet", pyarrow_schema, table) + data_file.spec_id = 0 + + task = FileScanTask(data_file=data_file) + + # Projected schema: id, age only (2 fields) + projected_schema = Schema( + NestedField(1, "id", IntegerType(), required=False), + NestedField(3, "age", IntegerType(), required=False), + ) + + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=4, + format_version=2, + schemas=[full_schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=projected_schema, + row_filter=AlwaysTrue(), + case_sensitive=True, + ) + + result = scan.to_record_batch_stream([task]) + + # Verify generator type + import types + + assert isinstance(result, types.GeneratorType) + + # Materialize and verify data + batches = list(result) + total_rows = sum(len(b) for b in batches) + assert total_rows == 100 + + # Verify only projected columns present + for batch in batches: + assert batch.num_columns == 2 + assert "id" in batch.schema.names + assert "age" in batch.schema.names + assert "name" not in batch.schema.names + assert "active" not in batch.schema.names + + +def test_to_record_batch_stream_projection_with_filter(tmpdir: str) -> None: + """Test projection with filter on non-projected field. + + Verifies that filters work correctly even when filtered field is not projected. + Run: pytest tests/io/test_pyarrow.py::test_to_record_batch_stream_projection_with_filter -v + """ + # Full schema: id, name, age, active + full_schema = Schema( + NestedField(1, "id", IntegerType(), required=False), + NestedField(2, "name", StringType(), required=False), + NestedField(3, "age", IntegerType(), required=False), + NestedField(4, "active", BooleanType(), required=False), + ) + pyarrow_schema = schema_to_pyarrow(full_schema, metadata={ICEBERG_SCHEMA: bytes(full_schema.model_dump_json(), UTF8)}) + + # Create test data + table = pa.Table.from_arrays( + [ + pa.array(list(range(100))), + pa.array([f"name_{i}" for i in range(100)]), + pa.array([20 + i for i in range(100)]), + pa.array([i % 2 == 0 for i in range(100)]), + ], + schema=pyarrow_schema, + ) + data_file = _write_table_to_data_file(f"{tmpdir}/projection_filter.parquet", pyarrow_schema, table) + data_file.spec_id = 0 + + task = FileScanTask(data_file=data_file) + + # Projected schema: id, name only + projected_schema = Schema( + NestedField(1, "id", IntegerType(), required=False), + NestedField(2, "name", StringType(), required=False), + ) + + # Filter on non-projected field: age > 30 + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=4, + format_version=2, + schemas=[full_schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=projected_schema, + row_filter=GreaterThan("age", 30), + case_sensitive=True, + ) + + result = scan.to_record_batch_stream([task]) + + # Verify generator type + import types + + assert isinstance(result, types.GeneratorType) + + # Materialize and verify data + batches = list(result) + total_rows = sum(len(b) for b in batches) + + # Expected: age > 30 means age >= 31, which is 20+i > 30, so i > 10, giving 89 rows (11-99) + assert total_rows == 89 + + # Verify only projected columns present and age is not in output + for batch in batches: + assert batch.num_columns == 2 + assert "id" in batch.schema.names + assert "name" in batch.schema.names + assert "age" not in batch.schema.names + + # Verify IDs are > 10 (since age = 20 + id) + ids = batch.column("id").to_pylist() + assert all(id_val > 10 for id_val in ids) + + +def test_to_record_batch_stream_with_struct_type(tmpdir: str) -> None: + """Test streaming with struct type. + + Verifies that streaming correctly handles nested struct fields. + Run: pytest tests/io/test_pyarrow.py::test_to_record_batch_stream_with_struct_type -v + """ + schema = Schema( + NestedField( + 4, + "location", + StructType( + NestedField(41, "lat", DoubleType()), + NestedField(42, "long", DoubleType()), + ), + ) + ) + pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)}) + + # Create test data with struct values + locations = pa.array( + [{"lat": 37.7749 + i * 0.1, "long": -122.4194 + i * 0.1} for i in range(50)], + type=pyarrow_schema.field("location").type, + ) + table = pa.Table.from_arrays([locations], schema=pyarrow_schema) + data_file = _write_table_to_data_file(f"{tmpdir}/struct_type.parquet", pyarrow_schema, table) + data_file.spec_id = 0 + + task = FileScanTask(data_file=data_file) + + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=42, + format_version=2, + schemas=[schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=schema, + row_filter=AlwaysTrue(), + case_sensitive=True, + ) + + result = scan.to_record_batch_stream([task]) + + # Verify generator type + import types + + assert isinstance(result, types.GeneratorType) + + # Materialize and verify data + batches = list(result) + total_rows = sum(len(b) for b in batches) + assert total_rows == 50 + + # Verify struct fields accessible and preserved + for batch in batches: + location_col = batch.column("location") + assert location_col is not None + # Verify struct has both fields + for i in range(len(location_col)): + loc = location_col[i].as_py() + assert "lat" in loc + assert "long" in loc + assert isinstance(loc["lat"], float) + assert isinstance(loc["long"], float) + + +def test_to_record_batch_stream_with_list_type(tmpdir: str) -> None: + """Test streaming with list type. + + Verifies that streaming correctly handles list fields. + Run: pytest tests/io/test_pyarrow.py::test_to_record_batch_stream_with_list_type -v + """ + schema = Schema( + NestedField(5, "ids", ListType(51, IntegerType(), element_required=False), required=False), + ) + pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)}) + + # Create test data with list values + ids_data = pa.array([[i, i + 1, i + 2] for i in range(0, 100, 2)], type=pyarrow_schema.field("ids").type) + table = pa.Table.from_arrays([ids_data], schema=pyarrow_schema) + data_file = _write_table_to_data_file(f"{tmpdir}/list_type.parquet", pyarrow_schema, table) + data_file.spec_id = 0 + + task = FileScanTask(data_file=data_file) + + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=51, + format_version=2, + schemas=[schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=schema, + row_filter=AlwaysTrue(), + case_sensitive=True, + ) + + result = scan.to_record_batch_stream([task]) + + # Verify generator type + import types + + assert isinstance(result, types.GeneratorType) + + # Materialize and verify data + batches = list(result) + total_rows = sum(len(b) for b in batches) + assert total_rows == 50 + + # Verify list values preserved and elements accessible + for batch in batches: + ids_col = batch.column("ids") + assert ids_col is not None + for i in range(len(ids_col)): + id_list = ids_col[i].as_py() + assert isinstance(id_list, list) + assert len(id_list) == 3 + assert all(isinstance(x, int) for x in id_list) + + +def test_to_record_batch_stream_with_map_type(tmpdir: str) -> None: + """Test streaming with map type. + + Verifies that streaming correctly handles map fields. + Run: pytest tests/io/test_pyarrow.py::test_to_record_batch_stream_with_map_type -v + """ + schema = Schema( + NestedField( + 5, + "properties", + MapType( + key_id=51, + key_type=StringType(), + value_id=52, + value_type=StringType(), + value_required=True, + ), + required=False, + ), + ) + pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)}) + + # Create test data with map values + properties_data = pa.array( + [[(f"key{j}", f"value{j}") for j in range(3)] for i in range(30)], + type=pyarrow_schema.field("properties").type, + ) + table = pa.Table.from_arrays([properties_data], schema=pyarrow_schema) + data_file = _write_table_to_data_file(f"{tmpdir}/map_type.parquet", pyarrow_schema, table) + data_file.spec_id = 0 + + task = FileScanTask(data_file=data_file) + + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=52, + format_version=2, + schemas=[schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=schema, + row_filter=AlwaysTrue(), + case_sensitive=True, + ) + + result = scan.to_record_batch_stream([task]) + + # Verify generator type + import types + + assert isinstance(result, types.GeneratorType) + + # Materialize and verify data + batches = list(result) + total_rows = sum(len(b) for b in batches) + assert total_rows == 30 + + # Verify map structure and key-value pairs preserved + for batch in batches: + props_col = batch.column("properties") + assert props_col is not None + for i in range(len(props_col)): + prop_map = props_col[i].as_py() + assert isinstance(prop_map, list) # PyArrow represents maps as list of (key, value) tuples + assert len(prop_map) == 3 + for key, value in prop_map: + assert isinstance(key, str) + assert isinstance(value, str) + assert key.startswith("key") + assert value.startswith("value") + + +def test_to_record_batch_stream_consistency_with_to_table(tmpdir: str) -> None: + """Test that streaming results match to_table() results. + + Verifies that concatenated batches equal to_table() result with same filter/projection. + Run: pytest tests/io/test_pyarrow.py::test_to_record_batch_stream_consistency_with_to_table -v + """ + # Schema with multiple fields + schema = Schema( + NestedField(1, "id", IntegerType(), required=False), + NestedField(2, "name", StringType(), required=False), + NestedField(3, "age", IntegerType(), required=False), + NestedField(4, "active", BooleanType(), required=False), + ) + pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)}) + + # Create test data + table = pa.Table.from_arrays( + [ + pa.array(list(range(200))), + pa.array([f"name_{i}" for i in range(200)]), + pa.array([20 + (i % 60) for i in range(200)]), + pa.array([i % 3 == 0 for i in range(200)]), + ], + schema=pyarrow_schema, + ) + data_file = _write_table_to_data_file(f"{tmpdir}/consistency.parquet", pyarrow_schema, table) + data_file.spec_id = 0 + + task = FileScanTask(data_file=data_file) + + # Projected schema: subset of columns + projected_schema = Schema( + NestedField(1, "id", IntegerType(), required=False), + NestedField(3, "age", IntegerType(), required=False), + ) + + # Apply filter: age > 25 AND active = True + row_filter = And(GreaterThan("age", 25), EqualTo("active", True)) + + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=4, + format_version=2, + schemas=[schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=projected_schema, + row_filter=row_filter, + case_sensitive=True, + ) + + # Get streaming result + stream_result = scan.to_record_batch_stream([task]) + + # Verify generator type + import types + + assert isinstance(stream_result, types.GeneratorType) + + # Materialize streaming result + batches = list(stream_result) + stream_table = pa.concat_tables([pa.Table.from_batches([batch]) for batch in batches]) + + # Get to_table result + table_result = scan.to_table([task]) + + # Verify they match + assert stream_table.num_rows == table_result.num_rows + assert stream_table.num_columns == table_result.num_columns + assert stream_table.schema.equals(table_result.schema) + + # Verify same data (column by column) + for col_name in stream_table.schema.names: + stream_col = stream_table.column(col_name).to_pylist() + table_col = table_result.column(col_name).to_pylist() + assert stream_col == table_col + + +def test_to_record_batch_stream_multiple_files_with_filter(tmpdir: str) -> None: + """Test streaming with multiple files and filter. + + Verifies that filters work correctly across multiple files. + Run: pytest tests/io/test_pyarrow.py::test_to_record_batch_stream_multiple_files_with_filter -v + """ + schema = Schema(NestedField(1, "id", IntegerType(), required=False)) + pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)}) + + # Create 3 files with non-overlapping ID ranges + # File 1: IDs 0-99 + table1 = pa.Table.from_arrays([pa.array(list(range(0, 100)))], schema=pyarrow_schema) + data_file1 = _write_table_to_data_file(f"{tmpdir}/multi_file1.parquet", pyarrow_schema, table1) + data_file1.spec_id = 0 + + # File 2: IDs 100-199 + table2 = pa.Table.from_arrays([pa.array(list(range(100, 200)))], schema=pyarrow_schema) + data_file2 = _write_table_to_data_file(f"{tmpdir}/multi_file2.parquet", pyarrow_schema, table2) + data_file2.spec_id = 0 + + # File 3: IDs 200-299 + table3 = pa.Table.from_arrays([pa.array(list(range(200, 300)))], schema=pyarrow_schema) + data_file3 = _write_table_to_data_file(f"{tmpdir}/multi_file3.parquet", pyarrow_schema, table3) + data_file3.spec_id = 0 + + tasks = [FileScanTask(data_file=df) for df in [data_file1, data_file2, data_file3]] + + # Filter: id > 150 (should only get results from files 2 and 3) + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=1, + format_version=2, + schemas=[schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=schema, + row_filter=GreaterThan("id", 150), + case_sensitive=True, + ) + + result = scan.to_record_batch_stream(tasks) + + # Verify generator type + import types + + assert isinstance(result, types.GeneratorType) + + # Materialize and verify data + batches = list(result) + total_rows = sum(len(b) for b in batches) + + # Expected: IDs 151-299 = 149 rows + assert total_rows == 149 + + # Verify all IDs > 150 + for batch in batches: + ids = batch.column("id").to_pylist() + assert all(id_val > 150 for id_val in ids)